Skip to content

Commit 3d3d895

Browse files
author
Divjot Arora
authored
GODRIVER-1416 Add bson.MarshalValue (#241)
1 parent d88f225 commit 3d3d895

File tree

4 files changed

+278
-22
lines changed

4 files changed

+278
-22
lines changed

bson/bsonrw/value_writer.go

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter {
5656
return vw
5757
}
5858

59+
// GetAtModeElement retrieves a ValueWriterFlusher from the pool and resets it to use w as the destination.
60+
func (bvwp *BSONValueWriterPool) GetAtModeElement(w io.Writer) ValueWriterFlusher {
61+
vw := bvwp.Get(w).(*valueWriter)
62+
vw.push(mElement)
63+
return vw
64+
}
65+
5966
// Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing
6067
// happens and ok will be false.
6168
func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
@@ -512,17 +519,8 @@ func (vw *valueWriter) WriteDocumentEnd() error {
512519
}
513520

514521
if vw.stack[vw.frame].mode == mTopLevel {
515-
if vw.w != nil {
516-
if sw, ok := vw.w.(*SliceWriter); ok {
517-
*sw = vw.buf
518-
} else {
519-
_, err = vw.w.Write(vw.buf)
520-
if err != nil {
521-
return err
522-
}
523-
// reset buffer
524-
vw.buf = vw.buf[:0]
525-
}
522+
if err = vw.Flush(); err != nil {
523+
return err
526524
}
527525
}
528526

@@ -537,6 +535,23 @@ func (vw *valueWriter) WriteDocumentEnd() error {
537535
return nil
538536
}
539537

538+
func (vw *valueWriter) Flush() error {
539+
if vw.w == nil {
540+
return nil
541+
}
542+
543+
if sw, ok := vw.w.(*SliceWriter); ok {
544+
*sw = vw.buf
545+
return nil
546+
}
547+
if _, err := vw.w.Write(vw.buf); err != nil {
548+
return err
549+
}
550+
// reset buffer
551+
vw.buf = vw.buf[:0]
552+
return nil
553+
}
554+
540555
func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) {
541556
if vw.stack[vw.frame].mode != mArray {
542557
return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray})

bson/bsonrw/writer.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ type ValueWriter interface {
5555
WriteUndefined() error
5656
}
5757

58+
// ValueWriterFlusher is a superset of ValueWriter that exposes functionality to flush to the underlying buffer.
59+
type ValueWriterFlusher interface {
60+
ValueWriter
61+
Flush() error
62+
}
63+
5864
// BytesWriter is the interface used to write BSON bytes to a ValueWriter.
5965
// This interface is meant to be a superset of ValueWriter, so that types that
6066
// implement ValueWriter may also implement this interface.

bson/marshal.go

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ type ValueMarshaler interface {
3232
MarshalBSONValue() (bsontype.Type, []byte, error)
3333
}
3434

35-
// Marshal returns the BSON encoding of val.
35+
// Marshal returns the BSON encoding of val as a BSON document. If val is not a type that can be transformed into a
36+
// document, MarshalValue should be used instead.
3637
//
3738
// Marshal will use the default registry created by NewRegistry to recursively
3839
// marshal val into a []byte. Marshal will inspect struct tags and alter the
@@ -41,34 +42,37 @@ func Marshal(val interface{}) ([]byte, error) {
4142
return MarshalWithRegistry(DefaultRegistry, val)
4243
}
4344

44-
// MarshalAppend will append the BSON encoding of val to dst. If dst is not
45-
// large enough to hold the BSON encoding of val, dst will be grown.
45+
// MarshalAppend will encode val as a BSON document and append the bytes to dst. If dst is not large enough to hold the
46+
// bytes, it will be grown. If val is not a type that can be transformed into a document, MarshalValueAppend should be
47+
// used instead.
4648
func MarshalAppend(dst []byte, val interface{}) ([]byte, error) {
4749
return MarshalAppendWithRegistry(DefaultRegistry, dst, val)
4850
}
4951

50-
// MarshalWithRegistry returns the BSON encoding of val using Registry r.
52+
// MarshalWithRegistry returns the BSON encoding of val as a BSON document. If val is not a type that can be transformed
53+
// into a document, MarshalValueWithRegistry should be used instead.
5154
func MarshalWithRegistry(r *bsoncodec.Registry, val interface{}) ([]byte, error) {
5255
dst := make([]byte, 0, 256) // TODO: make the default cap a constant
5356
return MarshalAppendWithRegistry(r, dst, val)
5457
}
5558

56-
// MarshalWithContext returns the BSON encoding of val using EncodeContext ec.
59+
// MarshalWithContext returns the BSON encoding of val as a BSON document using EncodeContext ec. If val is not a type
60+
// that can be transformed into a document, MarshalValueWithContext should be used instead.
5761
func MarshalWithContext(ec bsoncodec.EncodeContext, val interface{}) ([]byte, error) {
5862
dst := make([]byte, 0, 256) // TODO: make the default cap a constant
5963
return MarshalAppendWithContext(ec, dst, val)
6064
}
6165

62-
// MarshalAppendWithRegistry will append the BSON encoding of val to dst using
63-
// Registry r. If dst is not large enough to hold the BSON encoding of val, dst
64-
// will be grown.
66+
// MarshalAppendWithRegistry will encode val as a BSON document using Registry r and append the bytes to dst. If dst is
67+
// not large enough to hold the bytes, it will be grown. If val is not a type that can be transformed into a document,
68+
// MarshalValueAppendWithRegistry should be used instead.
6569
func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) ([]byte, error) {
6670
return MarshalAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val)
6771
}
6872

