Skip to content

Commit 6b8a0c2

Browse files
committed
Plan section access improvement
1 parent 3438327 commit 6b8a0c2

File tree

2 files changed

+90
-30
lines changed

2 files changed

+90
-30
lines changed

internals/plan/extensions_test.go

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,40 @@ var extensionTests = []struct {
266266
},
267267
errorString: "cannot validate plan section .* cannot find .* as required by .*",
268268
},
269-
// Index 8: Load file layers
269+
// Index 8: Load nothing
270+
{
271+
extensions: []extension{
272+
extension{
273+
field: "x-field",
274+
ext: &xExtension{},
275+
},
276+
extension{
277+
field: "y-field",
278+
ext: &yExtension{},
279+
},
280+
},
281+
files: []*planInput{
282+
&planInput{
283+
order: 1,
284+
label: "layer-x",
285+
yaml: `
286+
summary: x
287+
description: desc-x
288+
`,
289+
},
290+
&planInput{
291+
order: 2,
292+
label: "layer-y",
293+
yaml: `
294+
summary: y
295+
description: desc-y
296+
`,
297+
},
298+
},
299+
result: &planResult{},
300+
resultYaml: string("{}\n"),
301+
},
302+
// Index 9: Load file layers
270303
{
271304
extensions: []extension{
272305
extension{
@@ -369,12 +402,26 @@ func (ps *planSuite) TestPlanExtensions(c *C) {
369402
} else {
370403
if planTest.result != nil {
371404
// Verify "x-field" data.
372-
x := p.Section(xField).(*xSection)
373-
c.Assert(x.Entries, DeepEquals, planTest.result.x.Entries)
405+
var x *xSection
406+
err = p.Section(xField, &x)
407+
c.Assert(err, IsNil)
408+
if planTest.result.x != nil {
409+
c.Assert(x.Entries, DeepEquals, planTest.result.x.Entries)
410+
} else {
411+
// No section must result in nil.
412+
c.Assert(x, IsNil)
413+
}
374414

375415
// Verify "y-field" data.
376-
y := p.Section(yField).(*ySection)
377-
c.Assert(y.Entries, DeepEquals, planTest.result.y.Entries)
416+
var y *ySection
417+
err = p.Section(yField, &y)
418+
c.Assert(err, IsNil)
419+
if planTest.result.y != nil {
420+
c.Assert(y.Entries, DeepEquals, planTest.result.y.Entries)
421+
} else {
422+
// No section must result in nil.
423+
c.Assert(x, IsNil)
424+
}
378425
}
379426

380427
// Verify combined plan YAML.
@@ -430,16 +477,22 @@ func (x xExtension) CombineSections(sections ...plan.LayerSection) (plan.LayerSe
430477
}
431478

432479
func (x xExtension) ValidatePlan(p *plan.Plan) error {
433-
planXSection := p.Section(xField)
434-
if planXSection != nil {
435-
planYSection := p.Section(yField)
436-
if planYSection == nil {
480+
var xs *xSection
481+
err := p.Section(xField, &xs)
482+
if err != nil {
483+
return err
484+
}
485+
if xs != nil {
486+
var ys *ySection
487+
err = p.Section(yField, &ys)
488+
if err != nil {
489+
return err
490+
}
491+
if ys == nil {
437492
return fmt.Errorf("cannot validate %v field without %v field", xField, yField)
438493
}
439-
ys := planYSection.(*ySection)
440494

441495
// Make sure every Y field in X refer to an existing Y entry.
442-
xs := planXSection.(*xSection)
443496
for xEntryField, xEntryValue := range xs.Entries {
444497
for _, yReference := range xEntryValue.Y {
445498
found := false

internals/plan/plan.go

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2021 Canonical Ltd
1+
// Copyright (c) 2024 Canonical Ltd
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -90,10 +90,28 @@ type Plan struct {
9090
Sections map[string]LayerSection `yaml:",inline,omitempty"`
9191
}
9292

93-
// Section retrieves a section from the plan. Returns nil if
94-
// the field does not exist.
95-
func (p *Plan) Section(field string) LayerSection {
96-
return p.Sections[field]
93+
// Section retrieves a section from the plan, if it exists.
94+
func (p *Plan) Section(field string, out any) error {
95+
if _, found := layerExtensions[field]; !found {
96+
return fmt.Errorf("cannot find registered extension for field %q", field)
97+
}
98+
99+
outVal := reflect.ValueOf(out)
100+
if outVal.Kind() != reflect.Ptr || outVal.IsNil() {
101+
return fmt.Errorf("cannot read non pointer to section interface type %q", outVal.Kind())
102+
}
103+
104+
section, exists := p.Sections[field]
105+
if exists {
106+
sectionVal := reflect.ValueOf(section)
107+
sectionType := sectionVal.Type()
108+
outValPtrType := outVal.Elem().Type()
109+
if !sectionType.AssignableTo(outValPtrType) {
110+
return fmt.Errorf("cannot assign value of type %s to out argument of type %s", sectionType, outValPtrType)
111+
}
112+
outVal.Elem().Set(sectionVal)
113+
}
114+
return nil
97115
}
98116

99117
type Layer struct {
@@ -108,17 +126,6 @@ type Layer struct {
108126
Sections map[string]LayerSection `yaml:",inline,omitempty"`
109127
}
110128

111-
// addSection adds a new section to the layer.
112-
func (layer *Layer) addSection(field string, section LayerSection) {
113-
layer.Sections[field] = section
114-
}
115-
116-
// Section retrieves a layer section from a layer. Returns nil if
117-
// the field does not exist.
118-
func (layer *Layer) Section(field string) LayerSection {
119-
return layer.Sections[field]
120-
}
121-
122129
type Service struct {
123130
// Basic details
124131
Name string `yaml:"-"`
@@ -708,7 +715,7 @@ func CombineLayers(layers ...*Layer) (*Layer, error) {
708715
for field, extension := range layerExtensions {
709716
var sections []LayerSection
710717
for _, layer := range layers {
711-
if section := layer.Section(field); section != nil {
718+
if section := layer.Sections[field]; section != nil {
712719
sections = append(sections, section)
713720
}
714721
}
@@ -723,7 +730,7 @@ func CombineLayers(layers ...*Layer) (*Layer, error) {
723730
}
724731
// We support the ability for a valid combine to result in an omitted section.
725732
if combinedSection != nil {
726-
combined.addSection(field, combinedSection)
733+
combined.Sections[field] = combinedSection
727734
}
728735
}
729736
}
@@ -1182,7 +1189,7 @@ func ParseLayer(order int, label string, data []byte) (*Layer, error) {
11821189
}
11831190
}
11841191
if extendedSection != nil {
1185-
layer.addSection(field, extendedSection)
1192+
layer.Sections[field] = extendedSection
11861193
}
11871194
} else {
11881195
// At the top level we do not ignore keys we do not understand.

0 commit comments

Comments
 (0)