Skip to content

Commit 6d53408

Browse files
author
Divjot Arora
committed
Fix aggregate maxTimeMS integration test.
GODRIVER-813 Change-Id: Icdaba2d0768109ca17aba7f4dca1c69d6fc6606c
1 parent 2cd4b88 commit 6d53408

File tree

4 files changed

+181
-182
lines changed

4 files changed

+181
-182
lines changed

internal/testutil/config.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ package testutil
99
import (
1010
"context"
1111
"fmt"
12+
"math"
1213
"os"
1314
"reflect"
15+
"strconv"
1416
"strings"
1517
"sync"
1618
"testing"
@@ -283,3 +285,29 @@ func Integration(t *testing.T) {
283285
t.Skip("skipping integration test in short mode")
284286
}
285287
}
288+
289+
// compareVersions compares two version number strings (i.e. positive integers separated by
290+
// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is
291+
// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11.
292+
//
293+
// Returns a positive int if version1 is greater than version2, a negative int if version1 is less
294+
// than version2, and 0 if version1 is equal to version2.
295+
func CompareVersions(t *testing.T, v1 string, v2 string) int {
296+
n1 := strings.Split(v1, ".")
297+
n2 := strings.Split(v2, ".")
298+
299+
for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ {
300+
i1, err := strconv.Atoi(n1[i])
301+
require.NoError(t, err)
302+
303+
i2, err := strconv.Atoi(n2[i])
304+
require.NoError(t, err)
305+
306+
difference := i1 - i2
307+
if difference != 0 {
308+
return difference
309+
}
310+
}
311+
312+
return 0
313+
}
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package integration
8+
9+
import (
10+
"context"
11+
"github.com/mongodb/mongo-go-driver/bson"
12+
"github.com/mongodb/mongo-go-driver/event"
13+
"github.com/mongodb/mongo-go-driver/internal/testutil"
14+
"github.com/mongodb/mongo-go-driver/mongo/options"
15+
"github.com/mongodb/mongo-go-driver/mongo/readpref"
16+
"github.com/mongodb/mongo-go-driver/x/bsonx"
17+
"github.com/mongodb/mongo-go-driver/x/mongo/driver"
18+
"github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
19+
"github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
20+
"github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid"
21+
"github.com/mongodb/mongo-go-driver/x/network/command"
22+
"github.com/mongodb/mongo-go-driver/x/network/description"
23+
"github.com/stretchr/testify/require"
24+
"testing"
25+
"time"
26+
)
27+
28+
func setUpMonitor() (*event.CommandMonitor, chan *event.CommandStartedEvent, chan *event.CommandSucceededEvent, chan *event.CommandFailedEvent) {
29+
started := make(chan *event.CommandStartedEvent, 1)
30+
succeeded := make(chan *event.CommandSucceededEvent, 1)
31+
failed := make(chan *event.CommandFailedEvent, 1)
32+
33+
return &event.CommandMonitor{
34+
Started: func(ctx context.Context, e *event.CommandStartedEvent) {
35+
started <- e
36+
},
37+
Succeeded: func(ctx context.Context, e *event.CommandSucceededEvent) {
38+
succeeded <- e
39+
},
40+
Failed: func(ctx context.Context, e *event.CommandFailedEvent) {
41+
failed <- e
42+
},
43+
}, started, succeeded, failed
44+
}
45+
46+
func skipIfBelow32(ctx context.Context, t *testing.T, topo *topology.Topology) {
47+
server, err := topo.SelectServer(ctx, description.WriteSelector())
48+
noerr(t, err)
49+
50+
versionCmd := bsonx.Doc{{"serverStatus", bsonx.Int32(1)}}
51+
serverStatus, err := testutil.RunCommand(t, server.Server, dbName, versionCmd)
52+
version, err := serverStatus.LookupErr("version")
53+
54+
if testutil.CompareVersions(t, version.StringValue(), "3.2") < 0 {
55+
t.Skip()
56+
}
57+
}
58+
59+
func TestAggregate(t *testing.T) {
60+
t.Run("TestMaxTimeMSInGetMore", func(t *testing.T) {
61+
ctx := context.Background()
62+
monitor, started, succeeded, failed := setUpMonitor()
63+
dbName := "TestAggMaxTimeDB"
64+
collName := "TestAggMaxTimeColl"
65+
top := testutil.MonitoredTopology(t, dbName, monitor)
66+
clearChannels(started, succeeded, failed)
67+
skipIfBelow32(ctx, t, top)
68+
69+
clientID, err := uuid.New()
70+
noerr(t, err)
71+
72+
ns := command.Namespace{
73+
DB: dbName,
74+
Collection: collName,
75+
}
76+
pool := &session.Pool{}
77+
78+
clearChannels(started, succeeded, failed)
79+
_, err = driver.Insert(
80+
ctx,
81+
command.Insert{
82+
NS: ns,
83+
Docs: []bsonx.Doc{
84+
{{"x", bsonx.Int32(1)}},
85+
{{"x", bsonx.Int32(1)}},
86+
{{"x", bsonx.Int32(1)}},
87+
},
88+
},
89+
top,
90+
description.WriteSelector(),
91+
clientID,
92+
pool,
93+
false,
94+
)
95+
noerr(t, err)
96+
97+
clearChannels(started, succeeded, failed)
98+
cmd := command.Aggregate{
99+
NS: ns,
100+
Pipeline: bsonx.Arr{},
101+
}
102+
batchCursor, err := driver.Aggregate(
103+
ctx,
104+
cmd,
105+
top,
106+
description.ReadPrefSelector(readpref.Primary()),
107+
description.WriteSelector(),
108+
clientID,
109+
pool,
110+
bson.DefaultRegistry,
111+
options.Aggregate().SetMaxAwaitTime(10*time.Millisecond).SetBatchSize(2),
112+
)
113+
noerr(t, err)
114+
115+
var e *event.CommandStartedEvent
116+
select {
117+
case e = <-started:
118+
case <-time.After(200 * time.Millisecond):
119+
t.Fatal("timed out waiting for aggregate")
120+
}
121+
122+
require.Equal(t, "aggregate", e.CommandName)
123+
124+
clearChannels(started, succeeded, failed)
125+
// first Next() should automatically return true
126+
require.True(t, batchCursor.Next(ctx), "expected true from first Next, got false")
127+
clearChannels(started, succeeded, failed)
128+
batchCursor.Next(ctx) // should do getMore
129+
130+
select {
131+
case e = <-started:
132+
case <-time.After(200 * time.Millisecond):
133+
t.Fatal("timed out waiting for getMore")
134+
}
135+
require.Equal(t, "getMore", e.CommandName)
136+
_, err = e.Command.LookupErr("maxTimeMS")
137+
noerr(t, err)
138+
})
139+
}
140+
141+
func clearChannels(s chan *event.CommandStartedEvent, succ chan *event.CommandSucceededEvent, f chan *event.CommandFailedEvent) {
142+
for len(s) > 0 {
143+
<-s
144+
}
145+
for len(succ) > 0 {
146+
<-succ
147+
}
148+
for len(f) > 0 {
149+
<-f
150+
}
151+
}

