Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions internal/driverutil/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

package driverutil

import (
"context"
"math"
"time"
)

// Operation Names should be sourced from the command reference documentation:
// https://www.mongodb.com/docs/manual/reference/command/
const (
Expand All @@ -30,3 +36,34 @@ const (
UpdateOp = "update" // UpdateOp is the name for updating
BulkWriteOp = "bulkWrite" // BulkWriteOp is the name for client-level bulk write
)

// CalculateMaxTimeMS calculates the maxTimeMS value to send to the server
// based on the context deadline and the minimum round trip time. If the
// calculated maxTimeMS is likely to cause a socket timeout, then this function
// will return 0 and false.
func CalculateMaxTimeMS(ctx context.Context, rttMin time.Duration) (int64, bool) {
deadline, ok := ctx.Deadline()
if !ok {
return 0, true
}

remainingTimeout := time.Until(deadline)

// Always round up to the next millisecond value so we never truncate the calculated
// maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms).
maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond)
if maxTimeMS <= 0 {
return 0, false
}

// The server will return a "BadValue" error if maxTimeMS is greater
// than the maximum positive int32 value (about 24.9 days). If the
// user specified a timeout value greater than that, omit maxTimeMS
// and let the client-side timeout handle cancelling the op if the
// timeout is ever reached.
if maxTimeMS > math.MaxInt32 {
return 0, true
}

return maxTimeMS, true
}
113 changes: 113 additions & 0 deletions internal/driverutil/operation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (C) MongoDB, Inc. 2025-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package driverutil

import (
"context"
"math"
"testing"
"time"

"go.mongodb.org/mongo-driver/v2/internal/assert"
)

func TestCalculateMaxTimeMS(t *testing.T) {
tests := []struct {
name string
ctx context.Context
rttMin time.Duration
wantZero bool
wantOk bool
wantPositive bool
wantExact int64
}{
{
name: "no deadline",
ctx: context.Background(),
rttMin: 10 * time.Millisecond,
wantZero: true,
wantOk: true,
wantPositive: false,
},
{
name: "deadline expired",
ctx: func() context.Context {
ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) //nolint:govet
return ctx
}(),
wantZero: true,
wantOk: false,
wantPositive: false,
},
{
name: "remaining timeout < rttMin",
ctx: func() context.Context {
ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(1*time.Millisecond)) //nolint:govet
return ctx
}(),
rttMin: 10 * time.Millisecond,
wantZero: true,
wantOk: false,
wantPositive: false,
},
{
name: "normal positive result",
ctx: func() context.Context {
ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) //nolint:govet
return ctx
}(),
wantZero: false,
wantOk: true,
wantPositive: true,
},
{
name: "beyond maxInt32",
ctx: func() context.Context {
dur := time.Now().Add(time.Duration(math.MaxInt32+1000) * time.Millisecond)
ctx, _ := context.WithDeadline(context.Background(), dur) //nolint:govet
return ctx
}(),
wantZero: true,
wantOk: true,
wantPositive: false,
},
{
name: "round up to 1ms",
ctx: func() context.Context {
ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(999*time.Microsecond)) //nolint:govet
return ctx
}(),
wantOk: true,
wantExact: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := CalculateMaxTimeMS(tt.ctx, tt.rttMin)

assert.Equal(t, tt.wantOk, got1)

if tt.wantExact > 0 && got != tt.wantExact {
t.Errorf("CalculateMaxTimeMS() got = %v, want %v", got, tt.wantExact)
}

if tt.wantZero && got != 0 {
t.Errorf("CalculateMaxTimeMS() got = %v, want 0", got)
}

if !tt.wantZero && got == 0 {
t.Errorf("CalculateMaxTimeMS() got = %v, want > 0", got)
}

if !tt.wantZero && tt.wantPositive && got <= 0 {
t.Errorf("CalculateMaxTimeMS() got = %v, want > 0", got)
}
})
}

}
70 changes: 70 additions & 0 deletions internal/integration/cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/failpoint"
"go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
Expand Down Expand Up @@ -303,6 +304,75 @@ func TestCursor(t *testing.T) {
batchSize = sizeVal.Int32()
assert.Equal(mt, int32(4), batchSize, "expected batchSize 4, got %v", batchSize)
})

tailableAwaitDataCursorOpts := mtest.NewOptions().MinServerVersion("4.4").
Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single)

mt.RunOpts("tailable awaitData cursor", tailableAwaitDataCursorOpts, func(mt *mtest.T) {
mt.Run("apply remaining timeoutMS if less than maxAwaitTimeMS", func(mt *mtest.T) {
initCollection(mt, mt.Coll)
mt.ClearEvents()

// Create a find cursor
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(100 * time.Millisecond)

cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
require.NoError(mt, err)

_ = mt.GetStartedEvent() // Empty find from started list.

defer cursor.Close(context.Background())

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

// Iterate twice to force a getMore
cursor.Next(ctx)
cursor.Next(ctx)

cmd := mt.GetStartedEvent().Command

maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
require.NoError(mt, err)

got, ok := maxTimeMSRaw.AsInt64OK()
require.True(mt, ok)

assert.LessOrEqual(mt, got, int64(50))
})

mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", tailableAwaitDataCursorOpts, func(mt *mtest.T) {
initCollection(mt, mt.Coll)
mt.ClearEvents()

// Create a find cursor
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond)

cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
require.NoError(mt, err)

_ = mt.GetStartedEvent() // Empty find from started list.

defer cursor.Close(context.Background())

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

// Iterate twice to force a getMore
cursor.Next(ctx)
cursor.Next(ctx)

cmd := mt.GetStartedEvent().Command

maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
require.NoError(mt, err)

got, ok := maxTimeMSRaw.AsInt64OK()
require.True(mt, ok)

assert.LessOrEqual(mt, got, int64(50))
})
})
}

