Skip to content

Commit b986070

Browse files
rfblue2Roland Fong
authored andcommitted
Allow passing aggregateopt to Aggregate
GODRIVER-272 Change-Id: Iec3a4446f5e6624510ddd1e75f498c2a64e5d929
1 parent 7c88c5c commit b986070

File tree

4 files changed

+33
-25
lines changed

4 files changed

+33
-25
lines changed

mongo/aggregateopt/aggregateopt.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ var aggregateBundle = new(AggregateBundle)
1313
// Aggregate is options for the aggregate() function
1414
type Aggregate interface {
1515
aggregate()
16-
ConvertOption() option.Optioner
16+
ConvertOption() option.AggregateOptioner
1717
}
1818

1919
// AggregateBundle is a bundle of Aggregate options
@@ -26,7 +26,7 @@ type AggregateBundle struct {
2626
func (ab *AggregateBundle) aggregate() {}
2727

2828
// ConvertOption implements the Aggregate interface
29-
func (ab *AggregateBundle) ConvertOption() option.Optioner { return nil }
29+
func (ab *AggregateBundle) ConvertOption() option.AggregateOptioner { return nil }
3030

3131
// BundleAggregate bundles Aggregate options
3232
func BundleAggregate(opts ...Aggregate) *AggregateBundle {
@@ -135,7 +135,7 @@ func (ab *AggregateBundle) bundleLength() int {
135135
}
136136

137137
// Unbundle transforms a bundle into a slice of options, optionally deduplicating
138-
func (ab *AggregateBundle) Unbundle(deduplicate bool) ([]option.Optioner, error) {
138+
func (ab *AggregateBundle) Unbundle(deduplicate bool) ([]option.AggregateOptioner, error) {
139139

140140
options, err := ab.unbundle()
141141
if err != nil {
@@ -166,14 +166,14 @@ func (ab *AggregateBundle) Unbundle(deduplicate bool) ([]option.Optioner, error)
166166
}
167167

168168
// Helper that recursively unwraps bundle into slice of options
169-
func (ab *AggregateBundle) unbundle() ([]option.Optioner, error) {
169+
func (ab *AggregateBundle) unbundle() ([]option.AggregateOptioner, error) {
170170
if ab == nil {
171171
return nil, nil
172172
}
173173

174174
listLen := ab.bundleLength()
175175

176-
options := make([]option.Optioner, listLen)
176+
options := make([]option.AggregateOptioner, listLen)
177177
index := listLen - 1
178178

179179
for listHead := ab; listHead != nil && listHead.option != nil; listHead = listHead.next {
@@ -263,7 +263,7 @@ type OptAllowDiskUse option.OptAllowDiskUse
263263
func (OptAllowDiskUse) aggregate() {}
264264

265265
// ConvertOption implements the Aggregate interface
266-
func (opt OptAllowDiskUse) ConvertOption() option.Optioner {
266+
func (opt OptAllowDiskUse) ConvertOption() option.AggregateOptioner {
267267
return option.OptAllowDiskUse(opt)
268268
}
269269

@@ -273,15 +273,15 @@ type OptBatchSize option.OptBatchSize
273273
func (OptBatchSize) aggregate() {}
274274

275275
// ConvertOption implements the Aggregate interface
276-
func (opt OptBatchSize) ConvertOption() option.Optioner {
276+
func (opt OptBatchSize) ConvertOption() option.AggregateOptioner {
277277
return option.OptBatchSize(opt)
278278
}
279279

280280
// OptBypassDocumentValidation allows the write to opt-out of document-level validation.
281281
type OptBypassDocumentValidation option.OptBypassDocumentValidation
282282

283283
// ConvertOption implements the Aggregate interface
284-
func (opt OptBypassDocumentValidation) ConvertOption() option.Optioner {
284+
func (opt OptBypassDocumentValidation) ConvertOption() option.AggregateOptioner {
285285
return option.OptBypassDocumentValidation(opt)
286286
}
287287

@@ -293,7 +293,7 @@ type OptCollation option.OptCollation
293293
func (OptCollation) aggregate() {}
294294

295295
// ConvertOption implements the Aggregate interface
296-
func (opt OptCollation) ConvertOption() option.Optioner {
296+
func (opt OptCollation) ConvertOption() option.AggregateOptioner {
297297
return option.OptCollation(opt)
298298
}
299299

@@ -303,7 +303,7 @@ type OptMaxTime option.OptMaxTime
303303
func (OptMaxTime) aggregate() {}
304304

305305
// ConvertOption implements the Aggregate interface
306-
func (opt OptMaxTime) ConvertOption() option.Optioner {
306+
func (opt OptMaxTime) ConvertOption() option.AggregateOptioner {
307307
return option.OptMaxTime(opt)
308308
}
309309

@@ -313,7 +313,7 @@ type OptComment option.OptComment
313313
func (OptComment) aggregate() {}
314314

315315
// ConvertOption implements the Aggregate interface
316-
func (opt OptComment) ConvertOption() option.Optioner {
316+
func (opt OptComment) ConvertOption() option.AggregateOptioner {
317317
return option.OptComment(opt)
318318
}
319319

@@ -323,6 +323,6 @@ type OptHint option.OptHint
323323
func (OptHint) aggregate() {}
324324

325325
// ConvertOption implements the Aggregate interface
326-
func (opt OptHint) ConvertOption() option.Optioner {
326+
func (opt OptHint) ConvertOption() option.AggregateOptioner {
327327
return option.OptHint(opt)
328328
}

mongo/collection.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/mongodb/mongo-go-driver/core/readconcern"
2020
"github.com/mongodb/mongo-go-driver/core/readpref"
2121
"github.com/mongodb/mongo-go-driver/core/writeconcern"
22+
"github.com/mongodb/mongo-go-driver/mongo/aggregateopt"
2223
)
2324

2425
// Collection performs operations on a given collection.
@@ -419,7 +420,7 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{},
419420
// *bson.Document. See TransformDocument for the list of valid types for
420421
// pipeline.
421422
func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{},
422-
opts ...option.AggregateOptioner) (Cursor, error) {
423+
opts ...aggregateopt.Aggregate) (Cursor, error) {
423424

424425
if ctx == nil {
425426
ctx = context.Background()
@@ -430,11 +431,17 @@ func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{},
430431
return nil, err
431432
}
432433

434+
// convert options into []option.Optioner and dedup
435+
aggOpts, err := aggregateopt.BundleAggregate(opts...).Unbundle(true)
436+
if err != nil {
437+
return nil, err
438+
}
439+
433440
oldns := coll.namespace()
434441
cmd := command.Aggregate{
435442
NS: command.Namespace{DB: oldns.DB, Collection: oldns.Collection},
436443
Pipeline: pipelineArr,
437-
Opts: opts,
444+
Opts: aggOpts,
438445
ReadPref: coll.readPreference,
439446
}
440447
return dispatch.Aggregate(ctx, cmd, coll.client.topology, coll.readSelector, coll.writeSelector, coll.writeConcern)

mongo/collection_internal_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/mongodb/mongo-go-driver/core/option"
1818
"github.com/mongodb/mongo-go-driver/core/writeconcern"
1919
"github.com/mongodb/mongo-go-driver/internal/testutil"
20+
"github.com/mongodb/mongo-go-driver/mongo/aggregateopt"
2021
"github.com/stretchr/testify/assert"
2122
"github.com/stretchr/testify/require"
2223
)
@@ -1005,7 +1006,7 @@ func TestCollection_Aggregate(t *testing.T) {
10051006
),
10061007
))
10071008

1008-
cursor, err := coll.Aggregate(context.Background(), pipeline)
1009+
cursor, err := coll.Aggregate(context.Background(), pipeline, aggregateopt.BundleAggregate())
10091010
require.Nil(t, err)
10101011

10111012
for i := 2; i < 5; i++ {
@@ -1025,7 +1026,7 @@ func TestCollection_Aggregate(t *testing.T) {
10251026
}
10261027
}
10271028

1028-
func testAggregateWithOptions(t *testing.T, createIndex bool, option option.AggregateOptioner) error {
1029+
func testAggregateWithOptions(t *testing.T, createIndex bool, opts aggregateopt.Aggregate) error {
10291030
coll := createTestCollection(t, nil, nil)
10301031
initCollection(t, coll)
10311032

@@ -1064,7 +1065,7 @@ func testAggregateWithOptions(t *testing.T, createIndex bool, option option.Aggr
10641065
),
10651066
))
10661067

1067-
cursor, err := coll.Aggregate(context.Background(), pipeline, option)
1068+
cursor, err := coll.Aggregate(context.Background(), pipeline, opts)
10681069
if err != nil {
10691070
return err
10701071
}
@@ -1107,10 +1108,9 @@ func TestCollection_Aggregate_IndexHint(t *testing.T) {
11071108

11081109
t.Parallel()
11091110

1110-
hint, err := Opt.Hint(bson.NewDocument(bson.EC.Int32("x", 1)))
1111-
require.NoError(t, err)
1111+
hint := aggregateopt.Hint(bson.NewDocument(bson.EC.Int32("x", 1)))
11121112

1113-
err = testAggregateWithOptions(t, true, hint)
1113+
err := testAggregateWithOptions(t, true, hint)
11141114
require.NoError(t, err)
11151115
}
11161116

@@ -1121,7 +1121,7 @@ func TestCollection_Aggregate_withOptions(t *testing.T) {
11211121

11221122
t.Parallel()
11231123

1124-
err := testAggregateWithOptions(t, false, Opt.AllowDiskUse(true))
1124+
err := testAggregateWithOptions(t, false, aggregateopt.AllowDiskUse(true))
11251125
require.NoError(t, err)
11261126
}
11271127

mongo/crud_spec_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/mongodb/mongo-go-driver/bson"
1515
"github.com/mongodb/mongo-go-driver/core/option"
1616
"github.com/mongodb/mongo-go-driver/internal/testutil/helpers"
17+
"github.com/mongodb/mongo-go-driver/mongo/aggregateopt"
1718
"github.com/stretchr/testify/require"
1819
)
1920

@@ -186,14 +187,14 @@ func aggregateTest(t *testing.T, db *Database, coll *Collection, test *testCase)
186187
t.Run(test.Description, func(t *testing.T) {
187188
pipeline := test.Operation.Arguments["pipeline"].([]interface{})
188189

189-
var opts []option.AggregateOptioner
190+
opts := aggregateopt.BundleAggregate()
190191

191192
if batchSize, found := test.Operation.Arguments["batchSize"]; found {
192-
opts = append(opts, Opt.BatchSize(int32(batchSize.(float64))))
193+
opts = opts.BatchSize(int32(batchSize.(float64)))
193194
}
194195

195196
if collation, found := test.Operation.Arguments["collation"]; found {
196-
opts = append(opts, Opt.Collation(collationFromMap(collation.(map[string]interface{}))))
197+
opts = opts.Collation(*collationFromMap(collation.(map[string]interface{})))
197198
}
198199

199200
out := false
@@ -203,7 +204,7 @@ func aggregateTest(t *testing.T, db *Database, coll *Collection, test *testCase)
203204
}
204205
}
205206

206-
cursor, err := coll.Aggregate(context.Background(), pipeline, opts...)
207+
cursor, err := coll.Aggregate(context.Background(), pipeline, opts)
207208
require.NoError(t, err)
208209

209210
if !out {

0 commit comments

Comments
 (0)