Skip to content

Commit 9671363

Browse files
authored
encodeMap() error handling and TestEncoder_UnmarshallableTypes() (#674)
* encodeMap() error handling and TestEncoder_UnmarshallableTypes() * more TestEncoder_UnmarshallableTypes test cases
1 parent 944206a commit 9671363

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

encode.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
492492
if value := e.encodePtrAnchor(v, column); value != nil {
493493
return value, nil
494494
}
495-
return e.encodeMap(ctx, v, column), nil
495+
return e.encodeMap(ctx, v, column)
496496
default:
497497
return nil, fmt.Errorf("unknown value type %s", v.Type().String())
498498
}
@@ -684,7 +684,7 @@ func (e *Encoder) isTagAndMapNode(node ast.Node) bool {
684684
return ok && e.isMapNode(tn.Value)
685685
}
686686

687-
func (e *Encoder) encodeMap(ctx context.Context, value reflect.Value, column int) ast.Node {
687+
func (e *Encoder) encodeMap(ctx context.Context, value reflect.Value, column int) (ast.Node, error) {
688688
node := ast.Mapping(token.New("", "", e.pos(column)), e.isFlowStyle)
689689
keys := make([]interface{}, len(value.MapKeys()))
690690
for i, k := range value.MapKeys() {
@@ -698,7 +698,7 @@ func (e *Encoder) encodeMap(ctx context.Context, value reflect.Value, column int
698698
v := value.MapIndex(k)
699699
value, err := e.encodeValue(ctx, v, column)
700700
if err != nil {
701-
return nil
701+
return nil, err
702702
}
703703
if e.isMapNode(value) {
704704
value.AddColumn(e.indentNum)
@@ -724,7 +724,7 @@ func (e *Encoder) encodeMap(ctx context.Context, value reflect.Value, column int
724724
))
725725
e.setSmartAnchor(vRef, keyText)
726726
}
727-
return node
727+
return node, nil
728728
}
729729

730730
// IsZeroer is used to check whether an object is zero to determine

encode_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"strings"
1212
"testing"
1313
"time"
14+
"unsafe"
1415

1516
"github.com/goccy/go-yaml"
1617
"github.com/goccy/go-yaml/ast"
@@ -1345,6 +1346,81 @@ func TestEncoder_MultipleDocuments(t *testing.T) {
13451346
}
13461347
}
13471348

1349+
func TestEncoder_UnmarshallableTypes(t *testing.T) {
1350+
for _, test := range []struct {
1351+
desc string
1352+
input any
1353+
expectedErr string
1354+
}{
1355+
{
1356+
desc: "channel",
1357+
input: make(chan int),
1358+
expectedErr: "unknown value type chan int",
1359+
},
1360+
{
1361+
desc: "function",
1362+
input: func() {},
1363+
expectedErr: "unknown value type func()",
1364+
},
1365+
{
1366+
desc: "complex number",
1367+
input: complex(10, 11),
1368+
expectedErr: "unknown value type complex128",
1369+
},
1370+
{
1371+
desc: "unsafe pointer",
1372+
input: unsafe.Pointer(&struct{}{}),
1373+
expectedErr: "unknown value type unsafe.Pointer",
1374+
},
1375+
{
1376+
desc: "uintptr",
1377+
input: uintptr(0x1234),
1378+
expectedErr: "unknown value type uintptr",
1379+
},
1380+
{
1381+
desc: "map with channel",
1382+
input: map[string]any{"key": make(chan string)},
1383+
expectedErr: "unknown value type chan string",
1384+
},
1385+
{
1386+
desc: "nested map with func",
1387+
input: map[string]any{
1388+
"a": map[string]any{
1389+
"b": func(_ string) {},
1390+
},
1391+
},
1392+
expectedErr: "unknown value type func(string)",
1393+
},
1394+
{
1395+
desc: "slice with channel",
1396+
input: []any{make(chan bool)},
1397+
expectedErr: "unknown value type chan bool",
1398+
},
1399+
{
1400+
desc: "nested slice with complex number",
1401+
input: []any{[]any{complex(10, 11)}},
1402+
expectedErr: "unknown value type complex128",
1403+
},
1404+
{
1405+
desc: "struct with unsafe pointer",
1406+
input: struct {
1407+
Field unsafe.Pointer `yaml:"field"`
1408+
}{},
1409+
expectedErr: "unknown value type unsafe.Pointer",
1410+
},
1411+
} {
1412+
t.Run(test.desc, func(t *testing.T) {
1413+
var buf bytes.Buffer
1414+
err := yaml.NewEncoder(&buf).Encode(test.input)
1415+
if err == nil {
1416+
t.Errorf("expect error:\n%s\nbut got none\n", test.expectedErr)
1417+
} else if err.Error() != test.expectedErr {
1418+
t.Errorf("expect error:\n%s\nactual\n%s\n", test.expectedErr, err)
1419+
}
1420+
})
1421+
}
1422+
}
1423+
13481424
func ExampleMarshal_node() {
13491425
type T struct {
13501426
Text ast.Node `yaml:"text"`

0 commit comments

Comments
 (0)