Skip to content

Commit 17bb285

Browse files
committed
Fix transforming of interface{} into agg pipeline
This commit fixes the transformAggregatePipeline function, which previously didn't properly handle arrays and slices. The function now does not accept a document, instead requiring either a slice, array, or a type that implements ValueMarshaler and returns a BSON array from its MarshalBSONValue method. GODRIVER-679 Change-Id: I5f1e4b6fe84a6e1db6dda187010d7f94e4a42e7b
1 parent 5277acc commit 17bb285

File tree

5 files changed

+204
-37
lines changed

5 files changed

+204
-37
lines changed

mongo/causal_consistency_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func createReadFuncMap(t *testing.T, dbName string, collName string) (*Client, *
8080
coll.writeConcern = writeconcern.New(writeconcern.WMajority())
8181

8282
functions := []CollFunction{
83-
{"Aggregate", coll, nil, func(mctx SessionContext) error { _, err := coll.Aggregate(mctx, emptyDoc); return err }},
83+
{"Aggregate", coll, nil, func(mctx SessionContext) error { _, err := coll.Aggregate(mctx, emptyArr); return err }},
8484
{"Count", coll, nil, func(mctx SessionContext) error { _, err := coll.Count(mctx, emptyDoc); return err }},
8585
{"Distinct", coll, nil, func(mctx SessionContext) error { _, err := coll.Distinct(mctx, "field", emptyDoc); return err }},
8686
{"Find", coll, nil, func(mctx SessionContext) error { _, err := coll.Find(mctx, emptyDoc); return err }},

mongo/change_stream_test.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,17 @@ func skipIfBelow32(t *testing.T) {
9797
}
9898

9999
func createCollectionStream(t *testing.T, dbName string, collName string, pipeline interface{}) (*Collection, Cursor) {
100+
if pipeline == nil {
101+
pipeline = Pipeline{}
102+
}
100103
client := createTestClient(t)
101104
return createStream(t, client, dbName, collName, pipeline)
102105
}
103106

104107
func createMonitoredStream(t *testing.T, dbName string, collName string, pipeline interface{}) (*Collection, Cursor) {
108+
if pipeline == nil {
109+
pipeline = Pipeline{}
110+
}
105111
client := createMonitoredClient(t, monitor)
106112
return createStream(t, client, dbName, collName, pipeline)
107113
}
@@ -169,7 +175,7 @@ func TestChangeStream(t *testing.T) {
169175
_, err := coll.InsertOne(context.Background(), bsonx.Doc{{"x", bsonx.Int32(1)}})
170176
require.NoError(t, err)
171177

172-
changes, err := coll.Watch(context.Background(), nil)
178+
changes, err := coll.Watch(context.Background(), Pipeline{})
173179
require.NoError(t, err)
174180
defer changes.Close(ctx)
175181

@@ -242,7 +248,7 @@ func TestChangeStream(t *testing.T) {
242248
_, err := coll.InsertOne(context.Background(), bsonx.Doc{{"x", bsonx.Int32(1)}})
243249
require.NoError(t, err)
244250

245-
_, err = coll.Watch(context.Background(), nil)
251+
_, err = coll.Watch(context.Background(), Pipeline{})
246252
require.Error(t, err)
247253
if _, ok := err.(command.Error); !ok {
248254
t.Errorf("Should have returned command error, but got %T", err)
@@ -282,7 +288,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) {
282288
t.Run("TestTrackResumeToken", func(t *testing.T) {
283289
// Stream must continuously track last seen resumeToken
284290

285-
coll, stream := createCollectionStream(t, "TrackTokenDB", "TrackTokenColl", bsonx.Doc{})
291+
coll, stream := createCollectionStream(t, "TrackTokenDB", "TrackTokenColl", Pipeline{})
286292
defer closeCursor(stream)
287293

288294
cs := stream.(*changeStream)

mongo/mongo.go

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919

2020
"github.com/mongodb/mongo-go-driver/bson"
2121
"github.com/mongodb/mongo-go-driver/bson/bsoncodec"
22+
"github.com/mongodb/mongo-go-driver/bson/bsontype"
2223
"github.com/mongodb/mongo-go-driver/bson/primitive"
2324
)
2425

@@ -165,42 +166,29 @@ func ensureDollarKey(doc bsonx.Doc) error {
165166
func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface{}) (bsonx.Arr, error) {
166167
pipelineArr := bsonx.Arr{}
167168
switch t := pipeline.(type) {
168-
case Pipeline:
169-
for _, d := range t {
170-
doc, err := transformDocument(registry, d)
171-
if err != nil {
172-
return nil, err
173-
}
174-
pipelineArr = append(pipelineArr, bsonx.Document(doc))
175-
}
176-
case bsonx.Arr:
177-
pipelineArr = make(bsonx.Arr, len(t))
178-
copy(pipelineArr, t)
179-
case []bsonx.Doc:
180-
pipelineArr = bsonx.Arr{}
181-
182-
for _, doc := range t {
183-
pipelineArr = append(pipelineArr, bsonx.Document(doc))
169+
case bsoncodec.ValueMarshaler:
170+
btype, val, err := t.MarshalBSONValue()
171+
if err != nil {
172+
return nil, err
184173
}
185-
case []interface{}:
186-
pipelineArr = bsonx.Arr{}
187-
188-
for _, val := range t {
189-
doc, err := transformDocument(registry, val)
190-
if err != nil {
191-
return nil, err
192-
}
193-
194-
pipelineArr = append(pipelineArr, bsonx.Document(doc))
174+
if btype != bsontype.Array {
175+
return nil, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bsontype.Array)
195176
}
196-
default:
197-
p, err := transformDocument(registry, pipeline)
177+
err = pipelineArr.UnmarshalBSONValue(btype, val)
198178
if err != nil {
199179
return nil, err
200180
}
201-
202-
for _, elem := range p {
203-
pipelineArr = append(pipelineArr, elem.Value)
181+
default:
182+
val := reflect.ValueOf(t)
183+
if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) {
184+
return nil, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind())
185+
}
186+
for idx := 0; idx < val.Len(); idx++ {
187+
elem, err := transformDocument(registry, val.Index(idx).Interface())
188+
if err != nil {
189+
return nil, err
190+
}
191+
pipelineArr = append(pipelineArr, bsonx.Document(elem))
204192
}
205193
}
206194

mongo/mongo_test.go

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@ package mongo
88

99
import (
1010
"errors"
11+
"fmt"
1112
"testing"
1213

1314
"github.com/google/go-cmp/cmp"
1415
"github.com/mongodb/mongo-go-driver/bson"
16+
"github.com/mongodb/mongo-go-driver/bson/bsoncodec"
17+
"github.com/mongodb/mongo-go-driver/bson/bsontype"
18+
"github.com/mongodb/mongo-go-driver/bson/primitive"
1519
"github.com/mongodb/mongo-go-driver/x/bsonx"
20+
"github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
1621
)
1722

1823
func TestTransformDocument(t *testing.T) {
@@ -62,6 +67,161 @@ func TestTransformDocument(t *testing.T) {
6267
}
6368
}
6469

70+
func TestTransformAggregatePipeline(t *testing.T) {
71+
index, arr := bsoncore.AppendArrayStart(nil)
72+
dindex, arr := bsoncore.AppendDocumentElementStart(arr, "0")
73+
arr = bsoncore.AppendInt32Element(arr, "$limit", 12345)
74+
arr, _ = bsoncore.AppendDocumentEnd(arr, dindex)
75+
arr, _ = bsoncore.AppendArrayEnd(arr, index)
76+
77+
testCases := []struct {
78+
name string
79+
pipeline interface{}
80+
arr bsonx.Arr
81+
err error
82+
}{
83+
{"Pipeline/error", Pipeline{{{"hello", func() {}}}}, bsonx.Arr{}, MarshalError{Value: primitive.D{}}},
84+
{
85+
"Pipeline/success",
86+
Pipeline{{{"hello", "world"}}, {{"pi", 3.14159}}},
87+
bsonx.Arr{
88+
bsonx.Document(bsonx.Doc{{"hello", bsonx.String("world")}}),
89+
bsonx.Document(bsonx.Doc{{"pi", bsonx.Double(3.14159)}}),
90+
},
91+
nil,
92+
},
93+
{
94+
"bsonx.Arr",
95+
bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}})},
96+
bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}})},
97+
nil,
98+
},
99+
{
100+
"[]bsonx.Doc",
101+
[]bsonx.Doc{{{"$limit", bsonx.Int32(12345)}}},
102+
bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}})},
103+
nil,
104+
},
105+
{
106+
"primitive.A/error",
107+
primitive.A{"5"},
108+
bsonx.Arr{},
109+
MarshalError{Value: string("")},
110+
},
111+
{
112+
"primitive.A/success",
113+
primitive.A{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}},
114+
bsonx.Arr{
115+
bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}}),
116+
bsonx.Document(bsonx.Doc{{"$count", bsonx.String("foobar")}}),
117+
},
118+
nil,
119+
},
120+
{
121+
"bson.A/error",
122+
bson.A{"5"},
123+
bsonx.Arr{},
124+
MarshalError{Value: string("")},
125+
},
126+
{
127+
"bson.A/success",
128+
bson.A{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}},
129+
bsonx.Arr{
130+
bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}}),
131+
bsonx.Document(bsonx.Doc{{"$count", bsonx.String("foobar")}}),
132+
},
133+
nil,
134+
},
135+
{
136+
"[]interface{}/error",
137+
[]interface{}{"5"},
138+
bsonx.Arr{},
139+
MarshalError{Value: string("")},
140+
},
141+
{
142+
"[]interface{}/success",
143+
[]interface{}{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}},
144+
bsonx.Arr{
145+
bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}}),
146+
bsonx.Document(bsonx.Doc{{"$count", bsonx.String("foobar")}}),
147+
},
148+
nil,
149+
},
150+
{
151+
"bsoncodec.ValueMarshaler/MarshalBSONValue error",
152+
bvMarsh{err: errors.New("MarshalBSONValue error")},
153+
bsonx.Arr{},
154+
errors.New("MarshalBSONValue error"),
155+
},
156+
{
157+
"bsoncodec.ValueMarshaler/not array",
158+
bvMarsh{t: bsontype.String},
159+
bsonx.Arr{},
160+
fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", bsontype.String, bsontype.Array),
161+
},
162+
{
163+
"bsoncodec.ValueMarshaler/UnmarshalBSONValue error",
164+
bvMarsh{t: bsontype.Array},
165+
bsonx.Arr{},
166+
bsoncore.NewInsufficientBytesError(nil, nil),
167+
},
168+
{
169+
"bsoncodec.ValueMarshaler/success",
170+
bvMarsh{t: bsontype.Array, data: arr},
171+
bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}})},
172+
nil,
173+
},
174+
{
175+
"nil",
176+
nil,
177+
bsonx.Arr{},
178+
errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid"),
179+
},
180+
{
181+
"not array or slice",
182+
int64(42),
183+
bsonx.Arr{},
184+
errors.New("can only transform slices and arrays into aggregation pipelines, but got int64"),
185+
},
186+
{
187+
"array/error",
188+
[1]interface{}{int64(42)},
189+
bsonx.Arr{},
190+
MarshalError{Value: int64(0)},
191+
},
192+
{
193+
"array/success",
194+
[1]interface{}{primitive.D{{"$limit", int64(12345)}}},
195+
bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int64(12345)}})},
196+
nil,
197+
},
198+
{
199+
"slice/error",
200+
[]interface{}{int64(42)},
201+
bsonx.Arr{},
202+
MarshalError{Value: int64(0)},
203+
},
204+
{
205+
"slice/success",
206+
[]interface{}{primitive.D{{"$limit", int64(12345)}}},
207+
bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int64(12345)}})},
208+
nil,
209+
},
210+
}
211+
212+
for _, tc := range testCases {
213+
t.Run(tc.name, func(t *testing.T) {
214+
arr, err := transformAggregatePipeline(bson.NewRegistryBuilder().Build(), tc.pipeline)
215+
if !cmp.Equal(err, tc.err, cmp.Comparer(compareErrors)) {
216+
t.Errorf("Error does not match expected error. got %v; want %v", err, tc.err)
217+
}
218+
if !cmp.Equal(arr, tc.arr, cmp.AllowUnexported(bsonx.Val{})) {
219+
t.Errorf("Returned array does not match expected array. got %v; want %v", arr, tc.arr)
220+
}
221+
})
222+
}
223+
}
224+
65225
func compareErrors(err1, err2 error) bool {
66226
if err1 == nil && err2 == nil {
67227
return true
@@ -91,3 +251,15 @@ func (b bMarsh) MarshalBSON() ([]byte, error) {
91251
type reflectStruct struct {
92252
Foo string
93253
}
254+
255+
var _ bsoncodec.ValueMarshaler = bvMarsh{}
256+
257+
type bvMarsh struct {
258+
t bsontype.Type
259+
data []byte
260+
err error
261+
}
262+
263+
func (b bvMarsh) MarshalBSONValue() (bsontype.Type, []byte, error) {
264+
return b.t, b.data, b.err
265+
}

mongo/sessions_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ type CollFunction struct {
5757

5858
var ctx = context.Background()
5959
var emptyDoc = bsonx.Doc{}
60+
var emptyArr = bsonx.Arr{}
6061
var updateDoc = bsonx.Doc{{"$inc", bsonx.Document(bsonx.Doc{{"x", bsonx.Int32(1)}})}}
6162
var doc = bsonx.Doc{{"x", bsonx.Int32(1)}}
6263
var doc2 = bsonx.Doc{{"y", bsonx.Int32(1)}}
@@ -102,7 +103,7 @@ func createFuncMap(t *testing.T, dbName string, collName string, monitored bool)
102103
{"UpdateOne", coll, nil, func(mctx SessionContext) error { _, err := coll.UpdateOne(mctx, emptyDoc, updateDoc); return err }},
103104
{"UpdateMany", coll, nil, func(mctx SessionContext) error { _, err := coll.UpdateMany(mctx, emptyDoc, updateDoc); return err }},
104105
{"ReplaceOne", coll, nil, func(mctx SessionContext) error { _, err := coll.ReplaceOne(mctx, emptyDoc, emptyDoc); return err }},
105-
{"Aggregate", coll, nil, func(mctx SessionContext) error { _, err := coll.Aggregate(mctx, emptyDoc); return err }},
106+
{"Aggregate", coll, nil, func(mctx SessionContext) error { _, err := coll.Aggregate(mctx, emptyArr); return err }},
106107
{"Count", coll, nil, func(mctx SessionContext) error { _, err := coll.Count(mctx, emptyDoc); return err }},
107108
{"Distinct", coll, nil, func(mctx SessionContext) error { _, err := coll.Distinct(mctx, "field", emptyDoc); return err }},
108109
{"Find", coll, nil, func(mctx SessionContext) error { _, err := coll.Find(mctx, emptyDoc); return err }},
@@ -386,7 +387,7 @@ func TestSessions(t *testing.T) {
386387
}{
387388
{"ServerStatus", reflect.ValueOf(db.RunCommand), []interface{}{ctx, serverStatusDoc}, []interface{}{ctx, serverStatusDoc}},
388389
{"InsertOne", reflect.ValueOf(coll.InsertOne), []interface{}{ctx, doc}, []interface{}{ctx, doc2}},
389-
{"Aggregate", reflect.ValueOf(coll.Aggregate), []interface{}{ctx, emptyDoc}, []interface{}{ctx, emptyDoc}},
390+
{"Aggregate", reflect.ValueOf(coll.Aggregate), []interface{}{ctx, emptyArr}, []interface{}{ctx, emptyArr}},
390391
{"Find", reflect.ValueOf(coll.Find), []interface{}{ctx, emptyDoc}, []interface{}{ctx, emptyDoc}},
391392
}
392393

0 commit comments

Comments
 (0)