Skip to content

Commit 1e8b9e5

Browse files
author
Isabella Siu
committed
GODRIVER-663 Extend the Marshal* and Unmarshal* family functions with *WithContext functions
Change-Id: I261c3276a78b897039d8e9629b6eab22d608eedf
1 parent e8f52c4 commit 1e8b9e5

11 files changed

+354
-56
lines changed

bson/decoder.go

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,33 @@ var decPool = sync.Pool{
2828
// A Decoder reads and decodes BSON documents from a stream. It reads from a bsonrw.ValueReader as
2929
// the source of BSON data.
3030
type Decoder struct {
31-
r *bsoncodec.Registry
31+
dc bsoncodec.DecodeContext
3232
vr bsonrw.ValueReader
3333
}
3434

35-
// NewDecoder returns a new decoder that uses Registry reg to read from r.
36-
func NewDecoder(r *bsoncodec.Registry, vr bsonrw.ValueReader) (*Decoder, error) {
37-
if r == nil {
38-
return nil, errors.New("cannot create a new Decoder with a nil Registry")
35+
// NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr.
36+
func NewDecoder(vr bsonrw.ValueReader) (*Decoder, error) {
37+
if vr == nil {
38+
return nil, errors.New("cannot create a new Decoder with a nil ValueReader")
39+
}
40+
41+
return &Decoder{
42+
dc: bsoncodec.DecodeContext{Registry: DefaultRegistry},
43+
vr: vr,
44+
}, nil
45+
}
46+
47+
// NewDecoderWithContext returns a new decoder that uses DecodeContext dc to read from vr.
48+
func NewDecoderWithContext(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader) (*Decoder, error) {
49+
if dc.Registry == nil {
50+
dc.Registry = DefaultRegistry
3951
}
4052
if vr == nil {
4153
return nil, errors.New("cannot create a new Decoder with a nil ValueReader")
4254
}
4355

4456
return &Decoder{
45-
r: r,
57+
dc: dc,
4658
vr: vr,
4759
}, nil
4860
}
@@ -67,22 +79,28 @@ func (d *Decoder) Decode(val interface{}) error {
6779
return fmt.Errorf("argument to Decode must be a pointer to a type, but got %v", rval)
6880
}
6981
rval = rval.Elem()
70-
decoder, err := d.r.LookupDecoder(rval.Type())
82+
decoder, err := d.dc.LookupDecoder(rval.Type())
7183
if err != nil {
7284
return err
7385
}
74-
return decoder.DecodeValue(bsoncodec.DecodeContext{Registry: d.r}, d.vr, rval)
86+
return decoder.DecodeValue(d.dc, d.vr, rval)
7587
}
7688

77-
// Reset will reset the state of the decoder, using the same *Registry used in
78-
// the original construction but using r for reading.
89+
// Reset will reset the state of the decoder, using the same *DecodeContext used in
90+
// the original construction but using vr for reading.
7991
func (d *Decoder) Reset(vr bsonrw.ValueReader) error {
8092
d.vr = vr
8193
return nil
8294
}
8395

8496
// SetRegistry replaces the current registry of the decoder with r.
8597
func (d *Decoder) SetRegistry(r *bsoncodec.Registry) error {
86-
d.r = r
98+
d.dc.Registry = r
99+
return nil
100+
}
101+
102+
// SetContext replaces the current registry of the decoder with dc.
103+
func (d *Decoder) SetContext(dc bsoncodec.DecodeContext) error {
104+
d.dc = dc
87105
return nil
88106
}

bson/decoder_test.go

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func TestDecoderv2(t *testing.T) {
5151
} else {
5252
reg = DefaultRegistry
5353
}
54-
dec, err := NewDecoder(reg, vr)
54+
dec, err := NewDecoderWithContext(bsoncodec.DecodeContext{Registry: reg}, vr)
5555
noerr(t, err)
5656
err = dec.Decode(got)
5757
noerr(t, err)
@@ -64,7 +64,7 @@ func TestDecoderv2(t *testing.T) {
6464
t.Run("lookup error", func(t *testing.T) {
6565
type certainlydoesntexistelsewhereihope func(string, string) string
6666
cdeih := func(string, string) string { return "certainlydoesntexistelsewhereihope" }
67-
dec, err := NewDecoder(DefaultRegistry, bsonrw.NewBSONDocumentReader([]byte{}))
67+
dec, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{}))
6868
noerr(t, err)
6969
want := bsoncodec.ErrNoDecoder{Type: reflect.TypeOf(cdeih)}
7070
got := dec.Decode(&cdeih)
@@ -102,7 +102,7 @@ func TestDecoderv2(t *testing.T) {
102102
for _, tc := range testCases {
103103
t.Run(tc.name, func(t *testing.T) {
104104
unmarshaler := &testUnmarshaler{err: tc.err}
105-
dec, err := NewDecoder(DefaultRegistry, tc.vr)
105+
dec, err := NewDecoder(tc.vr)
106106
noerr(t, err)
107107
got := dec.Decode(unmarshaler)
108108
want := tc.err
@@ -123,7 +123,7 @@ func TestDecoderv2(t *testing.T) {
123123
want := bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159))
124124
unmarshaler := &testUnmarshaler{}
125125
vr := bsonrw.NewBSONDocumentReader(want)
126-
dec, err := NewDecoder(DefaultRegistry, vr)
126+
dec, err := NewDecoder(vr)
127127
noerr(t, err)
128128
err = dec.Decode(unmarshaler)
129129
noerr(t, err)
@@ -134,21 +134,39 @@ func TestDecoderv2(t *testing.T) {
134134
})
135135
})
136136
})
137-
t.Run("NewDecoderv2", func(t *testing.T) {
138-
t.Run("errors", func(t *testing.T) {
139-
_, got := NewDecoder(nil, bsonrw.ValueReader(nil))
140-
want := errors.New("cannot create a new Decoder with a nil Registry")
137+
t.Run("NewDecoder", func(t *testing.T) {
138+
t.Run("error", func(t *testing.T) {
139+
_, got := NewDecoder(nil)
140+
want := errors.New("cannot create a new Decoder with a nil ValueReader")
141141
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
142142
t.Errorf("Was expecting error but got different error. got %v; want %v", got, want)
143143
}
144-
_, got = NewDecoder(DefaultRegistry, nil)
145-
want = errors.New("cannot create a new Decoder with a nil ValueReader")
144+
})
145+
t.Run("success", func(t *testing.T) {
146+
got, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{}))
147+
noerr(t, err)
148+
if got == nil {
149+
t.Errorf("Was expecting a non-nil Decoder, but got <nil>")
150+
}
151+
})
152+
})
153+
t.Run("NewDecoderWithContext", func(t *testing.T) {
154+
t.Run("errors", func(t *testing.T) {
155+
dc := bsoncodec.DecodeContext{Registry: DefaultRegistry}
156+
_, got := NewDecoderWithContext(dc, nil)
157+
want := errors.New("cannot create a new Decoder with a nil ValueReader")
146158
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
147159
t.Errorf("Was expecting error but got different error. got %v; want %v", got, want)
148160
}
149161
})
150162
t.Run("success", func(t *testing.T) {
151-
got, err := NewDecoder(DefaultRegistry, bsonrw.NewBSONDocumentReader([]byte{}))
163+
got, err := NewDecoderWithContext(bsoncodec.DecodeContext{}, bsonrw.NewBSONDocumentReader([]byte{}))
164+
noerr(t, err)
165+
if got == nil {
166+
t.Errorf("Was expecting a non-nil Decoder, but got <nil>")
167+
}
168+
dc := bsoncodec.DecodeContext{Registry: DefaultRegistry}
169+
got, err = NewDecoderWithContext(dc, bsonrw.NewBSONDocumentReader([]byte{}))
152170
noerr(t, err)
153171
if got == nil {
154172
t.Errorf("Was expecting a non-nil Decoder, but got <nil>")
@@ -166,7 +184,7 @@ func TestDecoderv2(t *testing.T) {
166184
got.Bonus = 2
167185
data := docToBytes(bsonx.Doc{{"item", bsonx.String("canvas")}, {"qty", bsonx.Int32(4)}})
168186
vr := bsonrw.NewBSONDocumentReader(data)
169-
dec, err := NewDecoder(DefaultRegistry, vr)
187+
dec, err := NewDecoder(vr)
170188
noerr(t, err)
171189
err = dec.Decode(&got)
172190
noerr(t, err)
@@ -177,7 +195,8 @@ func TestDecoderv2(t *testing.T) {
177195
})
178196
t.Run("Reset", func(t *testing.T) {
179197
vr1, vr2 := bsonrw.NewBSONDocumentReader([]byte{}), bsonrw.NewBSONDocumentReader([]byte{})
180-
dec, err := NewDecoder(DefaultRegistry, vr1)
198+
dc := bsoncodec.DecodeContext{Registry: DefaultRegistry}
199+
dec, err := NewDecoderWithContext(dc, vr1)
181200
noerr(t, err)
182201
if dec.vr != vr1 {
183202
t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr1)
@@ -188,17 +207,33 @@ func TestDecoderv2(t *testing.T) {
188207
t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr2)
189208
}
190209
})
210+
t.Run("SetContext", func(t *testing.T) {
211+
dc1 := bsoncodec.DecodeContext{Registry: DefaultRegistry}
212+
dc2 := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()}
213+
dec, err := NewDecoderWithContext(dc1, bsonrw.NewBSONDocumentReader([]byte{}))
214+
noerr(t, err)
215+
if dec.dc != dc1 {
216+
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1)
217+
}
218+
err = dec.SetContext(dc2)
219+
noerr(t, err)
220+
if dec.dc != dc2 {
221+
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2)
222+
}
223+
})
191224
t.Run("SetRegistry", func(t *testing.T) {
192-
reg1, reg2 := DefaultRegistry, NewRegistryBuilder().Build()
193-
dec, err := NewDecoder(reg1, bsonrw.NewBSONDocumentReader([]byte{}))
225+
r1, r2 := DefaultRegistry, NewRegistryBuilder().Build()
226+
dc1 := bsoncodec.DecodeContext{Registry: r1}
227+
dc2 := bsoncodec.DecodeContext{Registry: r2}
228+
dec, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{}))
194229
noerr(t, err)
195-
if dec.r != reg1 {
196-
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.r, reg1)
230+
if dec.dc != dc1 {
231+
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1)
197232
}
198-
err = dec.SetRegistry(reg2)
233+
err = dec.SetRegistry(r2)
199234
noerr(t, err)
200-
if dec.r != reg2 {
201-
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.r, reg2)
235+
if dec.dc != dc2 {
236+
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2)
202237
}
203238
})
204239
}