x/network/integration/aggregate_test.go

Lines changed: 0 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,17 @@ package integration
88

99
import (
1010
"context"
11-
"fmt"
12-
"os"
1311
"strings"
1412
"testing"
15-
"time"
1613

1714
"github.com/mongodb/mongo-go-driver/bson"
1815
"github.com/mongodb/mongo-go-driver/internal/testutil"
19-
"github.com/mongodb/mongo-go-driver/internal/testutil/israce"
2016
"github.com/mongodb/mongo-go-driver/mongo/writeconcern"
2117
"github.com/mongodb/mongo-go-driver/x/bsonx"
22-
"github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
23-
"github.com/mongodb/mongo-go-driver/x/mongo/driver"
2418
"github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
2519
"github.com/mongodb/mongo-go-driver/x/network/address"
2620
"github.com/mongodb/mongo-go-driver/x/network/command"
2721
"github.com/mongodb/mongo-go-driver/x/network/description"
28-
"github.com/stretchr/testify/assert"
2922
)
3023

3124
func TestCommandAggregate(t *testing.T) {
@@ -158,146 +151,3 @@ func TestCommandAggregate(t *testing.T) {
158151
noerr(t, err)
159152
})
160153
}
161-
162-
func TestAggregatePassesMaxAwaitTimeMSThroughToGetMore(t *testing.T) {
163-
if os.Getenv("TOPOLOGY") != "replica_set" {
164-
t.Skip()
165-
}
166-
167-
startedChan, succeededChan, failedChan, monitor := initMonitor()
168-
169-
dbName := fmt.Sprintf("mongo-go-driver-%d-agg", os.Getpid())
170-
colName := testutil.ColName(t)
171-
172-
server, err := testutil.MonitoredTopology(t, dbName, monitor).SelectServer(context.Background(), description.WriteSelector())
173-
noerr(t, err)
174-
175-
versionCmd := bsonx.Doc{{"serverStatus", bsonx.Int32(1)}}
176-
serverStatus, err := testutil.RunCommand(t, server.Server, dbName, versionCmd)
177-
version, err := serverStatus.LookupErr("version")
178-
179-
if compareVersions(t, version.StringValue(), "3.6") < 0 {
180-
t.Skip()
181-
}
182-
183-
// create capped collection
184-
createCmd := bsonx.Doc{
185-
{"create", bsonx.String(colName)},
186-
{"capped", bsonx.Boolean(true)},
187-
{"size", bsonx.Int32(1000)}}
188-
_, err = testutil.RunCommand(t, server.Server, dbName, createCmd)
189-
noerr(t, err)
190-
191-
conn, err := server.Connection(context.Background())
192-
noerr(t, err)
193-
194-
// create an aggregate command that results with a TAILABLEAWAIT cursor
195-
result, err := (&command.Aggregate{
196-
NS: command.Namespace{DB: dbName, Collection: testutil.ColName(t)},
197-
Pipeline: bsonx.Arr{
198-
bsonx.Document(bsonx.Doc{
199-
{"$changeStream", bsonx.Document(bsonx.Doc{})}}),
200-
bsonx.Document(bsonx.Doc{
201-
{"$match", bsonx.Document(bsonx.Doc{
202-
{"fullDocument._id", bsonx.Document(bsonx.Doc{{"$gte", bsonx.Int32(1)}})},
203-
})}})},
204-
Opts: []bsonx.Elem{{"batchSize", bsonx.Int32(2)}},
205-
CursorOpts: []bsonx.Elem{
206-
{"batchSize", bsonx.Int32(2)},
207-
{"maxTimeMS", bsonx.Int64(50)},
208-
},
209-
}).RoundTrip(context.Background(), server.SelectedDescription(), conn)
210-
noerr(t, err)
211-
212-
cursor, err := driver.NewBatchCursor(
213-
bsoncore.Document(result), nil, nil, server.Server,
214-
bsonx.Elem{"batchSize", bsonx.Int32(2)}, bsonx.Elem{"maxTimeMS", bsonx.Int64(50)},
215-
)
216-
noerr(t, err)
217-
218-
// insert some documents
219-
insertCmd := bsonx.Doc{
220-
{"insert", bsonx.String(colName)},
221-
{"documents", bsonx.Array(bsonx.Arr{
222-
bsonx.Document(bsonx.Doc{{"_id", bsonx.Int32(1)}}),
223-
bsonx.Document(bsonx.Doc{{"_id", bsonx.Int32(2)}}),
224-
bsonx.Document(bsonx.Doc{{"_id", bsonx.Int32(3)}})})}}
225-
_, err = testutil.RunCommand(t, server.Server, dbName, insertCmd)
226-
227-
// wait a bit between insert and getMore commands
228-
time.Sleep(time.Millisecond * 100)
229-
if israce.Enabled {
230-
time.Sleep(time.Millisecond * 400) // wait a little longer when race detector is enabled.
231-
}
232-
233-
ctx, cancel := context.WithCancel(context.Background())
234-
if israce.Enabled {
235-
time.AfterFunc(time.Millisecond*2000, cancel)
236-
} else {
237-
time.AfterFunc(time.Millisecond*900, cancel)
238-
}
239-
for cursor.Next(ctx) {
240-
}
241-
242-
// allow for iteration over range chan
243-
close(startedChan)
244-
close(succeededChan)
245-
close(failedChan)
246-
247-
// no commands should have failed
248-
if len(failedChan) != 0 {
249-
t.Errorf("%d command(s) failed", len(failedChan))
250-
}
251-
252-
// check the expected commands were started
253-
for started := range startedChan {
254-
switch started.CommandName {
255-
case "aggregate":
256-
assert.Equal(t, 2, int(started.Command.Lookup("cursor", "batchSize").Int32()))
257-
assert.Equal(t, started.Command.Lookup("maxAwaitTimeMS"), bson.RawValue{},
258-
"Should not have sent maxAwaitTimeMS in find command")
259-
case "getMore":
260-
assert.Equal(t, 2, int(started.Command.Lookup("batchSize").Int32()))
261-
assert.Equal(t, 50, int(started.Command.Lookup("maxTimeMS").Int64()),
262-
"Should have sent maxTimeMS in getMore command")
263-
default:
264-
continue
265-
}
266-
}
267-
268-
// to keep track of seen documents
269-
id := 1
270-
271-
// check expected commands succeeded
272-
for succeeded := range succeededChan {
273-
switch succeeded.CommandName {
274-
case "aggregate":
275-
assert.Equal(t, 1, int(succeeded.Reply.Lookup("ok").Double()))
276-
277-
actual, err := succeeded.Reply.Lookup("cursor", "firstBatch").Array().Values()
278-
assert.NoError(t, err)
279-
280-
for _, v := range actual {
281-
assert.Equal(t, id, int(v.Document().Lookup("fullDocument", "_id").Int32()))
282-
id++
283-
}
284-
case "getMore":
285-
assert.Equal(t, "getMore", succeeded.CommandName)
286-
assert.Equal(t, 1, int(succeeded.Reply.Lookup("ok").Double()))
287-
288-
actual, err := succeeded.Reply.Lookup("cursor", "nextBatch").Array().Values()
289-
assert.NoError(t, err)
290-
291-
for _, v := range actual {
292-
assert.Equal(t, id, int(v.Document().Lookup("fullDocument", "_id").Int32()))
293-
id++
294-
}
295-
default:
296-
continue
297-
}
298-
}
299-
300-
if id <= 3 {
301-
t.Errorf("not all documents returned; last seen id = %d", id-1)
302-
}
303-
}

0 commit comments

Comments
 (0)