|
1 | 1 | package sourceflag |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "reflect" |
4 | 5 | "strings" |
5 | 6 | "testing" |
6 | 7 |
|
@@ -173,6 +174,138 @@ func Test_Source(t *testing.T) { |
173 | 174 | err := src(ctx, new(configWithFlag)) |
174 | 175 | test.Assert(t, err != nil && strings.Contains(err.Error(), "some values where not found, make sure flag values all points to config")) |
175 | 176 | }) |
| 177 | + |
| 178 | + t.Run("handling embedding results in cfg", func(t *testing.T) { |
| 179 | + type ( |
| 180 | + simple struct{ Bar string } |
| 181 | + embedded struct{ *simple } |
| 182 | + ) |
| 183 | + |
| 184 | + var cfgForFlags simple |
| 185 | + |
| 186 | + ctx := cli.NewCommandContext(test.Context(t)) |
| 187 | + cli.SetInitializedFlagsInContext(ctx, []cli.Flag{cli.NewBuiltinFlag("bar", "", &cfgForFlags.Bar, "")}, nil) |
| 188 | + |
| 189 | + flagsLocal, flagsPersistent := cli.GetInitializedFlagsFromContext(ctx) |
| 190 | + test.Require(t, len(flagsLocal) == 1 && len(flagsPersistent) == 0) |
| 191 | + test.Assert(t, flagsLocal[0].FromString("bar") == nil) |
| 192 | + |
| 193 | + var cfg embedded |
| 194 | + test.Require(t, Source(&embedded{simple: &cfgForFlags})(ctx, &cfg) == nil) |
| 195 | + test.Assert(t, cfg.simple != nil && cfg.Bar == "bar") |
| 196 | + }) |
| 197 | +} |
| 198 | + |
| 199 | +func Test_ensurePointerInitialized(t *testing.T) { |
| 200 | + for name, tc := range map[string]struct { |
| 201 | + setup func() reflect.Value |
| 202 | + expectValue func(t *testing.T, applyWO func() error, o, v reflect.Value) (test.TestingT, bool) |
| 203 | + }{ |
| 204 | + "non-nil pointer": { |
| 205 | + setup: func() reflect.Value { |
| 206 | + v := &struct{ Field int }{Field: 42} |
| 207 | + return reflect.ValueOf(&v).Elem() |
| 208 | + }, |
| 209 | + expectValue: func(t *testing.T, _ func() error, o, v reflect.Value) (test.TestingT, bool) { |
| 210 | + a := v.IsValid() && !v.IsNil() |
| 211 | + b := o.Addr().Pointer() == v.Addr().Pointer() |
| 212 | + return t, a && b |
| 213 | + }, |
| 214 | + }, |
| 215 | + "nil pointer": { |
| 216 | + setup: func() reflect.Value { |
| 217 | + var v *struct{ Field int } |
| 218 | + return reflect.ValueOf(&v).Elem() |
| 219 | + }, |
| 220 | + expectValue: func(t *testing.T, applyWO func() error, o, v reflect.Value) (test.TestingT, bool) { |
| 221 | + a := o.IsNil() |
| 222 | + b := v.IsValid() && !v.IsNil() |
| 223 | + err := applyWO() |
| 224 | + c := !o.IsNil() |
| 225 | + d := o.Pointer() == v.Pointer() |
| 226 | + return t, a && b && err == nil && c && d |
| 227 | + }, |
| 228 | + }, |
| 229 | + } { |
| 230 | + t.Run(name, func(t *testing.T) { |
| 231 | + original := tc.setup() |
| 232 | + result, applyWO := ensurePointerInitialized(original) |
| 233 | + test.Assert(tc.expectValue(t, applyWO, original, result)) |
| 234 | + }) |
| 235 | + } |
| 236 | +} |
| 237 | + |
| 238 | +func Test_reflectSetValue(t *testing.T) { |
| 239 | + for name, tc := range map[string]struct { |
| 240 | + setup func() (reflect.Value, reflect.Value) |
| 241 | + unsafe bool |
| 242 | + expectValue func(t *testing.T, dst, src reflect.Value) (test.TestingT, bool) |
| 243 | + expectErrorMessage string |
| 244 | + }{ |
| 245 | + "invalid dst": { |
| 246 | + setup: func() (reflect.Value, reflect.Value) { |
| 247 | + return reflect.Value{}, reflect.ValueOf(42) |
| 248 | + }, |
| 249 | + expectErrorMessage: "invalid value: dst.IsValid=false", |
| 250 | + }, |
| 251 | + "invalid src": { |
| 252 | + setup: func() (reflect.Value, reflect.Value) { |
| 253 | + var v int |
| 254 | + return reflect.ValueOf(&v).Elem(), reflect.Value{} |
| 255 | + }, |
| 256 | + expectErrorMessage: "invalid value: dst.IsValid=true src.IsValid=false", |
| 257 | + }, |
| 258 | + "type mismatch": { |
| 259 | + setup: func() (reflect.Value, reflect.Value) { |
| 260 | + var dst int |
| 261 | + src := "string" |
| 262 | + return reflect.ValueOf(&dst).Elem(), reflect.ValueOf(src) |
| 263 | + }, |
| 264 | + expectErrorMessage: "type mismatch", |
| 265 | + }, |
| 266 | + "can set": { |
| 267 | + setup: func() (reflect.Value, reflect.Value) { |
| 268 | + var dst int |
| 269 | + src := 42 |
| 270 | + return reflect.ValueOf(&dst).Elem(), reflect.ValueOf(src) |
| 271 | + }, |
| 272 | + expectValue: func(t *testing.T, dst, _ reflect.Value) (test.TestingT, bool) { |
| 273 | + return t, dst.Int() == 42 |
| 274 | + }, |
| 275 | + }, |
| 276 | + "can set with unsafe": { |
| 277 | + setup: func() (reflect.Value, reflect.Value) { |
| 278 | + s := struct{ field int }{} |
| 279 | + src := 42 |
| 280 | + return reflect.ValueOf(&s).Elem().FieldByName("field"), reflect.ValueOf(src) |
| 281 | + }, |
| 282 | + unsafe: true, |
| 283 | + expectValue: func(t *testing.T, dst, _ reflect.Value) (test.TestingT, bool) { |
| 284 | + return t, dst.Int() == 42 |
| 285 | + }, |
| 286 | + }, |
| 287 | + "not addressable": { |
| 288 | + setup: func() (reflect.Value, reflect.Value) { |
| 289 | + s := struct{ Field int }{} |
| 290 | + src := 42 |
| 291 | + field := reflect.ValueOf(s).FieldByName("Field") |
| 292 | + return field, reflect.ValueOf(src) |
| 293 | + }, |
| 294 | + expectErrorMessage: "dst is not addressable", |
| 295 | + }, |
| 296 | + } { |
| 297 | + t.Run(name, func(t *testing.T) { |
| 298 | + dst, src := tc.setup() |
| 299 | + |
| 300 | + err := reflectSetValue(dst, src) |
| 301 | + if tc.expectErrorMessage != "" { |
| 302 | + test.Require(t, err != nil && strings.Contains(err.Error(), tc.expectErrorMessage), err) |
| 303 | + } else { |
| 304 | + test.Require(t, err == nil, err) |
| 305 | + test.Assert(tc.expectValue(t, dst, src)) |
| 306 | + } |
| 307 | + }) |
| 308 | + } |
176 | 309 | } |
177 | 310 |
|
178 | 311 | func ptrTo[T any](v T) *T { return &v } |
0 commit comments