diff --git a/decode.go b/decode.go index a365a2e..f16f4fb 100644 --- a/decode.go +++ b/decode.go @@ -91,6 +91,13 @@ func (d *Decoder) Decode(v interface{}) error { } func (d *Decoder) unmarshal(pval *plistValue, v reflect.Value) error { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + // check for empty interface v type if v.Kind() == reflect.Interface && v.NumMethod() == 0 { val := reflect.ValueOf(d.valueInterface(pval)) @@ -101,13 +108,6 @@ func (d *Decoder) unmarshal(pval *plistValue, v reflect.Value) error { return nil } - if v.Kind() == reflect.Ptr { - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - v = v.Elem() - } - unmarshalerType := reflect.TypeOf((*Unmarshaler)(nil)).Elem() if v.CanInterface() && v.Type().Implements(unmarshalerType) { diff --git a/decode_test.go b/decode_test.go index e0b9d75..59c7faa 100644 --- a/decode_test.go +++ b/decode_test.go @@ -555,3 +555,64 @@ func TestXMLPlutilParity(t *testing.T) { } } } + +type testVal struct { + s string + b bool +} + +func (v *testVal) UnmarshalPlist(f func(interface{}) error) (err error) { + var val interface{} + err = f(&val) + if err != nil { + return err + } + switch value := val.(type) { + case string: + v.s = value + case bool: + v.b = value + } + return nil +} + +type nestedType struct { + Val *testVal `plist:"val"` + Val2 *testVal `plist:"val2"` +} + +// TestDecodeCustomType tests decoding a type that decodes into multiple types +// based on the underlying plist type +func TestDecodeCustomType(t *testing.T) { + p := ` + + + + val + val + val2 + + + ` + r := bytes.NewBuffer([]byte(p)) + decoder := NewXMLDecoder(r) + typ := new(nestedType) + err := decoder.Decode(typ) + if err != nil { + t.Fatalf("could not read profile: %v", err) + } + + if typ.Val == nil { + t.Fatal("unexpected nil for typ.Val") + } + if have, want := typ.Val.s, "val"; have != want { + t.Errorf("typ.Val: have %v, want %v", have, want) + } + + if typ.Val2 == nil { + t.Fatal("unexpected nil for typ.Val2") + } + if have, want := typ.Val2.b, true; have != want { + t.Errorf("typ.Val2: have %v, want %v", have, want) + } +}