bson/encoder.go

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,33 @@ var encPool = sync.Pool{
2727
// An Encoder writes a serialization format to an output stream. It writes to a bsonrw.ValueWriter
2828
// as the destination of BSON data.
2929
type Encoder struct {
30-
r *bsoncodec.Registry
30+
ec bsoncodec.EncodeContext
3131
vw bsonrw.ValueWriter
3232
}
3333

34-
// NewEncoder returns a new encoder that uses Registry r to write to w.
35-
func NewEncoder(r *bsoncodec.Registry, vw bsonrw.ValueWriter) (*Encoder, error) {
36-
if r == nil {
37-
return nil, errors.New("cannot create a new Encoder with a nil Registry")
34+
// NewEncoder returns a new encoder that uses the DefaultRegistry to write to vw.
35+
func NewEncoder(vw bsonrw.ValueWriter) (*Encoder, error) {
36+
if vw == nil {
37+
return nil, errors.New("cannot create a new Encoder with a nil ValueWriter")
38+
}
39+
40+
return &Encoder{
41+
ec: bsoncodec.EncodeContext{Registry: DefaultRegistry},
42+
vw: vw,
43+
}, nil
44+
}
45+
46+
// NewEncoderWithContext returns a new encoder that uses EncodeContext ec to write to vw.
47+
func NewEncoderWithContext(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter) (*Encoder, error) {
48+
if ec.Registry == nil {
49+
ec = bsoncodec.EncodeContext{Registry: DefaultRegistry}
3850
}
3951
if vw == nil {
4052
return nil, errors.New("cannot create a new Encoder with a nil ValueWriter")
4153
}
4254

4355
return &Encoder{
44-
r: r,
56+
ec: ec,
4557
vw: vw,
4658
}, nil
4759
}
@@ -60,14 +72,14 @@ func (e *Encoder) Encode(val interface{}) error {
6072
return bsonrw.Copier{}.CopyDocumentFromBytes(e.vw, buf)
6173
}
6274

63-
encoder, err := e.r.LookupEncoder(reflect.TypeOf(val))
75+
encoder, err := e.ec.LookupEncoder(reflect.TypeOf(val))
6476
if err != nil {
6577
return err
6678
}
67-
return encoder.EncodeValue(bsoncodec.EncodeContext{Registry: e.r}, e.vw, reflect.ValueOf(val))
79+
return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val))
6880
}
6981

