Skip to content

Commit 9eea702

Browse files
DRIVERS-3473 Add client-side validation for maxAwaitTime+op timeout
1 parent 32c7b85 commit 9eea702

File tree

7 files changed

+135
-141
lines changed

7 files changed

+135
-141
lines changed

internal/mongoutil/mongoutil.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
package mongoutil
88

99
import (
10+
"context"
1011
"reflect"
12+
"time"
1113

1214
"go.mongodb.org/mongo-driver/v2/mongo/options"
1315
)
@@ -83,3 +85,23 @@ func HostsFromURI(uri string) ([]string, error) {
8385

8486
return opts.Hosts, nil
8587
}
88+
89+
// ValidMaxAwaitTimeMS will return "false" if maxAwaitTimeMS is set, timeoutMS
90+
// is set to a non-zero value, and maxAwaitTimeMS is greater than or equal to
91+
// timeoutMS. Otherwise, the timeouts are valid.
92+
func ValidMaxAwaitTimeMS(ctx context.Context, timeout, maxAwaiTime *time.Duration) bool {
93+
if maxAwaiTime == nil {
94+
return true
95+
}
96+
97+
if deadline, ok := ctx.Deadline(); ok {
98+
ctxTimeout := time.Until(deadline)
99+
timeout = &ctxTimeout
100+
}
101+
102+
if timeout == nil {
103+
return true
104+
}
105+
106+
return *timeout <= 0 || *maxAwaiTime < *timeout
107+
}

internal/mongoutil/mongoutil_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
package mongoutil
88

99
import (
10+
"context"
1011
"strings"
1112
"testing"
13+
"time"
1214

15+
"go.mongodb.org/mongo-driver/v2/internal/assert"
1316
"go.mongodb.org/mongo-driver/v2/mongo/options"
1417
)
1518

@@ -32,3 +35,87 @@ func BenchmarkNewOptions(b *testing.B) {
3235
}
3336
})
3437
}
38+
39+
func TestValidChangeStreamTimeouts(t *testing.T) {
40+
t.Parallel()
41+
42+
newDurPtr := func(dur time.Duration) *time.Duration {
43+
return &dur
44+
}
45+
46+
tests := []struct {
47+
name string
48+
parent context.Context
49+
maxAwaitTimeout, timeout *time.Duration
50+
wantTimeout time.Duration
51+
want bool
52+
}{
53+
{
54+
name: "no context deadline and no timeouts",
55+
parent: context.Background(),
56+
maxAwaitTimeout: nil,
57+
timeout: nil,
58+
wantTimeout: 0,
59+
want: true,
60+
},
61+
{
62+
name: "no context deadline and maxAwaitTimeout",
63+
parent: context.Background(),
64+
maxAwaitTimeout: newDurPtr(1),
65+
timeout: nil,
66+
wantTimeout: 0,
67+
want: true,
68+
},
69+
{
70+
name: "no context deadline and timeout",
71+
parent: context.Background(),
72+
maxAwaitTimeout: nil,
73+
timeout: newDurPtr(1),
74+
wantTimeout: 0,
75+
want: true,
76+
},
77+
{
78+
name: "no context deadline and maxAwaitTime gt timeout",
79+
parent: context.Background(),
80+
maxAwaitTimeout: newDurPtr(2),
81+
timeout: newDurPtr(1),
82+
wantTimeout: 0,
83+
want: false,
84+
},
85+
{
86+
name: "no context deadline and maxAwaitTime lt timeout",
87+
parent: context.Background(),
88+
maxAwaitTimeout: newDurPtr(1),
89+
timeout: newDurPtr(2),
90+
wantTimeout: 0,
91+
want: true,
92+
},
93+
{
94+
name: "no context deadline and maxAwaitTime eq timeout",
95+
parent: context.Background(),
96+
maxAwaitTimeout: newDurPtr(1),
97+
timeout: newDurPtr(1),
98+
wantTimeout: 0,
99+
want: false,
100+
},
101+
{
102+
name: "no context deadline and maxAwaitTime with negative timeout",
103+
parent: context.Background(),
104+
maxAwaitTimeout: newDurPtr(1),
105+
timeout: newDurPtr(-1),
106+
wantTimeout: 0,
107+
want: true,
108+
},
109+
}
110+
111+
for _, test := range tests {
112+
test := test // Capture the range variable
113+
114+
t.Run(test.name, func(t *testing.T) {
115+
t.Parallel()
116+
117+
got := ValidMaxAwaitTimeMS(test.parent, test.timeout, test.maxAwaitTimeout)
118+
assert.Equal(t, test.want, got)
119+
})
120+
}
121+
}

