Skip to content

Commit 5514df9

Browse files
committed
Ensure _id is always first in inserted documents
GODRIVER-660 Change-Id: Ieed0294e8447d50fd49b00551d064522994260f7
1 parent 8991f81 commit 5514df9

File tree

4 files changed

+105
-33
lines changed

4 files changed

+105
-33
lines changed

mongo/mongo.go

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,30 @@ func transformAndEnsureID(registry *bsoncodec.Registry, val interface{}) (bsonx.
103103

104104
var id interface{}
105105

106-
v, err := bson.Raw(b).LookupErr("_id")
107-
switch err.(type) {
108-
case nil:
109-
if err := v.Unmarshal(&id); err != nil {
110-
return nil, nil, err
111-
}
106+
idx := d.IndexOf("_id")
107+
var idElem bsonx.Elem
108+
switch idx {
109+
case -1:
110+
idElem = bsonx.Elem{"_id", bsonx.ObjectID(primitive.NewObjectID())}
111+
d = append(d, bsonx.Elem{})
112+
copy(d[1:], d)
113+
d[0] = idElem
112114
default:
113-
oid := primitive.NewObjectID()
114-
d = append(d, bsonx.Elem{"_id", bsonx.ObjectID(oid)})
115-
id = oid
115+
idElem = d[idx]
116+
copy(d[1:idx+1], d[0:idx])
117+
d[0] = idElem
116118
}
119+
120+
t, data, err := idElem.Value.MarshalAppendBSONValue(buf[:0])
121+
if err != nil {
122+
return nil, nil, err
123+
}
124+
125+
err = bson.RawValue{Type: t, Value: data}.UnmarshalWithRegistry(registry, &id)
126+
if err != nil {
127+
return nil, nil, err
128+
}
129+
117130
return d, id, nil
118131
}
119132

mongo/mongo_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,58 @@ func TestTransformDocument(t *testing.T) {
8181
}
8282
}
8383

