Skip to content

Commit a80ae1d

Browse files
authored
Add support for keys implement TextMarshaler & TextUnmarshaler from MapCodec (#946)
1 parent 0cdb185 commit a80ae1d

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

bson/bson_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"bytes"
1111
"fmt"
1212
"reflect"
13+
"strconv"
1314
"strings"
1415
"testing"
1516
"time"
@@ -140,6 +141,29 @@ func (kb *keyBool) UnmarshalKey(key string) error {
140141
return nil
141142
}
142143

144+
type keyStruct struct {
145+
val int64
146+
}
147+
148+
func (k keyStruct) MarshalText() (text []byte, err error) {
149+
str := strconv.FormatInt(k.val, 10)
150+
151+
return []byte(str), nil
152+
}
153+
154+
func (k *keyStruct) UnmarshalText(text []byte) error {
155+
val, err := strconv.ParseInt(string(text), 10, 64)
156+
if err != nil {
157+
return err
158+
}
159+
160+
*k = keyStruct{
161+
val: val,
162+
}
163+
164+
return nil
165+
}
166+
143167
func TestMapCodec(t *testing.T) {
144168
t.Run("EncodeKeysWithStringer", func(t *testing.T) {
145169
strstr := stringerString("foo")
@@ -163,6 +187,7 @@ func TestMapCodec(t *testing.T) {
163187
})
164188
}
165189
})
190+
166191
t.Run("keys implements keyMarshaler and keyUnmarshaler", func(t *testing.T) {
167192
mapObj := map[keyBool]int{keyBool(true): 1}
168193

@@ -179,6 +204,25 @@ func TestMapCodec(t *testing.T) {
179204
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
180205

181206
})
207+
208+
t.Run("keys implements encoding.TextMarshaler and encoding.TextUnmarshaler", func(t *testing.T) {
209+
mapObj := map[keyStruct]int{
210+
{val: 10}: 100,
211+
}
212+
213+
doc, err := Marshal(mapObj)
214+
assert.Nil(t, err, "Marshal error: %v", err)
215+
idx, want := bsoncore.AppendDocumentStart(nil)
216+
want = bsoncore.AppendInt32Element(want, "10", 100)
217+
want, _ = bsoncore.AppendDocumentEnd(want, idx)
218+
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
219+
220+
var got map[keyStruct]int
221+
err = Unmarshal(doc, &got)
222+
assert.Nil(t, err, "Unmarshal error: %v", err)
223+
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
224+
225+
})
182226
}
183227

184228
func TestExtJSONEscapeKey(t *testing.T) {

bson/bsoncodec/map_codec.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package bsoncodec
88

99
import (
10+
"encoding"
1011
"fmt"
1112
"reflect"
1213
"strconv"
@@ -230,6 +231,19 @@ func (mc *MapCodec) encodeKey(val reflect.Value) (string, error) {
230231
}
231232
return "", err
232233
}
234+
// keys implement encoding.TextMarshaler are marshaled.
235+
if km, ok := val.Interface().(encoding.TextMarshaler); ok {
236+
if val.Kind() == reflect.Ptr && val.IsNil() {
237+
return "", nil
238+
}
239+
240+
buf, err := km.MarshalText()
241+
if err != nil {
242+
return "", err
243+
}
244+
245+
return string(buf), nil
246+
}
233247

234248
switch val.Kind() {
235249
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -241,6 +255,7 @@ func (mc *MapCodec) encodeKey(val reflect.Value) (string, error) {
241255
}
242256

243257
var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
258+
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
244259

245260
func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
246261
keyVal := reflect.ValueOf(key)
@@ -252,6 +267,12 @@ func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value,
252267
v := keyVal.Interface().(KeyUnmarshaler)
253268
err = v.UnmarshalKey(key)
254269
keyVal = keyVal.Elem()
270+
// Try to decode encoding.TextUnmarshalers.
271+
case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
272+
keyVal = reflect.New(keyType)
273+
v := keyVal.Interface().(encoding.TextUnmarshaler)
274+
err = v.UnmarshalText([]byte(key))
275+
keyVal = keyVal.Elem()
255276
// Otherwise, go to type specific behavior
256277
default:
257278
switch keyType.Kind() {

0 commit comments

Comments
 (0)