Skip to content

Commit c4dde16

Browse files
authored
GODRIVER-1782 correctly mark aggregation pipeline with $out (#537)
1 parent 9e2aca8 commit c4dde16

File tree

2 files changed

+65
-31
lines changed

2 files changed

+65
-31
lines changed

mongo/integration/crud_prose_test.go

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"testing"
1313

1414
"go.mongodb.org/mongo-driver/bson"
15+
"go.mongodb.org/mongo-driver/bson/bsontype"
1516
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1617
"go.mongodb.org/mongo-driver/mongo"
1718
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
@@ -145,6 +146,14 @@ func TestHintErrors(t *testing.T) {
145146
})
146147
}
147148

149+
type testValueMarshaler struct {
150+
val []bson.D
151+
}
152+
153+
func (tvm testValueMarshaler) MarshalBSONValue() (bsontype.Type, []byte, error) {
154+
return bson.MarshalValue(tvm.val)
155+
}
156+
148157
func TestAggregatePrimaryPreferredReadPreference(t *testing.T) {
149158
primaryPrefClientOpts := options.Client().
150159
SetWriteConcern(mtest.MajorityWc).
@@ -155,39 +164,62 @@ func TestAggregatePrimaryPreferredReadPreference(t *testing.T) {
155164
MinServerVersion("4.1.0") // Consistent with tests in aggregate-out-readConcern.json
156165

157166
mt := mtest.New(t, mtOpts)
158-
mt.Run("aggregate $out with non-primary read ppreference", func(mt *mtest.T) {
167+
mt.Run("aggregate $out with non-primary read preference", func(mt *mtest.T) {
159168
doc, err := bson.Marshal(bson.D{
160169
{"_id", 1},
161170
{"x", 11},
162171
})
163172
assert.Nil(mt, err, "Marshal error: %v", err)
164-
_, err = mt.Coll.InsertOne(mtest.Background, doc)
165-
assert.Nil(mt, err, "InsertOne error: %v", err)
166-
167-
mt.ClearEvents()
168173
outputCollName := "aggregate-read-pref-primary-preferred-output"
169-
outStage := bson.D{
170-
{"$out", outputCollName},
174+
testCases := []struct {
175+
name string
176+
pipeline interface{}
177+
}{
178+
{
179+
"pipeline",
180+
mongo.Pipeline{bson.D{{"$out", outputCollName}}},
181+
},
182+
{
183+
"doc slice",
184+
[]bson.D{{{"$out", outputCollName}}},
185+
},
186+
{
187+
"bson a",
188+
bson.A{bson.D{{"$out", outputCollName}}},
189+
},
190+
{
191+
"valueMarshaler",
192+
testValueMarshaler{[]bson.D{{{"$out", outputCollName}}}},
193+
},
194+
}
195+
for _, tc := range testCases {
196+
mt.Run(tc.name, func(mt *mtest.T) {
197+
_, err = mt.Coll.InsertOne(mtest.Background, doc)
198+
assert.Nil(mt, err, "InsertOne error: %v", err)
199+
200+
mt.ClearEvents()
201+
202+
cursor, err := mt.Coll.Aggregate(mtest.Background, tc.pipeline)
203+
assert.Nil(mt, err, "Aggregate error: %v", err)
204+
_ = cursor.Close(mtest.Background)
205+
206+
// Assert that the output collection contains the document we expect.
207+
outputColl := mt.CreateCollection(mtest.Collection{Name: outputCollName}, false)
208+
cursor, err = outputColl.Find(mtest.Background, bson.D{})
209+
assert.Nil(mt, err, "Find error: %v", err)
210+
defer cursor.Close(mtest.Background)
211+
212+
assert.True(mt, cursor.Next(mtest.Background), "expected Next to return true, got false")
213+
assert.True(mt, bytes.Equal(doc, cursor.Current), "expected document %s, got %s", bson.Raw(doc), cursor.Current)
214+
assert.False(mt, cursor.Next(mtest.Background), "unexpected document returned by Find: %s", cursor.Current)
215+
216+
// Assert that no read preference was sent to the server.
217+
evt := mt.GetStartedEvent()
218+
assert.Equal(mt, "aggregate", evt.CommandName, "expected command 'aggregate', got '%s'", evt.CommandName)
219+
_, err = evt.Command.LookupErr("$readPreference")
220+
assert.NotNil(mt, err, "expected command %s to not contain $readPreference", evt.Command)
221+
})
171222
}
172-
cursor, err := mt.Coll.Aggregate(mtest.Background, mongo.Pipeline{outStage})
173-
assert.Nil(mt, err, "Aggregate error: %v", err)
174-
_ = cursor.Close(mtest.Background)
175-
176-
// Assert that the output collection contains the document we expect.
177-
outputColl := mt.CreateCollection(mtest.Collection{Name: outputCollName}, false)
178-
cursor, err = outputColl.Find(mtest.Background, bson.D{})
179-
assert.Nil(mt, err, "Find error: %v", err)
180-
defer cursor.Close(mtest.Background)
181-
182-
assert.True(mt, cursor.Next(mtest.Background), "expected Next to return true, got false")
183-
assert.True(mt, bytes.Equal(doc, cursor.Current), "expected document %s, got %s", bson.Raw(doc), cursor.Current)
184-
assert.False(mt, cursor.Next(mtest.Background), "unexpected document returned by Find: %s", cursor.Current)
185-
186-
// Assert that no read preference was sent to the server.
187-
evt := mt.GetStartedEvent()
188-
assert.Equal(mt, "aggregate", evt.CommandName, "expected command 'aggregate', got '%s'", evt.CommandName)
189-
_, err = evt.Command.LookupErr("$readPreference")
190-
assert.NotNil(mt, err, "expected command %s to not contain $readPreference", evt.Command)
191223
})
192224
}
193225

mongo/mongo.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,13 @@ func transformAggregatePipelinev2(registry *bsoncodec.Registry, pipeline interfa
306306

307307
var hasOutputStage bool
308308
pipelineDoc := bsoncore.Document(val)
309-
if _, err := pipelineDoc.LookupErr("$out"); err == nil {
310-
hasOutputStage = true
311-
}
312-
if _, err := pipelineDoc.LookupErr("$merge"); err == nil {
313-
hasOutputStage = true
309+
values, _ := pipelineDoc.Values()
310+
if pipelineLen := len(values); pipelineLen > 0 {
311+
if finalDoc, ok := values[pipelineLen-1].DocumentOK(); ok {
312+
if elem, err := finalDoc.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") {
313+
hasOutputStage = true
314+
}
315+
}
314316
}
315317

316318
return pipelineDoc, hasOutputStage, nil

0 commit comments

Comments
 (0)