Skip to content

Commit 8c8d4c3

Browse files
authored
Merge pull request #682 from benluddy/textunmarshaler-bytestringtostring
Use TextUnmarshaler on byte strings with ByteStringToStringAllowed.
2 parents a89c3ce + 9bdebd2 commit 8c8d4c3

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

decode.go

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,7 +1570,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
15701570
return err
15711571
}
15721572
copied = copied || converted
1573-
return fillByteString(t, b, !copied, v, d.dm.byteStringToString, d.dm.binaryUnmarshaler)
1573+
return fillByteString(t, b, !copied, v, d.dm.byteStringToString, d.dm.binaryUnmarshaler, d.dm.textUnmarshaler)
15741574

15751575
case cborTypeTextString:
15761576
b, err := d.parseTextString()
@@ -1629,7 +1629,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
16291629
return nil
16301630
}
16311631
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
1632-
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler)
1632+
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler, d.dm.textUnmarshaler)
16331633
}
16341634
if bi.IsUint64() {
16351635
return fillPositiveInt(t, bi.Uint64(), v)
@@ -1652,7 +1652,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
16521652
return nil
16531653
}
16541654
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
1655-
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler)
1655+
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler, d.dm.textUnmarshaler)
16561656
}
16571657
if bi.IsInt64() {
16581658
return fillNegativeInt(t, bi.Int64(), v)
@@ -3180,7 +3180,7 @@ func fillFloat(t cborType, val float64, v reflect.Value) error {
31803180
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
31813181
}
31823182

3183-
func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode, bum BinaryUnmarshalerMode) error {
3183+
func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode, bum BinaryUnmarshalerMode, tum TextUnmarshalerMode) error {
31843184
if bum == BinaryUnmarshalerByteString && reflect.PointerTo(v.Type()).Implements(typeBinaryUnmarshaler) {
31853185
if v.CanAddr() {
31863186
v = v.Addr()
@@ -3193,9 +3193,26 @@ func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts B
31933193
}
31943194
return errors.New("cbor: cannot set new value for " + v.Type().String())
31953195
}
3196-
if bsts != ByteStringToStringForbidden && v.Kind() == reflect.String {
3197-
v.SetString(string(val))
3198-
return nil
3196+
if bsts != ByteStringToStringForbidden {
3197+
if tum == TextUnmarshalerTextString && reflect.PointerTo(v.Type()).Implements(typeTextUnmarshaler) {
3198+
if v.CanAddr() {
3199+
v = v.Addr()
3200+
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
3201+
// The contract of TextUnmarshaler forbids retaining the input
3202+
// bytes, so no copying is required even if val is shared.
3203+
if err := u.UnmarshalText(val); err != nil {
3204+
return fmt.Errorf("cbor: cannot unmarshal text for %s: %w", v.Type(), err)
3205+
}
3206+
return nil
3207+
}
3208+
}
3209+
return errors.New("cbor: cannot set new value for " + v.Type().String())
3210+
}
3211+
3212+
if v.Kind() == reflect.String {
3213+
v.SetString(string(val))
3214+
return nil
3215+
}
31993216
}
32003217
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
32013218
src := val

decode_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10779,6 +10779,15 @@ func TestTextUnmarshalerMode(t *testing.T) {
1077910779
in: []byte("\x65hello"), // "hello"
1078010780
want: testTextUnmarshaler("hello"),
1078110781
},
10782+
{
10783+
name: "UnmarshalText is called for byte string with TextUnmarshalerTextString and ByteStringToStringAllowed",
10784+
opts: DecOptions{
10785+
TextUnmarshaler: TextUnmarshalerTextString,
10786+
ByteStringToString: ByteStringToStringAllowed,
10787+
},
10788+
in: []byte("\x45hello"), // 'hello'
10789+
want: testTextUnmarshaler("UnmarshalText"),
10790+
},
1078210791
} {
1078310792
t.Run(tc.name, func(t *testing.T) {
1078410793
dm, err := tc.opts.DecMode()

0 commit comments

Comments
 (0)