Skip to content

Commit 7b75c3c

Browse files
author
Divjot Arora
authored
GODRIVER-1590 Use MarshalValue in transformValue helper (#386)
1 parent 5ae4e01 commit 7b75c3c

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

mongo/errors.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ var ErrClientDisconnected = errors.New("client is disconnected")
2626
// ErrNilDocument is returned when a nil document is passed to a CRUD method.
2727
var ErrNilDocument = errors.New("document is nil")
2828

29+
// ErrNilValue is returned when a nil value is passed to a CRUD method.
30+
var ErrNilValue = errors.New("value is nil")
31+
2932
// ErrEmptySlice is returned when an empty slice is passed to a CRUD method that requires a non-empty slice.
3033
var ErrEmptySlice = errors.New("must provide at least one element in input slice")
3134

mongo/mongo.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -430,17 +430,20 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, chec
430430
}
431431

432432
func transformValue(registry *bsoncodec.Registry, val interface{}) (bsoncore.Value, error) {
433-
switch conv := val.(type) {
434-
case string:
435-
return bsoncore.Value{Type: bsontype.String, Data: bsoncore.AppendString(nil, conv)}, nil
436-
default:
437-
doc, err := transformBsoncoreDocument(registry, val)
438-
if err != nil {
439-
return bsoncore.Value{}, err
440-
}
433+
if registry == nil {
434+
registry = bson.DefaultRegistry
435+
}
436+
if val == nil {
437+
return bsoncore.Value{}, ErrNilValue
438+
}
441439

442-
return bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: doc}, nil
440+
buf := make([]byte, 0, 256)
441+
bsonType, bsonValue, err := bson.MarshalValueAppendWithRegistry(registry, buf[:0], val)
442+
if err != nil {
443+
return bsoncore.Value{}, MarshalError{Value: val, Err: err}
443444
}
445+
446+
return bsoncore.Value{Type: bsonType, Data: bsonValue}, nil
444447
}
445448

446449
// Build the aggregation pipeline for the CountDocument command.

mongo/mongo_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,38 @@ func TestMongoHelpers(t *testing.T) {
274274
})
275275
}
276276
})
277+
t.Run("transform value", func(t *testing.T) {
278+
valueMarshaler := bvMarsh{
279+
t: bsontype.String,
280+
data: bsoncore.AppendString(nil, "foo"),
281+
}
282+
doc := bson.D{{"x", 1}}
283+
docBytes, _ := bson.Marshal(doc)
284+
285+
testCases := []struct {
286+
name string
287+
value interface{}
288+
err error
289+
bsonType bsontype.Type
290+
bsonValue []byte
291+
}{
292+
{"nil document", nil, ErrNilValue, 0, nil},
293+
{"value marshaler", valueMarshaler, nil, valueMarshaler.t, valueMarshaler.data},
294+
{"document", doc, nil, bsontype.EmbeddedDocument, docBytes},
295+
}
296+
for _, tc := range testCases {
297+
t.Run(tc.name, func(t *testing.T) {
298+
res, err := transformValue(nil, tc.value)
299+
if tc.err != nil {
300+
assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
301+
return
302+
}
303+
304+
assert.Equal(t, tc.bsonType, res.Type, "expected BSON type %s, got %s", tc.bsonType, res.Type)
305+
assert.Equal(t, tc.bsonValue, res.Data, "expected BSON data %v, got %v", tc.bsonValue, res.Data)
306+
})
307+
}
308+
})
277309
}
278310

279311
var _ bson.Marshaler = bMarsh{}

0 commit comments

Comments
 (0)