Skip to content

Commit 5ae4e01

Browse files
author
Divjot Arora
authored
GODRIVER-1020 Allow empty JSON string to decode into ObjectID (#384)
1 parent cd2cf48 commit 5ae4e01

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

bson/primitive/objectid.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,18 @@ func (id ObjectID) MarshalJSON() ([]byte, error) {
9393
return json.Marshal(id.Hex())
9494
}
9595

96-
// UnmarshalJSON populates the byte slice with the ObjectID. If the byte slice is 64 bytes long, it
96+
// UnmarshalJSON populates the byte slice with the ObjectID. If the byte slice is 24 bytes long, it
9797
// will be populated with the hex representation of the ObjectID. If the byte slice is twelve bytes
98-
// long, it will be populated with the BSON representation of the ObjectID. Otherwise, it will
99-
// return an error.
98+
// long, it will be populated with the BSON representation of the ObjectID. This method also accepts empty strings and
99+
// decodes them as NilObjectID. For any other inputs, an error will be returned.
100100
func (id *ObjectID) UnmarshalJSON(b []byte) error {
101+
// Ignore "null" to keep parity with the standard library. Decoding a JSON null into a non-pointer ObjectID field
102+
// will leave the field unchanged. For pointer values, encoding/json will set the pointer to nil and will not
103+
// enter the UnmarshalJSON hook.
104+
if string(b) == "null" {
105+
return nil
106+
}
107+
101108
var err error
102109
switch len(b) {
103110
case 12:
@@ -125,6 +132,12 @@ func (id *ObjectID) UnmarshalJSON(b []byte) error {
125132
}
126133
}
127134

135+
// An empty string is not a valid ObjectID, but we treat it as a special value that decodes as NilObjectID.
136+
if len(str) == 0 {
137+
copy(id[:], NilObjectID[:])
138+
return nil
139+
}
140+
128141
if len(str) != 24 {
129142
return fmt.Errorf("cannot unmarshal into an ObjectID, the length must be 24 but it is %d", len(str))
130143
}

bson/primitive/objectid_test.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
package primitive
88

99
import (
10-
"testing"
11-
1210
"encoding/binary"
1311
"encoding/hex"
12+
"encoding/json"
13+
"fmt"
14+
"testing"
1415
"time"
1516

1617
"github.com/stretchr/testify/require"
18+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1719
)
1820

1921
func TestNew(t *testing.T) {
@@ -148,3 +150,33 @@ func TestCounterOverflow(t *testing.T) {
148150
NewObjectID()
149151
require.Equal(t, uint32(0), objectIDCounter)
150152
}
153+
154+
func TestObjectID_UnmarshalJSON(t *testing.T) {
155+
oid := NewObjectID()
156+
157+
hexJSON := fmt.Sprintf(`{"foo": %q}`, oid.Hex())
158+
extJSON := fmt.Sprintf(`{"foo": {"$oid": %q}}`, oid.Hex())
159+
emptyStringJSON := `{"foo": ""}`
160+
nullJSON := `{"foo": null}`
161+
162+
testCases := []struct {
163+
name string
164+
jsonString string
165+
expected ObjectID
166+
}{
167+
{"hex bytes", hexJSON, oid},
168+
{"extended JSON", extJSON, oid},
169+
{"empty string", emptyStringJSON, NilObjectID},
170+
{"null", nullJSON, NilObjectID},
171+
}
172+
for _, tc := range testCases {
173+
t.Run(tc.name, func(t *testing.T) {
174+
var got map[string]ObjectID
175+
err := json.Unmarshal([]byte(tc.jsonString), &got)
176+
assert.Nil(t, err, "Unmarshal error: %v", err)
177+
178+
gotOid := got["foo"]
179+
assert.Equal(t, tc.expected, gotOid, "expected ObjectID %s, got %s", tc.expected, gotOid)
180+
})
181+
}
182+
}

0 commit comments

Comments
 (0)