Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ func (d *Decoder) Decode(v interface{}) error {
}

func (d *Decoder) unmarshal(pval *plistValue, v reflect.Value) error {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}

// check for empty interface v type
if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
val := reflect.ValueOf(d.valueInterface(pval))
Expand All @@ -101,13 +108,6 @@ func (d *Decoder) unmarshal(pval *plistValue, v reflect.Value) error {
return nil
}

if v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}

unmarshalerType := reflect.TypeOf((*Unmarshaler)(nil)).Elem()

if v.CanInterface() && v.Type().Implements(unmarshalerType) {
Expand Down
61 changes: 61 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,64 @@ func TestXMLPlutilParity(t *testing.T) {
}
}
}

type testVal struct {
s string
b bool
}

func (v *testVal) UnmarshalPlist(f func(interface{}) error) (err error) {
var val interface{}
err = f(&val)
if err != nil {
return err
}
switch value := val.(type) {
case string:
v.s = value
case bool:
v.b = value
}
return nil
}

type nestedType struct {
Val *testVal `plist:"val"`
Val2 *testVal `plist:"val2"`
}

// TestDecodeCustomType tests decoding a type that decodes into multiple types
// based on the underlying plist type
func TestDecodeCustomType(t *testing.T) {
p := `<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "https://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>val</key>
<string>val</string>
<key>val2</key>
<true></true>
</dict>
</plist>`
r := bytes.NewBuffer([]byte(p))
decoder := NewXMLDecoder(r)
typ := new(nestedType)
err := decoder.Decode(typ)
if err != nil {
t.Fatalf("could not read profile: %v", err)
}

if typ.Val == nil {
t.Fatal("unexpected nil for typ.Val")
}
if have, want := typ.Val.s, "val"; have != want {
t.Errorf("typ.Val: have %v, want %v", have, want)
}

if typ.Val2 == nil {
t.Fatal("unexpected nil for typ.Val2")
}
if have, want := typ.Val2.b, true; have != want {
t.Errorf("typ.Val2: have %v, want %v", have, want)
}
}
Loading