84+
func TestTransformAndEnsureID(t *testing.T) {
85+
t.Run("newly added _id should be first element", func(t *testing.T) {
86+
doc := bson.D{{"foo", "bar"}, {"baz", "qux"}, {"hello", "world"}}
87+
want := bsonx.Doc{
88+
{"_id", bsonx.Null()}, {"foo", bsonx.String("bar")},
89+
{"baz", bsonx.String("qux")}, {"hello", bsonx.String("world")},
90+
}
91+
got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc)
92+
noerr(t, err)
93+
oid, ok := id.(primitive.ObjectID)
94+
if !ok {
95+
t.Fatalf("Expected returned id to be a %T, but was %T", primitive.ObjectID{}, id)
96+
}
97+
want[0] = bsonx.Elem{"_id", bsonx.ObjectID(oid)}
98+
if diff := cmp.Diff(got, want, cmp.AllowUnexported(bsonx.Elem{}, bsonx.Val{})); diff != "" {
99+
t.Errorf("Returned documents differ: (-got +want)\n%s", diff)
100+
}
101+
})
102+
t.Run("existing _id should be first element", func(t *testing.T) {
103+
doc := bson.D{{"foo", "bar"}, {"baz", "qux"}, {"_id", 3.14159}, {"hello", "world"}}
104+
want := bsonx.Doc{
105+
{"_id", bsonx.Double(3.14159)}, {"foo", bsonx.String("bar")},
106+
{"baz", bsonx.String("qux")}, {"hello", bsonx.String("world")},
107+
}
108+
got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc)
109+
noerr(t, err)
110+
_, ok := id.(float64)
111+
if !ok {
112+
t.Fatalf("Expected returned id to be a %T, but was %T", float64(0), id)
113+
}
114+
if diff := cmp.Diff(got, want, cmp.AllowUnexported(bsonx.Elem{}, bsonx.Val{})); diff != "" {
115+
t.Errorf("Returned documents differ: (-got +want)\n%s", diff)
116+
}
117+
})
118+
t.Run("existing _id as first element should remain first element", func(t *testing.T) {
119+
doc := bson.D{{"_id", 3.14159}, {"foo", "bar"}, {"baz", "qux"}, {"hello", "world"}}
120+
want := bsonx.Doc{
121+
{"_id", bsonx.Double(3.14159)}, {"foo", bsonx.String("bar")},
122+
{"baz", bsonx.String("qux")}, {"hello", bsonx.String("world")},
123+
}
124+
got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc)
125+
noerr(t, err)
126+
_, ok := id.(float64)
127+
if !ok {
128+
t.Fatalf("Expected returned id to be a %T, but was %T", float64(0), id)
129+
}
130+
if diff := cmp.Diff(got, want, cmp.AllowUnexported(bsonx.Elem{}, bsonx.Val{})); diff != "" {
131+
t.Errorf("Returned documents differ: (-got +want)\n%s", diff)
132+
}
133+
})
134+
}
135+
84136
func TestTransformAggregatePipeline(t *testing.T) {
85137
index, arr := bsoncore.AppendArrayStart(nil)
86138
dindex, arr := bsoncore.AppendDocumentElementStart(arr, "0")

x/bsonx/document.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,17 @@ func (d Doc) Prepend(key string, val Val) Doc {
8181
// does not have an element with that key, the element is appended to the
8282
// document instead.
8383
func (d Doc) Set(key string, val Val) Doc {
84-
idx := d.indexOf(key)
84+
idx := d.IndexOf(key)
8585
if idx == -1 {
8686
return append(d, Elem{Key: key, Value: val})
8787
}
8888
d[idx] = Elem{Key: key, Value: val}
8989
return d
9090
}
9191

92-
func (d Doc) indexOf(key string) int {
92+
// IndexOf returns the index of the first element with a key of key, or -1 if no element with a key
93+
// was found.
94+
func (d Doc) IndexOf(key string) int {
9395
for i, e := range d {
9496
if e.Key == key {
9597
return i
@@ -100,7 +102,7 @@ func (d Doc) indexOf(key string) int {
100102

101103
// Delete removes the element with key if it exists and returns the updated Doc.
102104
func (d Doc) Delete(key string) Doc {
103-
idx := d.indexOf(key)
105+
idx := d.IndexOf(key)
104106
if idx == -1 {
105107
return d
106108
}
@@ -147,7 +149,7 @@ func (d Doc) LookupElementErr(key ...string) (Elem, error) {
147149

148150
var elem Elem
149151
var err error
150-
idx := d.indexOf(key[0])
152+
idx := d.IndexOf(key[0])
151153
if idx == -1 {
152154
return Elem{}, KeyNotFound{Key: key}
153155
}

x/bsonx/value.go

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -160,64 +160,69 @@ func (v Val) Interface() interface{} {
160160

161161
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
162162
func (v Val) MarshalBSONValue() (bsontype.Type, []byte, error) {
163+
return v.MarshalAppendBSONValue(nil)
164+
}
165+
166+
// MarshalAppendBSONValue is similar to MarshalBSONValue, but allows the caller to specify a slice
167+
// to add the bytes to.
168+
func (v Val) MarshalAppendBSONValue(dst []byte) (bsontype.Type, []byte, error) {
163169
t := v.Type()
164-
var data []byte
165170
switch v.Type() {
166171
case bsontype.Double:
167-
data = bsoncore.AppendDouble(data, v.Double())
172+
dst = bsoncore.AppendDouble(dst, v.Double())
168173
case bsontype.String:
169-
data = bsoncore.AppendString(data, v.String())
174+
dst = bsoncore.AppendString(dst, v.String())
170175
case bsontype.EmbeddedDocument:
171176
switch v.primitive.(type) {
172177
case Doc:
173-
t, data, _ = v.primitive.(Doc).MarshalBSONValue() // Doc.MarshalBSONValue never returns an error.
178+
t, dst, _ = v.primitive.(Doc).MarshalBSONValue() // Doc.MarshalBSONValue never returns an error.
174179
case MDoc:
175-
t, data, _ = v.primitive.(MDoc).MarshalBSONValue() // MDoc.MarshalBSONValue never returns an error.
180+
t, dst, _ = v.primitive.(MDoc).MarshalBSONValue() // MDoc.MarshalBSONValue never returns an error.
176181
}
177182
case bsontype.Array:
178-
t, data, _ = v.Array().MarshalBSONValue() // Arr.MarshalBSON never returns an error.
183+
t, dst, _ = v.Array().MarshalBSONValue() // Arr.MarshalBSON never returns an error.
179184
case bsontype.Binary:
180185
subtype, bindata := v.Binary()
181-
data = bsoncore.AppendBinary(data, subtype, bindata)
186+
dst = bsoncore.AppendBinary(dst, subtype, bindata)
182187
case bsontype.Undefined:
183188
case bsontype.ObjectID:
184-
data = bsoncore.AppendObjectID(data, v.ObjectID())
189+
dst = bsoncore.AppendObjectID(dst, v.ObjectID())
185190
case bsontype.Boolean:
186-
data = bsoncore.AppendBoolean(data, v.Boolean())
191+
dst = bsoncore.AppendBoolean(dst, v.Boolean())
187192
case bsontype.DateTime:
188-
data = bsoncore.AppendDateTime(data, int64(v.DateTime()))
193+
dst = bsoncore.AppendDateTime(dst, int64(v.DateTime()))
189194
case bsontype.Null:
190195
case bsontype.Regex:
191196
pattern, options := v.Regex()
192-
data = bsoncore.AppendRegex(data, pattern, options)
197+
dst = bsoncore.AppendRegex(dst, pattern, options)
193198
case bsontype.DBPointer:
194199
ns, ptr := v.DBPointer()
195-
data = bsoncore.AppendDBPointer(data, ns, ptr)
200+
dst = bsoncore.AppendDBPointer(dst, ns, ptr)
196201
case bsontype.JavaScript:
197-
data = bsoncore.AppendJavaScript(data, string(v.JavaScript()))
202+
dst = bsoncore.AppendJavaScript(dst, string(v.JavaScript()))
198203
case bsontype.Symbol:
199-
data = bsoncore.AppendSymbol(data, string(v.Symbol()))
204+
dst = bsoncore.AppendSymbol(dst, string(v.Symbol()))
200205
case bsontype.CodeWithScope:
201206
code, doc := v.CodeWithScope()
202207
var scope []byte
203208
scope, _ = doc.MarshalBSON() // Doc.MarshalBSON never returns an error.
204-
data = bsoncore.AppendCodeWithScope(data, code, scope)
209+
dst = bsoncore.AppendCodeWithScope(dst, code, scope)
205210
case bsontype.Int32:
206-
data = bsoncore.AppendInt32(data, v.Int32())
211+
dst = bsoncore.AppendInt32(dst, v.Int32())
207212
case bsontype.Timestamp:
208213
t, i := v.Timestamp()
209-
data = bsoncore.AppendTimestamp(data, t, i)
214+
dst = bsoncore.AppendTimestamp(dst, t, i)
210215
case bsontype.Int64:
211-
data = bsoncore.AppendInt64(data, v.Int64())
216+
dst = bsoncore.AppendInt64(dst, v.Int64())
212217
case bsontype.Decimal128:
213-
data = bsoncore.AppendDecimal128(data, v.Decimal128())
218+
dst = bsoncore.AppendDecimal128(dst, v.Decimal128())
214219
case bsontype.MinKey:
215220
case bsontype.MaxKey:
216221
default:
217222
panic(fmt.Errorf("invalid BSON type %v", t))
218223
}
219224

220-
return t, data, nil
225+
return t, dst, nil
221226
}
222227

223228
// UnmarshalBSONValue implements the bsoncodec.ValueUnmarshaler interface.

0 commit comments

Comments
 (0)