type tryNextCursor interface {
Expand Down
15 changes: 15 additions & 0 deletions internal/integration/unified/collection_operation_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"context"
"errors"
"fmt"
"strings"
"time"

"go.mongodb.org/mongo-driver/v2/bson"
Expand Down Expand Up @@ -1485,6 +1486,20 @@ func createFindCursor(ctx context.Context, operation *operation) (*cursorResult,
opts.SetSkip(int64(val.Int32()))
case "sort":
opts.SetSort(val.Document())
case "timeoutMode":
return nil, newSkipTestError("timeoutMode is not supported")
case "cursorType":
switch strings.ToLower(val.StringValue()) {
case "tailable":
opts.SetCursorType(options.Tailable)
case "tailableawait":
opts.SetCursorType(options.TailableAwait)
case "nontailable":
opts.SetCursorType(options.NonTailable)
}
case "maxAwaitTimeMS":
maxAwaitTimeMS := time.Duration(val.Int32()) * time.Millisecond
opts.SetMaxAwaitTime(maxAwaitTimeMS)
default:
return nil, fmt.Errorf("unrecognized find option %q", key)
}
Expand Down
22 changes: 19 additions & 3 deletions internal/spectest/skip.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,10 @@ var skipTests = map[string][]string{
"TestUnifiedSpec/client-side-operations-timeout/tests/retryability-timeoutMS.json/operation_is_retried_multiple_times_for_non-zero_timeoutMS_-_aggregate_on_collection",
"TestUnifiedSpec/client-side-operations-timeout/tests/retryability-timeoutMS.json/operation_is_retried_multiple_times_for_non-zero_timeoutMS_-_aggregate_on_database",
"TestUnifiedSpec/client-side-operations-timeout/tests/gridfs-find.json/timeoutMS_applied_to_find_command",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_-_failure",
},

// TODO(GODRIVER-3411): Tests require "getMore" with "maxTimeMS" settings. Not
Expand Down Expand Up @@ -443,7 +447,6 @@ var skipTests = map[string][]string{
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/change_stream_can_be_iterated_again_if_previous_iteration_times_out",
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/timeoutMS_is_refreshed_for_getMore_-_failure",
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS",
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
},

// Unknown CSOT:
Expand Down Expand Up @@ -579,12 +582,10 @@ var skipTests = map[string][]string{
"TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-timeoutMS.json",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_timeoutMode_is_cursor_lifetime",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_-_failure",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_maxAwaitTimeMS_if_less_than_remaining_timeout",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-non-awaitData.json/error_if_timeoutMode_is_cursor_lifetime",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-non-awaitData.json/timeoutMS_applied_to_find",
Expand Down Expand Up @@ -814,6 +815,21 @@ var skipTests = map[string][]string{
"TestUnifiedSpec/transactions-convenient-api/tests/unified/transaction-options.json/withTransaction_explicit_transaction_options_override_client_options",
"TestUnifiedSpec/transactions-convenient-api/tests/unified/commit.json/withTransaction_commits_after_callback_returns",
},

// GODRIVER-3473: the implementation of DRIVERS-2868 makes it clear that the
// Go Driver does not correctly implement the following validation for
// tailable awaitData cursors:
//
// Drivers MUST error if this option is set, timeoutMS is set to a
// non-zero value, and maxAwaitTimeMS is greater than or equal to
// timeoutMS.
//
// Once GODRIVER-3473 is completed, we can continue running these tests.
"When constructing tailable awaitData cusors must validate, timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or equal to timeoutMS (GODRIVER-3473)": {
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
},
}

// CheckSkip checks if the fully-qualified test name matches a list of skipped test names for a given reason.
Expand Down
30 changes: 28 additions & 2 deletions x/mongo/driver/batch_cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,40 @@ func (bc *BatchCursor) getMore(ctx context.Context) {

bc.err = Operation{
CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) {
// If maxAwaitTime > remaining timeoutMS - minRoundTripTime, then use
// send remaining TimeoutMS - minRoundTripTime allowing the server an
// opportunity to respond with an empty batch.
var maxTimeMS int64
if bc.maxAwaitTime != nil {
_, ctxDeadlineSet := ctx.Deadline()

if ctxDeadlineSet {
rttMonitor := bc.Server().RTTMonitor()

var ok bool
maxTimeMS, ok = driverutil.CalculateMaxTimeMS(ctx, rttMonitor.Min())
if !ok && maxTimeMS <= 0 {
return nil, fmt.Errorf(
"calculated server-side timeout (%v ms) is less than or equal to 0 (%v): %w",
maxTimeMS,
rttMonitor.Stats(),
ErrDeadlineWouldBeExceeded)
}
}

if !ctxDeadlineSet || bc.maxAwaitTime.Milliseconds() < maxTimeMS {
maxTimeMS = bc.maxAwaitTime.Milliseconds()
}
}

dst = bsoncore.AppendInt64Element(dst, "getMore", bc.id)
dst = bsoncore.AppendStringElement(dst, "collection", bc.collection)
if numToReturn > 0 {
dst = bsoncore.AppendInt32Element(dst, "batchSize", numToReturn)
}

if bc.maxAwaitTime != nil && *bc.maxAwaitTime > 0 {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(*bc.maxAwaitTime)/int64(time.Millisecond))
if maxTimeMS > 0 {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", maxTimeMS)
}

comment, err := codecutil.MarshalValue(bc.comment, bc.encoderFn)
Expand Down
Loading
Loading