69-
// MarshalAppendWithContext will append the BSON encoding of val to dst using
70-
// EncodeContext ec. If dst is not large enough to hold the BSON encoding of val, dst
71-
// will be grown.
73+
// MarshalAppendWithContext will encode val as a BSON document using Registry r and EncodeContext ec and append the
74+
// bytes to dst. If dst is not large enough to hold the bytes, it will be grown. If val is not a type that can be
75+
// transformed into a document, MarshalValueAppendWithContext should be used instead.
7276
func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) ([]byte, error) {
7377
sw := new(bsonrw.SliceWriter)
7478
*sw = dst
@@ -95,6 +99,69 @@ func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interf
9599
return *sw, nil
96100
}
97101

102+
// MarshalValue returns the BSON encoding of val.
103+
//
104+
// MarshalValue will use bson.DefaultRegistry to transform val into a BSON value. If val is a struct, this function will
105+
// inspect struct tags and alter the marshalling process accordingly.
106+
func MarshalValue(val interface{}) (bsontype.Type, []byte, error) {
107+
return MarshalValueWithRegistry(DefaultRegistry, val)
108+
}
109+
110+
// MarshalValueAppend will append the BSON encoding of val to dst. If dst is not large enough to hold the BSON encoding
111+
// of val, dst will be grown.
112+
func MarshalValueAppend(dst []byte, val interface{}) (bsontype.Type, []byte, error) {
113+
return MarshalValueAppendWithRegistry(DefaultRegistry, dst, val)
114+
}
115+
116+
// MarshalValueWithRegistry returns the BSON encoding of val using Registry r.
117+
func MarshalValueWithRegistry(r *bsoncodec.Registry, val interface{}) (bsontype.Type, []byte, error) {
118+
dst := make([]byte, 0, defaultDstCap)
119+
return MarshalValueAppendWithRegistry(r, dst, val)
120+
}
121+
122+
// MarshalValueWithContext returns the BSON encoding of val using EncodeContext ec.
123+
func MarshalValueWithContext(ec bsoncodec.EncodeContext, val interface{}) (bsontype.Type, []byte, error) {
124+
dst := make([]byte, 0, defaultDstCap)
125+
return MarshalValueAppendWithContext(ec, dst, val)
126+
}
127+
128+
// MarshalValueAppendWithRegistry will append the BSON encoding of val to dst using Registry r. If dst is not large
129+
// enough to hold the BSON encoding of val, dst will be grown.
130+
func MarshalValueAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) (bsontype.Type, []byte, error) {
131+
return MarshalValueAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val)
132+
}
133+
134+
// MarshalValueAppendWithContext will append the BSON encoding of val to dst using EncodeContext ec. If dst is not large
135+
// enough to hold the BSON encoding of val, dst will be grown.
136+
func MarshalValueAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) (bsontype.Type, []byte, error) {
137+
// get a ValueWriter configured to write to dst
138+
sw := new(bsonrw.SliceWriter)
139+
*sw = dst
140+
vwFlusher := bvwPool.GetAtModeElement(sw)
141+
142+
// get an Encoder and encode the value
143+
enc := encPool.Get().(*Encoder)
144+
defer encPool.Put(enc)
145+
if err := enc.Reset(vwFlusher); err != nil {
146+
return 0, nil, err
147+
}
148+
if err := enc.SetContext(ec); err != nil {
149+
return 0, nil, err
150+
}
151+
if err := enc.Encode(val); err != nil {
152+
return 0, nil, err
153+
}
154+
155+
// flush the bytes written because we cannot guarantee that a full document has been written
156+
// after the flush, *sw will be in the format
157+
// [value type, 0 (null byte to indicate end of empty element name), value bytes..]
158+
if err := vwFlusher.Flush(); err != nil {
159+
return 0, nil, err
160+
}
161+
buffer := *sw
162+
return bsontype.Type(buffer[0]), buffer[2:], nil
163+
}
164+
98165
// MarshalExtJSON returns the extended JSON encoding of val.
99166
func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error) {
100167
return MarshalExtJSONWithRegistry(DefaultRegistry, val, canonical, escapeHTML)

bson/marshal_value_test.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package bson
8+
9+
import (
10+
"io"
11+
"testing"
12+
13+
"go.mongodb.org/mongo-driver/bson/bsoncodec"
14+
"go.mongodb.org/mongo-driver/bson/bsontype"
15+
"go.mongodb.org/mongo-driver/bson/primitive"
16+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
17+
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
18+
)
19+
20+
// helper type for testing MarshalValue that implements io.Reader
21+
type marshalValueInterfaceInner struct {
22+
Foo int
23+
}
24+
25+
var _ io.Reader = marshalValueInterfaceInner{}
26+
27+
func (marshalValueInterfaceInner) Read([]byte) (int, error) {
28+
return 0, nil
29+
}
30+
31+
// helper type for testing MarshalValue that contains an interface
32+
type marshalValueInterfaceOuter struct {
33+
Reader io.Reader
34+
}
35+
36+
// helper type for testing MarshalValue that implements ValueMarshaler
37+
type marshalValueMarshaler struct {
38+
Foo int
39+
}
40+
41+
var _ ValueMarshaler = marshalValueMarshaler{}
42+
43+
func (mvi marshalValueMarshaler) MarshalBSONValue() (bsontype.Type, []byte, error) {
44+
return bsontype.Int32, bsoncore.AppendInt32(nil, int32(mvi.Foo)), nil
45+
}
46+
47+
type marshalValueStruct struct {
48+
Foo int
49+
}
50+
51+
type marshalValueTestCase struct {
52+
name string
53+
val interface{}
54+
expectedType bsontype.Type
55+
expectedBytes []byte
56+
}
57+
58+
func TestMarshalValue(t *testing.T) {
59+
oid := primitive.NewObjectID()
60+
regex := primitive.Regex{Pattern: "pattern", Options: "imx"}
61+
dbPointer := primitive.DBPointer{DB: "db", Pointer: primitive.NewObjectID()}
62+
codeWithScope := primitive.CodeWithScope{Code: "code", Scope: D{{"a", "b"}}}
63+
idx, scopeCore := bsoncore.AppendDocumentStart(nil)
64+
scopeCore = bsoncore.AppendStringElement(scopeCore, "a", "b")
65+
scopeCore, _ = bsoncore.AppendDocumentEnd(scopeCore, idx)
66+
decimal128 := primitive.NewDecimal128(5, 10)
67+
interfaceTest := marshalValueInterfaceOuter{
68+
Reader: marshalValueInterfaceInner{
69+
Foo: 10,
70+
},
71+
}
72+
interfaceCore, err := Marshal(interfaceTest)
73+
assert.Nil(t, err, "Marshal error: %v", err)
74+
structTest := marshalValueStruct{Foo: 10}
75+
structCore, err := Marshal(structTest)
76+
assert.Nil(t, err, "Marshal error: %v", err)
77+
78+
marshalValueTestCases := []marshalValueTestCase{
79+
{"double", 3.14, bsontype.Double, bsoncore.AppendDouble(nil, 3.14)},
80+
{"string", "hello world", bsontype.String, bsoncore.AppendString(nil, "hello world")},
81+
{"binary", primitive.Binary{1, []byte{1, 2}}, bsontype.Binary, bsoncore.AppendBinary(nil, 1, []byte{1, 2})},
82+
{"undefined", primitive.Undefined{}, bsontype.Undefined, []byte{}},
83+
{"object id", oid, bsontype.ObjectID, bsoncore.AppendObjectID(nil, oid)},
84+
{"boolean", true, bsontype.Boolean, bsoncore.AppendBoolean(nil, true)},
85+
{"datetime", primitive.DateTime(5), bsontype.DateTime, bsoncore.AppendDateTime(nil, 5)},
86+
{"null", primitive.Null{}, bsontype.Null, []byte{}},
87+
{"regex", regex, bsontype.Regex, bsoncore.AppendRegex(nil, regex.Pattern, regex.Options)},
88+
{"dbpointer", dbPointer, bsontype.DBPointer, bsoncore.AppendDBPointer(nil, dbPointer.DB, dbPointer.Pointer)},
89+
{"javascript", primitive.JavaScript("js"), bsontype.JavaScript, bsoncore.AppendJavaScript(nil, "js")},
90+
{"symbol", primitive.Symbol("symbol"), bsontype.Symbol, bsoncore.AppendSymbol(nil, "symbol")},
91+
{"code with scope", codeWithScope, bsontype.CodeWithScope, bsoncore.AppendCodeWithScope(nil, "code", scopeCore)},
92+
{"int32", 5, bsontype.Int32, bsoncore.AppendInt32(nil, 5)},
93+
{"int64", int64(5), bsontype.Int64, bsoncore.AppendInt64(nil, 5)},
94+
{"timestamp", primitive.Timestamp{T: 1, I: 5}, bsontype.Timestamp, bsoncore.AppendTimestamp(nil, 1, 5)},
95+
{"decimal128", decimal128, bsontype.Decimal128, bsoncore.AppendDecimal128(nil, decimal128)},
96+
{"min key", primitive.MinKey{}, bsontype.MinKey, []byte{}},
97+
{"max key", primitive.MaxKey{}, bsontype.MaxKey, []byte{}},
98+
{"struct", structTest, bsontype.EmbeddedDocument, structCore},
99+
{"interface", interfaceTest, bsontype.EmbeddedDocument, interfaceCore},
100+
{"D", D{{"foo", 10}}, bsontype.EmbeddedDocument, structCore},
101+
{"M", M{"foo": 10}, bsontype.EmbeddedDocument, structCore},
102+
{"ValueMarshaler", marshalValueMarshaler{Foo: 10}, bsontype.Int32, bsoncore.AppendInt32(nil, 10)},
103+
}
104+
105+
t.Run("MarshalValue", func(t *testing.T) {
106+
for _, tc := range marshalValueTestCases {
107+
t.Run(tc.name, func(t *testing.T) {
108+
valueType, valueBytes, err := MarshalValue(tc.val)
109+
assert.Nil(t, err, "MarshalValue error: %v", err)
110+
compareMarshalValueResults(t, tc, valueType, valueBytes)
111+
})
112+
}
113+
})
114+
t.Run("MarshalValueAppend", func(t *testing.T) {
115+
for _, tc := range marshalValueTestCases {
116+
t.Run(tc.name, func(t *testing.T) {
117+
valueType, valueBytes, err := MarshalValueAppend(nil, tc.val)
118+
assert.Nil(t, err, "MarshalValueAppend error: %v", err)
119+
compareMarshalValueResults(t, tc, valueType, valueBytes)
120+
})
121+
}
122+
})
123+
t.Run("MarshalValueWithRegistry", func(t *testing.T) {
124+
for _, tc := range marshalValueTestCases {
125+
t.Run(tc.name, func(t *testing.T) {
126+
valueType, valueBytes, err := MarshalValueWithRegistry(DefaultRegistry, tc.val)
127+
assert.Nil(t, err, "MarshalValueWithRegistry error: %v", err)
128+
compareMarshalValueResults(t, tc, valueType, valueBytes)
129+
})
130+
}
131+
})
132+
t.Run("MarshalValueWithContext", func(t *testing.T) {
133+
ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
134+
for _, tc := range marshalValueTestCases {
135+
t.Run(tc.name, func(t *testing.T) {
136+
valueType, valueBytes, err := MarshalValueWithContext(ec, tc.val)
137+
assert.Nil(t, err, "MarshalValueWithContext error: %v", err)
138+
compareMarshalValueResults(t, tc, valueType, valueBytes)
139+
})
140+
}
141+
})
142+
t.Run("MarshalValueAppendWithRegistry", func(t *testing.T) {
143+
for _, tc := range marshalValueTestCases {
144+
t.Run(tc.name, func(t *testing.T) {
145+
valueType, valueBytes, err := MarshalValueAppendWithRegistry(DefaultRegistry, nil, tc.val)
146+
assert.Nil(t, err, "MarshalValueAppendWithRegistry error: %v", err)
147+
compareMarshalValueResults(t, tc, valueType, valueBytes)
148+
})
149+
}
150+
})
151+
t.Run("MarshalValueAppendWithContext", func(t *testing.T) {
152+
ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
153+
for _, tc := range marshalValueTestCases {
154+
t.Run(tc.name, func(t *testing.T) {
155+
valueType, valueBytes, err := MarshalValueAppendWithContext(ec, nil, tc.val)
156+
assert.Nil(t, err, "MarshalValueWithContext error: %v", err)
157+
compareMarshalValueResults(t, tc, valueType, valueBytes)
158+
})
159+
}
160+
})
161+
}
162+
163+
func compareMarshalValueResults(t *testing.T, tc marshalValueTestCase, gotType bsontype.Type, gotBytes []byte) {
164+
t.Helper()
165+
expectedValue := RawValue{Type: tc.expectedType, Value: tc.expectedBytes}
166+
gotValue := RawValue{Type: gotType, Value: gotBytes}
167+
assert.Equal(t, expectedValue, gotValue, "value mismatch; expected %s, got %s", expectedValue, gotValue)
168+
}

0 commit comments

Comments
 (0)