internal/spectest/skip.go

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,6 @@ var skipTests = map[string][]string{
582582
"TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-operation-timeoutMS.json/timeoutMS_applied_to_withTransaction",
583583
"TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-timeoutMS.json",
584584
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_timeoutMode_is_cursor_lifetime",
585-
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS",
586585
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find",
587586
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set",
588587
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set",
@@ -817,19 +816,8 @@ var skipTests = map[string][]string{
817816
"TestUnifiedSpec/transactions-convenient-api/tests/unified/commit.json/withTransaction_commits_after_callback_returns",
818817
},
819818

820-
// GODRIVER-3473: the implementation of DRIVERS-2868 makes it clear that the
821-
// Go Driver does not correctly implement the following validation for
822-
// tailable awaitData cursors:
823-
//
824-
// Drivers MUST error if this option is set, timeoutMS is set to a
825-
// non-zero value, and maxAwaitTimeMS is greater than or equal to
826-
// timeoutMS.
827-
//
828-
// Once GODRIVER-3473 is completed, we can continue running these tests.
829-
"When constructing tailable awaitData cusors must validate, timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or equal to timeoutMS (GODRIVER-3473)": {
819+
"Address CSOT Compliance Issue in for Timeout Handling in Cursor Constructors (GODRIVER-3480)": {
830820
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS",
831-
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
832-
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
833821
},
834822
}
835823

mongo/change_stream.go

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"fmt"
1313
"reflect"
1414
"strconv"
15-
"time"
1615

1716
"go.mongodb.org/mongo-driver/v2/bson"
1817
"go.mongodb.org/mongo-driver/v2/internal/csot"
@@ -103,33 +102,6 @@ type changeStreamConfig struct {
103102
crypt driver.Crypt
104103
}
105104

106-
// validChangeStreamTimeouts will return "false" if maxAwaitTimeMS is set,
107-
// timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or
108-
// equal to timeoutMS. Otherwise, the timeouts are valid.
109-
func validChangeStreamTimeouts(ctx context.Context, cs *ChangeStream) bool {
110-
if cs.options == nil || cs.client == nil {
111-
return true
112-
}
113-
114-
maxAwaitTime := cs.options.MaxAwaitTime
115-
timeout := cs.client.timeout
116-
117-
if maxAwaitTime == nil {
118-
return true
119-
}
120-
121-
if deadline, ok := ctx.Deadline(); ok {
122-
ctxTimeout := time.Until(deadline)
123-
timeout = &ctxTimeout
124-
}
125-
126-
if timeout == nil {
127-
return true
128-
}
129-
130-
return *timeout <= 0 || *maxAwaitTime < *timeout
131-
}
132-
133105
func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline interface{},
134106
opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) {
135107
if ctx == nil {
@@ -145,6 +117,10 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in
145117
return nil, err
146118
}
147119

120+
if c := config.client; c != nil && !mongoutil.ValidMaxAwaitTimeMS(ctx, c.timeout, args.MaxAwaitTime) {
121+
return nil, fmt.Errorf("MaxAwaitTime must be less than the operation timeout")
122+
}
123+
148124
cs := &ChangeStream{
149125
client: config.client,
150126
bsonOpts: config.bsonOpts,
@@ -696,7 +672,11 @@ func (cs *ChangeStream) next(ctx context.Context, nonBlocking bool) bool {
696672
}
697673

698674
func (cs *ChangeStream) loopNext(ctx context.Context, nonBlocking bool) {
699-
if !validChangeStreamTimeouts(ctx, cs) {
675+
// Sending a maxAwaitTimeMS option to the server that is less than or equal to
676+
// the operation timeout will result in a socket timeout error. This block
677+
// short-circuits that behavior.
678+
if csOpts := cs.options; csOpts != nil && cs.client != nil &&
679+
!mongoutil.ValidMaxAwaitTimeMS(ctx, cs.client.timeout, csOpts.MaxAwaitTime) {
700680
cs.err = fmt.Errorf("MaxAwaitTime must be less than the operation timeout")
701681

702682
return

mongo/change_stream_test.go

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,9 @@
77
package mongo
88

99
import (
10-
"context"
1110
"testing"
12-
"time"
1311

1412
"go.mongodb.org/mongo-driver/v2/internal/assert"
15-
"go.mongodb.org/mongo-driver/v2/mongo/options"
1613
)
1714

1815
func TestChangeStream(t *testing.T) {
@@ -30,96 +27,3 @@ func TestChangeStream(t *testing.T) {
3027
assert.Nil(t, err, "Close error: %v", err)
3128
})
3229
}
33-
34-
func TestValidChangeStreamTimeouts(t *testing.T) {
35-
t.Parallel()
36-
37-
newDurPtr := func(dur time.Duration) *time.Duration {
38-
return &dur
39-
}
40-
41-
tests := []struct {
42-
name string
43-
parent context.Context
44-
maxAwaitTimeout, timeout *time.Duration
45-
wantTimeout time.Duration
46-
want bool
47-
}{
48-
{
49-
name: "no context deadline and no timeouts",
50-
parent: context.Background(),
51-
maxAwaitTimeout: nil,
52-
timeout: nil,
53-
wantTimeout: 0,
54-
want: true,
55-
},
56-
{
57-
name: "no context deadline and maxAwaitTimeout",
58-
parent: context.Background(),
59-
maxAwaitTimeout: newDurPtr(1),
60-
timeout: nil,
61-
wantTimeout: 0,
62-
want: true,
63-
},
64-
{
65-
name: "no context deadline and timeout",
66-
parent: context.Background(),
67-
maxAwaitTimeout: nil,
68-
timeout: newDurPtr(1),
69-
wantTimeout: 0,
70-
want: true,
71-
},
72-
{
73-
name: "no context deadline and maxAwaitTime gt timeout",
74-
parent: context.Background(),
75-
maxAwaitTimeout: newDurPtr(2),
76-
timeout: newDurPtr(1),
77-
wantTimeout: 0,
78-
want: false,
79-
},
80-
{
81-
name: "no context deadline and maxAwaitTime lt timeout",
82-
parent: context.Background(),
83-
maxAwaitTimeout: newDurPtr(1),
84-
timeout: newDurPtr(2),
85-
wantTimeout: 0,
86-
want: true,
87-
},
88-
{
89-
name: "no context deadline and maxAwaitTime eq timeout",
90-
parent: context.Background(),
91-
maxAwaitTimeout: newDurPtr(1),
92-
timeout: newDurPtr(1),
93-
wantTimeout: 0,
94-
want: false,
95-
},
96-
{
97-
name: "no context deadline and maxAwaitTime with negative timeout",
98-
parent: context.Background(),
99-
maxAwaitTimeout: newDurPtr(1),
100-
timeout: newDurPtr(-1),
101-
wantTimeout: 0,
102-
want: true,
103-
},
104-
}
105-
106-
for _, test := range tests {
107-
test := test // Capture the range variable
108-
109-
t.Run(test.name, func(t *testing.T) {
110-
t.Parallel()
111-
112-
cs := &ChangeStream{
113-
options: &options.ChangeStreamOptions{
114-
MaxAwaitTime: test.maxAwaitTimeout,
115-
},
116-
client: &Client{
117-
timeout: test.timeout,
118-
},
119-
}
120-
121-
got := validChangeStreamTimeouts(test.parent, cs)
122-
assert.Equal(t, test.want, got)
123-
})
124-
}
125-
}

mongo/collection.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,13 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption
954954
return nil, err
955955
}
956956

957+
// Sending a maxAwaitTimeMS option to the server that is less than or equal to
958+
// the operation timeout will result in a socket timeout error. This block
959+
// short-circuits that behavior.
960+
if c := a.client; c != nil && !mongoutil.ValidMaxAwaitTimeMS(a.ctx, c.timeout, args.MaxAwaitTime) {
961+
return nil, fmt.Errorf("MaxAwaitTime must be less than the operation timeout")
962+
}
963+
957964
cursorOpts := a.client.createBaseCursorOptions()
958965

959966
cursorOpts.MarshalValueEncoderFn = newEncoderFn(a.bsonOpts, a.registry)
@@ -1347,11 +1354,17 @@ func (coll *Collection) find(
13471354
omitMaxTimeMS bool,
13481355
args *options.FindOptions,
13491356
) (cur *Cursor, err error) {
1350-
13511357
if ctx == nil {
13521358
ctx = context.Background()
13531359
}
13541360

1361+
// Sending a maxAwaitTimeMS option to the server that is less than or equal to
1362+
// the operation timeout will result in a socket timeout error. This block
1363+
// short-circuits that behavior.
1364+
if c := coll.client; c != nil && !mongoutil.ValidMaxAwaitTimeMS(ctx, c.timeout, args.MaxAwaitTime) {
1365+
return nil, fmt.Errorf("MaxAwaitTime must be less than the operation timeout")
1366+
}
1367+
13551368
f, err := marshal(filter, coll.bsonOpts, coll.registry)
13561369
if err != nil {
13571370
return nil, err

x/mongo/driver/batch_cursor.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@ func (bc *BatchCursor) getMore(ctx context.Context) {
381381

382382
bc.err = Operation{
383383
CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) {
384-
// If maxAwaitTime > remaining timeoutMS - minRoundTripTime, then use
385-
// send remaining TimeoutMS - minRoundTripTime allowing the server an
384+
// If maxAwaitTime > remaining timeoutMS - minRoundTripTime, then send
385+
// remaining TimeoutMS - minRoundTripTime, allowing the server an
386386
// opportunity to respond with an empty batch.
387387
var maxTimeMS int64
388388
if bc.maxAwaitTime != nil {

0 commit comments

Comments
 (0)