Skip to content

Commit 06a8049

Browse files
committed
fix(cfg/flag): fix call Elem on non-initialized Pointers
1 parent 884f4e3 commit 06a8049

File tree

3 files changed

+204
-5
lines changed

3 files changed

+204
-5
lines changed

cfg/source/flag/flag.go

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"reflect"
8+
"unsafe"
89

910
"github.com/krostar/cli"
1011
clicfg "github.com/krostar/cli/cfg"
@@ -101,7 +102,7 @@ func Source[T any](flagDest *T) clicfg.SourceFunc[T] {
101102
//
102103
// At the end of successful processing, the pointers map should be empty, indicating that all
103104
// flag values were successfully transferred to the config struct.
104-
func recursivelyWalkThroughReflectValue(pointers map[uintptr]struct{}, v1, v2 reflect.Value) error {
105+
func recursivelyWalkThroughReflectValue(pointers map[uintptr]struct{}, v1, v2 reflect.Value, applyWOs ...func() error) error {
105106
if len(pointers) == 0 {
106107
return nil
107108
}
@@ -114,18 +115,27 @@ func recursivelyWalkThroughReflectValue(pointers map[uintptr]struct{}, v1, v2 re
114115

115116
v1ptr := uintptr(v1.Addr().UnsafePointer())
116117
if _, ok := pointers[v1ptr]; !ok {
117-
return recursivelyWalkThroughReflectValue(pointers, v1.Elem(), v2.Elem())
118+
var applyWO func() error
119+
120+
v2, applyWO = ensurePointerInitialized(v2)
121+
applyWOs = append(applyWOs, applyWO)
122+
123+
return recursivelyWalkThroughReflectValue(pointers, v1.Elem(), v2.Elem(), applyWOs...)
118124
}
119125

120126
v2.Set(v1)
121127
delete(pointers, v1ptr)
122128

129+
if err := applyAllWritingOperations(applyWOs); err != nil {
130+
return fmt.Errorf("unable to apply all writing operations: %v", err)
131+
}
132+
123133
return nil
124134

125135
case reflect.Struct:
126136
var errs []error
127137
for i := range v1.NumField() {
128-
errs = append(errs, recursivelyWalkThroughReflectValue(pointers, v1.Field(i), v2.Field(i)))
138+
errs = append(errs, recursivelyWalkThroughReflectValue(pointers, v1.Field(i), v2.Field(i), applyWOs...))
129139
}
130140

131141
return errors.Join(errs...)
@@ -143,6 +153,57 @@ func recursivelyWalkThroughReflectValue(pointers map[uintptr]struct{}, v1, v2 re
143153
v2.Set(v1)
144154
delete(pointers, v1ptr)
145155

156+
if err := applyAllWritingOperations(applyWOs); err != nil {
157+
return fmt.Errorf("unable to apply all writing operations: %v", err)
158+
}
159+
146160
return nil
147161
}
148162
}
163+
164+
func ensurePointerInitialized(v reflect.Value) (reflect.Value, func() error) {
165+
if !v.IsNil() {
166+
return v, func() error { return nil }
167+
}
168+
169+
oldV := v
170+
newV := reflect.New(v.Type().Elem())
171+
172+
return newV, func() error {
173+
return reflectSetValue(oldV, newV)
174+
}
175+
}
176+
177+
func reflectSetValue(dst, src reflect.Value) error {
178+
if !dst.IsValid() || !src.IsValid() {
179+
return fmt.Errorf("invalid value: dst.IsValid=%v src.IsValid=%v", dst.IsValid(), src.IsValid())
180+
}
181+
182+
if dst.Type() != src.Type() {
183+
return fmt.Errorf("type mismatch: %v != %v", dst.Type(), src.Type())
184+
}
185+
186+
if dst.CanSet() {
187+
dst.Set(src)
188+
return nil
189+
}
190+
191+
if !dst.CanAddr() {
192+
return fmt.Errorf("dst is not addressable; type=%v", dst.Type())
193+
}
194+
195+
dst = reflect.NewAt(dst.Type(), unsafe.Pointer(dst.UnsafeAddr())).Elem()
196+
dst.Set(src)
197+
198+
return nil
199+
}
200+
201+
func applyAllWritingOperations(wo []func() error) error {
202+
for i := range wo {
203+
if err := wo[len(wo)-1-i](); err != nil {
204+
return err
205+
}
206+
}
207+
208+
return nil
209+
}

cfg/source/flag/flag_test.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package sourceflag
22

33
import (
4+
"reflect"
45
"strings"
56
"testing"
67

@@ -173,6 +174,138 @@ func Test_Source(t *testing.T) {
173174
err := src(ctx, new(configWithFlag))
174175
test.Assert(t, err != nil && strings.Contains(err.Error(), "some values where not found, make sure flag values all points to config"))
175176
})
177+
178+
t.Run("handling embedding results in cfg", func(t *testing.T) {
179+
type (
180+
simple struct{ Bar string }
181+
embedded struct{ *simple }
182+
)
183+
184+
var cfgForFlags simple
185+
186+
ctx := cli.NewCommandContext(test.Context(t))
187+
cli.SetInitializedFlagsInContext(ctx, []cli.Flag{cli.NewBuiltinFlag("bar", "", &cfgForFlags.Bar, "")}, nil)
188+
189+
flagsLocal, flagsPersistent := cli.GetInitializedFlagsFromContext(ctx)
190+
test.Require(t, len(flagsLocal) == 1 && len(flagsPersistent) == 0)
191+
test.Assert(t, flagsLocal[0].FromString("bar") == nil)
192+
193+
var cfg embedded
194+
test.Require(t, Source(&embedded{simple: &cfgForFlags})(ctx, &cfg) == nil)
195+
test.Assert(t, cfg.simple != nil && cfg.Bar == "bar")
196+
})
197+
}
198+
199+
func Test_ensurePointerInitialized(t *testing.T) {
200+
for name, tc := range map[string]struct {
201+
setup func() reflect.Value
202+
expectValue func(t *testing.T, applyWO func() error, o, v reflect.Value) (test.TestingT, bool)
203+
}{
204+
"non-nil pointer": {
205+
setup: func() reflect.Value {
206+
v := &struct{ Field int }{Field: 42}
207+
return reflect.ValueOf(&v).Elem()
208+
},
209+
expectValue: func(t *testing.T, _ func() error, o, v reflect.Value) (test.TestingT, bool) {
210+
a := v.IsValid() && !v.IsNil()
211+
b := o.Addr().Pointer() == v.Addr().Pointer()
212+
return t, a && b
213+
},
214+
},
215+
"nil pointer": {
216+
setup: func() reflect.Value {
217+
var v *struct{ Field int }
218+
return reflect.ValueOf(&v).Elem()
219+
},
220+
expectValue: func(t *testing.T, applyWO func() error, o, v reflect.Value) (test.TestingT, bool) {
221+
a := o.IsNil()
222+
b := v.IsValid() && !v.IsNil()
223+
err := applyWO()
224+
c := !o.IsNil()
225+
d := o.Pointer() == v.Pointer()
226+
return t, a && b && err == nil && c && d
227+
},
228+
},
229+
} {
230+
t.Run(name, func(t *testing.T) {
231+
original := tc.setup()
232+
result, applyWO := ensurePointerInitialized(original)
233+
test.Assert(tc.expectValue(t, applyWO, original, result))
234+
})
235+
}
236+
}
237+
238+
func Test_reflectSetValue(t *testing.T) {
239+
for name, tc := range map[string]struct {
240+
setup func() (reflect.Value, reflect.Value)
241+
unsafe bool
242+
expectValue func(t *testing.T, dst, src reflect.Value) (test.TestingT, bool)
243+
expectErrorMessage string
244+
}{
245+
"invalid dst": {
246+
setup: func() (reflect.Value, reflect.Value) {
247+
return reflect.Value{}, reflect.ValueOf(42)
248+
},
249+
expectErrorMessage: "invalid value: dst.IsValid=false",
250+
},
251+
"invalid src": {
252+
setup: func() (reflect.Value, reflect.Value) {
253+
var v int
254+
return reflect.ValueOf(&v).Elem(), reflect.Value{}
255+
},
256+
expectErrorMessage: "invalid value: dst.IsValid=true src.IsValid=false",
257+
},
258+
"type mismatch": {
259+
setup: func() (reflect.Value, reflect.Value) {
260+
var dst int
261+
src := "string"
262+
return reflect.ValueOf(&dst).Elem(), reflect.ValueOf(src)
263+
},
264+
expectErrorMessage: "type mismatch",
265+
},
266+
"can set": {
267+
setup: func() (reflect.Value, reflect.Value) {
268+
var dst int
269+
src := 42
270+
return reflect.ValueOf(&dst).Elem(), reflect.ValueOf(src)
271+
},
272+
expectValue: func(t *testing.T, dst, _ reflect.Value) (test.TestingT, bool) {
273+
return t, dst.Int() == 42
274+
},
275+
},
276+
"can set with unsafe": {
277+
setup: func() (reflect.Value, reflect.Value) {
278+
s := struct{ field int }{}
279+
src := 42
280+
return reflect.ValueOf(&s).Elem().FieldByName("field"), reflect.ValueOf(src)
281+
},
282+
unsafe: true,
283+
expectValue: func(t *testing.T, dst, _ reflect.Value) (test.TestingT, bool) {
284+
return t, dst.Int() == 42
285+
},
286+
},
287+
"not addressable": {
288+
setup: func() (reflect.Value, reflect.Value) {
289+
s := struct{ Field int }{}
290+
src := 42
291+
field := reflect.ValueOf(s).FieldByName("Field")
292+
return field, reflect.ValueOf(src)
293+
},
294+
expectErrorMessage: "dst is not addressable",
295+
},
296+
} {
297+
t.Run(name, func(t *testing.T) {
298+
dst, src := tc.setup()
299+
300+
err := reflectSetValue(dst, src)
301+
if tc.expectErrorMessage != "" {
302+
test.Require(t, err != nil && strings.Contains(err.Error(), tc.expectErrorMessage), err)
303+
} else {
304+
test.Require(t, err == nil, err)
305+
test.Assert(tc.expectValue(t, dst, src))
306+
}
307+
})
308+
}
176309
}
177310

178311
func ptrTo[T any](v T) *T { return &v }

nix/repo/data.nix

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@
1515
linters = ["depguard"];
1616
text = "import 'reflect' is not allowed";
1717
}
18+
{
19+
path = "cfg/source/flag/flag.go";
20+
linters = ["gosec"];
21+
text = "Use of unsafe calls should be audited";
22+
}
1823
{
1924
path = "double/internal/generator/";
20-
linters = ["errcheck" "gosec" "goconst"];
21-
text = "Error return value of `file.Close` is not checked|Potential file inclusion via variable|make it a constant";
25+
linters = ["gosec" "goconst"];
26+
text = "Potential file inclusion via variable|make it a constant";
2227
}
2328
];
2429
};

0 commit comments

Comments
 (0)