70-
// Reset will reset the state of the encoder, using the same *Registry used in
82+
// Reset will reset the state of the encoder, using the same *EncodeContext used in
7183
// the original construction but using vw.
7284
func (e *Encoder) Reset(vw bsonrw.ValueWriter) error {
7385
e.vw = vw
@@ -76,6 +88,12 @@ func (e *Encoder) Reset(vw bsonrw.ValueWriter) error {
7688

7789
// SetRegistry replaces the current registry of the encoder with r.
7890
func (e *Encoder) SetRegistry(r *bsoncodec.Registry) error {
79-
e.r = r
91+
e.ec.Registry = r
92+
return nil
93+
}
94+
95+
// SetContext replaces the current EncodeContext of the encoder with er.
96+
func (e *Encoder) SetContext(ec bsoncodec.EncodeContext) error {
97+
e.ec = ec
8098
return nil
8199
}

bson/encoder_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ func TestEncoderEncode(t *testing.T) {
4444
got := make(bsonrw.SliceWriter, 0, 1024)
4545
vw, err := bsonrw.NewBSONValueWriter(&got)
4646
noerr(t, err)
47-
reg := DefaultRegistry
48-
enc, err := NewEncoder(reg, vw)
47+
enc, err := NewEncoder(vw)
4948
noerr(t, err)
5049
err = enc.Encode(tc.val)
5150
noerr(t, err)
@@ -103,8 +102,7 @@ func TestEncoderEncode(t *testing.T) {
103102
vw, err = bsonrw.NewBSONValueWriter(&b)
104103
noerr(t, err)
105104
}
106-
107-
enc, err := NewEncoder(DefaultRegistry, vw)
105+
enc, err := NewEncoder(vw)
108106
noerr(t, err)
109107
got := enc.Encode(marshaler)
110108
want := tc.wanterr

bson/marshal.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,23 @@ func MarshalWithRegistry(r *bsoncodec.Registry, val interface{}) ([]byte, error)
5353
return MarshalAppendWithRegistry(r, dst, val)
5454
}
5555

56+
// MarshalWithContext returns the BSON encoding of val using EncodeContext ec.
57+
func MarshalWithContext(ec bsoncodec.EncodeContext, val interface{}) ([]byte, error) {
58+
dst := make([]byte, 0, 256) // TODO: make the default cap a constant
59+
return MarshalAppendWithContext(ec, dst, val)
60+
}
61+
5662
// MarshalAppendWithRegistry will append the BSON encoding of val to dst using
5763
// Registry r. If dst is not large enough to hold the BSON encoding of val, dst
5864
// will be grown.
5965
func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) ([]byte, error) {
66+
return MarshalAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val)
67+
}
68+
69+
// MarshalAppendWithContext will append the BSON encoding of val to dst using
70+
// EncodeContext ec. If dst is not large enough to hold the BSON encoding of val, dst
71+
// will be grown.
72+
func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) ([]byte, error) {
6073
sw := new(bsonrw.SliceWriter)
6174
*sw = dst
6275
vw := bvwPool.Get(sw)
@@ -69,7 +82,7 @@ func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{
6982
if err != nil {
7083
return nil, err
7184
}
72-
err = enc.SetRegistry(r)
85+
err = enc.SetContext(ec)
7386
if err != nil {
7487
return nil, err
7588
}
@@ -97,13 +110,26 @@ func MarshalExtJSONAppend(dst []byte, val interface{}, canonical, escapeHTML boo
97110
// MarshalExtJSONWithRegistry returns the extended JSON encoding of val using Registry r.
98111
func MarshalExtJSONWithRegistry(r *bsoncodec.Registry, val interface{}, canonical, escapeHTML bool) ([]byte, error) {
99112
dst := make([]byte, 0, defaultDstCap)
100-
return MarshalExtJSONAppendWithRegistry(r, dst, val, canonical, escapeHTML)
113+
return MarshalExtJSONAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val, canonical, escapeHTML)
114+
}
115+
116+
// MarshalExtJSONWithContext returns the extended JSON encoding of val using Registry r.
117+
func MarshalExtJSONWithContext(ec bsoncodec.EncodeContext, val interface{}, canonical, escapeHTML bool) ([]byte, error) {
118+
dst := make([]byte, 0, defaultDstCap)
119+
return MarshalExtJSONAppendWithContext(ec, dst, val, canonical, escapeHTML)
101120
}
102121

103122
// MarshalExtJSONAppendWithRegistry will append the extended JSON encoding of
104123
// val to dst using Registry r. If dst is not large enough to hold the BSON
105124
// encoding of val, dst will be grown.
106125
func MarshalExtJSONAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) {
126+
return MarshalExtJSONAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val, canonical, escapeHTML)
127+
}
128+
129+
// MarshalExtJSONAppendWithContext will append the extended JSON encoding of
130+
// val to dst using Registry r. If dst is not large enough to hold the BSON
131+
// encoding of val, dst will be grown.
132+
func MarshalExtJSONAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) {
107133
sw := new(bsonrw.SliceWriter)
108134
*sw = dst
109135
ejvw := extjPool.Get(sw, canonical, escapeHTML)
@@ -116,7 +142,7 @@ func MarshalExtJSONAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val int
116142
if err != nil {
117143
return nil, err
118144
}
119-
err = enc.SetRegistry(r)
145+
err = enc.SetContext(ec)
120146
if err != nil {
121147
return nil, err
122148
}

0 commit comments

Comments
 (0)