Skip to content

Commit 55ea16d

Browse files
author
Divjot Arora
committed
Change UnmarshalExtJSON to return error for invalid JSON.
GODRIVER-694 Change-Id: I66ba147cd0d71f41cd8ed9b34e21d9b56add6044
1 parent 48f45a6 commit 55ea16d

File tree

5 files changed

+34
-15
lines changed

5 files changed

+34
-15
lines changed

bson/bson_corpus_spec_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,9 @@ func runTest(t *testing.T, file string) {
315315
err := UnmarshalExtJSON([]byte(s), true, &doc)
316316
expectError(t, err, fmt.Sprintf("%s: expected parse error", p.Description))
317317
case "0x13":
318-
ejvr := bsonrw.NewExtJSONValueReader(strings.NewReader(s), true)
319-
_, err := ejvr.ReadDecimal128()
318+
ejvr, err := bsonrw.NewExtJSONValueReader(strings.NewReader(s), true)
319+
expectNoError(t, err, fmt.Sprintf("error creating value reader: %s", err))
320+
_, err = ejvr.ReadDecimal128()
320321
expectError(t, err, fmt.Sprintf("%s: expected parse error", p.Description))
321322
default:
322323
t.Errorf("Update test to check for parse errors for type %s", test.BsonType)

bson/bsonrw/extjson_reader.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,9 @@ func NewExtJSONValueReaderPool() *ExtJSONValueReaderPool {
3232
}
3333

3434
// Get retrieves a ValueReader from the pool and uses src as the underlying ExtJSON.
35-
func (bvrp *ExtJSONValueReaderPool) Get(r io.Reader, canonical bool) ValueReader {
35+
func (bvrp *ExtJSONValueReaderPool) Get(r io.Reader, canonical bool) (ValueReader, error) {
3636
vr := bvrp.pool.Get().(*extJSONValueReader)
37-
vr = vr.reset(r, canonical)
38-
return vr
37+
return vr.reset(r, canonical)
3938
}
4039

4140
// Put inserts a ValueReader into the pool. If the ValueReader is not a ExtJSON ValueReader nothing
@@ -46,7 +45,7 @@ func (bvrp *ExtJSONValueReaderPool) Put(vr ValueReader) (ok bool) {
4645
return false
4746
}
4847

49-
bvr = bvr.reset(nil, false)
48+
bvr, _ = bvr.reset(nil, false)
5049
bvrp.pool.Put(bvr)
5150
return true
5251
}
@@ -68,22 +67,21 @@ type extJSONValueReader struct {
6867
// NewExtJSONValueReader creates a new ValueReader from a given io.Reader
6968
// It will interpret the JSON of r as canonical or relaxed according to the
7069
// given canonical flag
71-
func NewExtJSONValueReader(r io.Reader, canonical bool) ValueReader {
70+
func NewExtJSONValueReader(r io.Reader, canonical bool) (ValueReader, error) {
7271
return newExtJSONValueReader(r, canonical)
7372
}
7473

75-
func newExtJSONValueReader(r io.Reader, canonical bool) *extJSONValueReader {
74+
func newExtJSONValueReader(r io.Reader, canonical bool) (*extJSONValueReader, error) {
7675
ejvr := new(extJSONValueReader)
7776
return ejvr.reset(r, canonical)
7877
}
7978

80-
func (ejvr *extJSONValueReader) reset(r io.Reader, canonical bool) *extJSONValueReader {
79+
func (ejvr *extJSONValueReader) reset(r io.Reader, canonical bool) (*extJSONValueReader, error) {
8180
p := newExtJSONParser(r, canonical)
8281
typ, err := p.peekType()
8382

8483
if err != nil {
85-
// TODO: invalid JSON--return error message?
86-
return nil
84+
return nil, ErrInvalidJSON
8785
}
8886

8987
var m mode
@@ -104,7 +102,7 @@ func (ejvr *extJSONValueReader) reset(r io.Reader, canonical bool) *extJSONValue
104102
return &extJSONValueReader{
105103
p: p,
106104
stack: stack,
107-
}
105+
}, nil
108106
}
109107

110108
func (ejvr *extJSONValueReader) advanceFrame() {

bson/bsonrw/extjson_reader_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ func TestReadMultipleTopLevelDocuments(t *testing.T) {
105105
for _, tc := range testCases {
106106
t.Run(tc.name, func(t *testing.T) {
107107
r := strings.NewReader(tc.input)
108-
vr := NewExtJSONValueReader(r, false)
108+
vr, err := NewExtJSONValueReader(r, false)
109+
if err != nil {
110+
t.Fatalf("expected no error, but got %v", err)
111+
}
109112

110113
actual, err := readAllDocuments(vr)
111114
if err != nil {

bson/unmarshal.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,23 @@ func UnmarshalExtJSON(data []byte, canonical bool, val interface{}) error {
6464
// Registry r and stores the result in the value pointed to by val. If val is
6565
// nil or not a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError.
6666
func UnmarshalExtJSONWithRegistry(r *bsoncodec.Registry, data []byte, canonical bool, val interface{}) error {
67-
ejvr := bsonrw.NewExtJSONValueReader(bytes.NewReader(data), canonical)
67+
ejvr, err := bsonrw.NewExtJSONValueReader(bytes.NewReader(data), canonical)
68+
if err != nil {
69+
return err
70+
}
71+
6872
return unmarshalFromReader(bsoncodec.DecodeContext{Registry: r}, ejvr, val)
6973
}
7074

7175
// UnmarshalExtJSONWithContext parses the extended JSON-encoded data using
7276
// DecodeContext dc and stores the result in the value pointed to by val. If val is
7377
// nil or not a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError.
7478
func UnmarshalExtJSONWithContext(dc bsoncodec.DecodeContext, data []byte, canonical bool, val interface{}) error {
75-
ejvr := bsonrw.NewExtJSONValueReader(bytes.NewReader(data), canonical)
79+
ejvr, err := bsonrw.NewExtJSONValueReader(bytes.NewReader(data), canonical)
80+
if err != nil {
81+
return err
82+
}
83+
7684
return unmarshalFromReader(dc, ejvr, val)
7785
}
7886

bson/unmarshal_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/google/go-cmp/cmp"
1414
"github.com/mongodb/mongo-go-driver/bson/bsoncodec"
15+
"github.com/mongodb/mongo-go-driver/bson/bsonrw"
1516
)
1617

1718
func TestUnmarshal(t *testing.T) {
@@ -81,6 +82,14 @@ func TestUnmarshalExtJSONWithRegistry(t *testing.T) {
8182
t.Errorf("Did not unmarshal as expected. got %v; want %v", got, want)
8283
}
8384
})
85+
86+
t.Run("UnmarshalExtJSONInvalidInput", func(t *testing.T) {
87+
data := []byte("invalid")
88+
err := UnmarshalExtJSONWithRegistry(DefaultRegistry, data, true, &M{})
89+
if err != bsonrw.ErrInvalidJSON {
90+
t.Fatalf("wanted ErrInvalidJSON, got %v", err)
91+
}
92+
})
8493
}
8594

8695
func TestUnmarshalExtJSONWithContext(t *testing.T) {

0 commit comments

Comments
 (0)