diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 332175496a..d77769d3da 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -240,9 +240,38 @@ functions: args: [*task-runner, pr-task] send-perf-data: - - command: perf.send + # Here we begin to generate the request to send the data to SPS + - command: shell.exec params: - file: src/go.mongodb.org/mongo-driver/perf.json + script: | + # We use the requester expansion to determine whether the data is from a mainline evergreen run or not + if [ "${requester}" == "commit" ]; then + is_mainline=true + else + is_mainline=false + fi + + # We parse the username out of the order_id as patches append that in and SPS does not need that information + parsed_order_id=$(echo "${revision_order_id}" | awk -F'_' '{print $NF}') + # Submit the performance data to the SPS endpoint + response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" -X 'POST' \ + "https://performance-monitoring-api.corp.mongodb.com/raw_perf_results/cedar_report?project=${project_id}&version=${version_id}&variant=${build_variant}&order=$parsed_order_id&task_name=${task_name}&task_id=${task_id}&execution=${execution}&mainline=$is_mainline" \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d @src/go.mongodb.org/mongo-driver/perf.json) + + http_status=$(echo "$response" | grep "HTTP_STATUS" | awk -F':' '{print $2}') + response_body=$(echo "$response" | sed '/HTTP_STATUS/d') + + # We want to throw an error if the data was not successfully submitted + if [ "$http_status" -ne 200 ]; then + echo "Error: Received HTTP status $http_status" + echo "Response Body: $response_body" + exit 1 + fi + + echo "Response Body: $response_body" + echo "HTTP Status: $http_status" run-enterprise-auth-tests: - command: ec2.assume_role diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000..1fad82a6a2 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @mongodb/dbx-go diff --git a/bson/benchmark_test.go b/bson/benchmark_test.go index c77f2dde6b..2b7d6d293a 100644 --- a/bson/benchmark_test.go +++ b/bson/benchmark_test.go @@ -19,6 +19,20 @@ import ( "testing" ) +var encodetestBsonD D + +func init() { + b, err := Marshal(encodetestInstance) + if err != nil { + panic(fmt.Sprintf("error marshling struct: %v", err)) + } + + err = Unmarshal(b, &encodetestBsonD) + if err != nil { + panic(fmt.Sprintf("error unmarshaling BSON: %v", err)) + } +} + type encodetest struct { Field1String string Field1Int64 int64 @@ -184,7 +198,7 @@ func readExtJSONFile(filename string) map[string]interface{} { func BenchmarkMarshal(b *testing.B) { cases := []struct { desc string - value interface{} + value any }{ { desc: "simple struct", @@ -194,6 +208,10 @@ func BenchmarkMarshal(b *testing.B) { desc: "nested struct", value: nestedInstance, }, + { + desc: "simple D", + value: encodetestBsonD, + }, { desc: "deep_bson.json.gz", value: readExtJSONFile("deep_bson.json.gz"), @@ -208,119 +226,211 @@ func BenchmarkMarshal(b *testing.B) { }, } - for _, tc := range cases { - b.Run(tc.desc, func(b *testing.B) { - b.Run("BSON", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _, err := Marshal(tc.value) - if err != nil { - b.Errorf("error marshalling BSON: %s", err) + b.Run("BSON", func(b *testing.B) { + for _, tc := range cases { + tc := tc // Capture range variable. + + b.Run(tc.desc, func(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := Marshal(tc.value) + if err != nil { + b.Errorf("error marshalling BSON: %s", err) + } } - } + }) }) + } + }) - b.Run("extJSON", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _, err := MarshalExtJSON(tc.value, true, false) - if err != nil { - b.Errorf("error marshalling extended JSON: %s", err) + b.Run("extJSON", func(b *testing.B) { + for _, tc := range cases { + tc := tc // Capture range variable. + + b.Run(tc.desc, func(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := MarshalExtJSON(tc.value, true, false) + if err != nil { + b.Errorf("error marshalling extended JSON: %s", err) + } } - } + }) }) + } + }) - b.Run("JSON", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _, err := json.Marshal(tc.value) - if err != nil { - b.Errorf("error marshalling JSON: %s", err) + b.Run("JSON", func(b *testing.B) { + for _, tc := range cases { + tc := tc // Capture range variable. + + b.Run(tc.desc, func(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := json.Marshal(tc.value) + if err != nil { + b.Errorf("error marshalling JSON: %s", err) + } } - } + }) }) - }) - } + } + }) } func BenchmarkUnmarshal(b *testing.B) { - cases := []struct { + type testcase struct { desc string - value interface{} - }{ + value any + dst func() any + } + + cases := []testcase{ { desc: "simple struct", value: encodetestInstance, + dst: func() any { return &encodetest{} }, }, { desc: "nested struct", value: nestedInstance, + dst: func() any { return &encodetest{} }, }, + } + + inputs := []struct { + name string + value any + }{ { - desc: "deep_bson.json.gz", + name: "simple", + value: encodetestInstance, + }, + { + name: "nested", + value: nestedInstance, + }, + { + name: "deep_bson.json.gz", value: readExtJSONFile("deep_bson.json.gz"), }, { - desc: "flat_bson.json.gz", + name: "flat_bson.json.gz", value: readExtJSONFile("flat_bson.json.gz"), }, { - desc: "full_bson.json.gz", + name: "full_bson.json.gz", value: readExtJSONFile("full_bson.json.gz"), }, } - for _, tc := range cases { - b.Run(tc.desc, func(b *testing.B) { - b.Run("BSON", func(b *testing.B) { + destinations := []struct { + name string + dst func() any + }{ + { + name: "to map", + dst: func() any { return &map[string]any{} }, + }, + { + name: "to D", + dst: func() any { return &D{} }, + }, + } + + for _, input := range inputs { + for _, dest := range destinations { + cases = append(cases, testcase{ + desc: input.name + " " + dest.name, + value: input.value, + dst: dest.dst, + }) + } + } + + b.Run("BSON", func(b *testing.B) { + for _, tc := range cases { + tc := tc // Capture range variable. + + b.Run(tc.desc, func(b *testing.B) { + b.ReportAllocs() data, err := Marshal(tc.value) if err != nil { b.Errorf("error marshalling BSON: %s", err) return } + b.SetBytes(int64(len(data))) b.ResetTimer() - var v2 map[string]interface{} - for i := 0; i < b.N; i++ { - err := Unmarshal(data, &v2) - if err != nil { - b.Errorf("error unmarshalling BSON: %s", err) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + val := tc.dst() + err := Unmarshal(data, val) + if err != nil { + b.Errorf("error unmarshalling BSON: %s", err) + } } - } + }) }) + } + }) + + b.Run("extJSON", func(b *testing.B) { + for _, tc := range cases { + tc := tc // Capture range variable. - b.Run("extJSON", func(b *testing.B) { + b.Run(tc.desc, func(b *testing.B) { + b.ReportAllocs() data, err := MarshalExtJSON(tc.value, true, false) if err != nil { b.Errorf("error marshalling extended JSON: %s", err) return } + b.SetBytes(int64(len(data))) b.ResetTimer() - var v2 map[string]interface{} - for i := 0; i < b.N; i++ { - err := UnmarshalExtJSON(data, true, &v2) - if err != nil { - b.Errorf("error unmarshalling extended JSON: %s", err) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + val := tc.dst() + err := UnmarshalExtJSON(data, true, val) + if err != nil { + b.Errorf("error unmarshalling extended JSON: %s", err) + } } - } + }) }) + } + }) + + b.Run("JSON", func(b *testing.B) { + for _, tc := range cases { + tc := tc // Capture range variable. - b.Run("JSON", func(b *testing.B) { + b.Run(tc.desc, func(b *testing.B) { + b.ReportAllocs() data, err := json.Marshal(tc.value) if err != nil { b.Errorf("error marshalling JSON: %s", err) return } + b.SetBytes(int64(len(data))) b.ResetTimer() - var v2 map[string]interface{} - for i := 0; i < b.N; i++ { - err := json.Unmarshal(data, &v2) - if err != nil { - b.Errorf("error unmarshalling JSON: %s", err) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + val := tc.dst() + err := json.Unmarshal(data, val) + if err != nil { + b.Errorf("error unmarshalling JSON: %s", err) + } } - } + }) }) - }) - } + } + }) } // The following benchmarks are copied from the Go standard library's @@ -389,13 +499,13 @@ func codeInit() { } func BenchmarkCodeUnmarshal(b *testing.B) { - b.ReportAllocs() if codeJSON == nil { b.StopTimer() codeInit() b.StartTimer() } b.Run("BSON", func(b *testing.B) { + b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { var r codeResponse @@ -407,6 +517,7 @@ func BenchmarkCodeUnmarshal(b *testing.B) { b.SetBytes(int64(len(codeBSON))) }) b.Run("JSON", func(b *testing.B) { + b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { var r codeResponse @@ -420,13 +531,13 @@ func BenchmarkCodeUnmarshal(b *testing.B) { } func BenchmarkCodeMarshal(b *testing.B) { - b.ReportAllocs() if codeJSON == nil { b.StopTimer() codeInit() b.StartTimer() } b.Run("BSON", func(b *testing.B) { + b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { if _, err := Marshal(&codeStruct); err != nil { @@ -437,6 +548,7 @@ func BenchmarkCodeMarshal(b *testing.B) { b.SetBytes(int64(len(codeBSON))) }) b.Run("JSON", func(b *testing.B) { + b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { if _, err := json.Marshal(&codeStruct); err != nil { diff --git a/bson/bson_binary_vector_spec_test.go b/bson/bson_binary_vector_spec_test.go index 56516c61b6..33b5af595d 100644 --- a/bson/bson_binary_vector_spec_test.go +++ b/bson/bson_binary_vector_spec_test.go @@ -9,7 +9,6 @@ package bson import ( "encoding/hex" "encoding/json" - "math" "os" "path" "testing" @@ -27,13 +26,13 @@ type bsonBinaryVectorTests struct { } type bsonBinaryVectorTestCase struct { - Description string `json:"description"` - Valid bool `json:"valid"` - Vector []interface{} `json:"vector"` - DtypeHex string `json:"dtype_hex"` - DtypeAlias string `json:"dtype_alias"` - Padding int `json:"padding"` - CanonicalBson string `json:"canonical_bson"` + Description string `json:"description"` + Valid bool `json:"valid"` + Vector json.RawMessage `json:"vector"` + DtypeHex string `json:"dtype_hex"` + DtypeAlias string `json:"dtype_alias"` + Padding int `json:"padding"` + CanonicalBson string `json:"canonical_bson"` } func TestBsonBinaryVectorSpec(t *testing.T) { @@ -83,21 +82,19 @@ func TestBsonBinaryVectorSpec(t *testing.T) { }) } -func convertSlice[T int8 | float32 | byte](s []interface{}) []T { +func decodeTestSlice[T int8 | float32 | byte](t *testing.T, data []byte) []T { + t.Helper() + + if len(data) == 0 { + return nil + } + var s []float64 + err := UnmarshalExtJSON(data, true, &s) + require.NoError(t, err) + v := make([]T, len(s)) for i, e := range s { - f := math.NaN() - switch val := e.(type) { - case float64: - f = val - case string: - if val == "inf" { - f = math.Inf(0) - } else if val == "-inf" { - f = math.Inf(-1) - } - } - v[i] = T(f) + v[i] = T(e) } return v } @@ -108,17 +105,17 @@ func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVector case "0x03": testVector[testKey] = Vector{ dType: Int8Vector, - int8Data: convertSlice[int8](test.Vector), + int8Data: decodeTestSlice[int8](t, test.Vector), } case "0x27": testVector[testKey] = Vector{ dType: Float32Vector, - float32Data: convertSlice[float32](test.Vector), + float32Data: decodeTestSlice[float32](t, test.Vector), } case "0x10": testVector[testKey] = Vector{ dType: PackedBitVector, - bitData: convertSlice[byte](test.Vector), + bitData: decodeTestSlice[byte](t, test.Vector), bitPadding: uint8(test.Padding), } default: diff --git a/bson/marshal.go b/bson/marshal.go index 21631d8156..88a85c0ffb 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -74,7 +74,10 @@ func Marshal(val interface{}) ([]byte, error) { } }() sw.Reset() - vw := NewDocumentWriter(sw) + + vw := getDocumentWriter(sw) + defer putDocumentWriter(vw) + enc := encPool.Get().(*Encoder) defer encPool.Put(enc) enc.Reset(vw) diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 72870c10ab..52bd94fed7 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -42,7 +42,9 @@ type ValueUnmarshaler interface { // When unmarshaling BSON, if the BSON value is null and the Go value is a // pointer, the pointer is set to nil without calling UnmarshalBSONValue. func Unmarshal(data []byte, val interface{}) error { - vr := newDocumentReader(bytes.NewReader(data)) + vr := getDocumentReader(bytes.NewReader(data)) + defer putDocumentReader(vr) + if l, err := vr.peekLength(); err != nil { return err } else if int(l) != len(data) { diff --git a/bson/value_reader.go b/bson/value_reader.go index 678c47b106..d713e669a9 100644 --- a/bson/value_reader.go +++ b/bson/value_reader.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "math" + "sync" ) var _ ValueReader = &valueReader{} @@ -29,6 +30,20 @@ type vrState struct { end int64 } +var bufioReaderPool = sync.Pool{ + New: func() interface{} { + return bufio.NewReader(nil) + }, +} + +var vrPool = sync.Pool{ + New: func() interface{} { + return &valueReader{ + stack: make([]vrState, 1, 5), + } + }, +} + // valueReader is for reading BSON values. type valueReader struct { r *bufio.Reader @@ -38,6 +53,33 @@ type valueReader struct { frame int64 } +func getDocumentReader(r io.Reader) *valueReader { + vr := vrPool.Get().(*valueReader) + + vr.offset = 0 + vr.frame = 0 + + vr.stack = vr.stack[:1] + vr.stack[0] = vrState{mode: mTopLevel} + + br := bufioReaderPool.Get().(*bufio.Reader) + br.Reset(r) + vr.r = br + + return vr +} + +func putDocumentReader(vr *valueReader) { + if vr == nil { + return + } + + bufioReaderPool.Put(vr.r) + vr.r = nil + + vrPool.Put(vr) +} + // NewDocumentReader returns a ValueReader using b for the underlying BSON // representation. func NewDocumentReader(r io.Reader) ValueReader { @@ -253,14 +295,28 @@ func (vr *valueReader) appendNextElement(dst []byte) ([]byte, error) { return nil, err } - buf := make([]byte, length) - _, err = io.ReadFull(vr.r, buf) + buf, err := vr.r.Peek(int(length)) if err != nil { + if err == bufio.ErrBufferFull { + temp := make([]byte, length) + if _, err = io.ReadFull(vr.r, temp); err != nil { + return nil, err + } + dst = append(dst, temp...) + vr.offset += int64(len(temp)) + return dst, nil + } + return nil, err } + dst = append(dst, buf...) - vr.offset += int64(len(buf)) - return dst, err + if _, err = vr.r.Discard(int(length)); err != nil { + return nil, err + } + + vr.offset += int64(length) + return dst, nil } func (vr *valueReader) readValueBytes(dst []byte) (Type, []byte, error) { diff --git a/bson/value_writer.go b/bson/value_writer.go index 57334a925d..2748696fbe 100644 --- a/bson/value_writer.go +++ b/bson/value_writer.go @@ -33,6 +33,29 @@ func putValueWriter(vw *valueWriter) { } } +var documentWriterPool = sync.Pool{ + New: func() interface{} { + return newDocumentWriter(nil) + }, +} + +func getDocumentWriter(w io.Writer) *valueWriter { + vw := documentWriterPool.Get().(*valueWriter) + + vw.reset(vw.buf) + vw.buf = vw.buf[:0] + vw.w = w + + return vw +} + +func putDocumentWriter(vw *valueWriter) { + if vw != nil { + vw.w = nil // don't leak the writer + documentWriterPool.Put(vw) + } +} + // This is here so that during testing we can change it and not require // allocating a 4GB slice. var maxSize = math.MaxInt32 diff --git a/internal/cmd/compilecheck/main.go b/internal/cmd/compilecheck/main.go index 6a56997de2..1661c6e097 100644 --- a/internal/cmd/compilecheck/main.go +++ b/internal/cmd/compilecheck/main.go @@ -12,9 +12,13 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/xoptions" ) func main() { + opts := options.Client() + xoptions.SetInternalClientOptions(opts, "foo", "bar") + _, _ = mongo.Connect(options.Client()) fmt.Println(bson.D{{Key: "key", Value: "value"}}) } diff --git a/internal/driverutil/operation.go b/internal/driverutil/operation.go index e37cba5903..74142a56e8 100644 --- a/internal/driverutil/operation.go +++ b/internal/driverutil/operation.go @@ -6,6 +6,12 @@ package driverutil +import ( + "context" + "math" + "time" +) + // Operation Names should be sourced from the command reference documentation: // https://www.mongodb.com/docs/manual/reference/command/ const ( @@ -30,3 +36,34 @@ const ( UpdateOp = "update" // UpdateOp is the name for updating BulkWriteOp = "bulkWrite" // BulkWriteOp is the name for client-level bulk write ) + +// CalculateMaxTimeMS calculates the maxTimeMS value to send to the server +// based on the context deadline and the minimum round trip time. If the +// calculated maxTimeMS is likely to cause a socket timeout, then this function +// will return 0 and false. +func CalculateMaxTimeMS(ctx context.Context, rttMin time.Duration) (int64, bool) { + deadline, ok := ctx.Deadline() + if !ok { + return 0, true + } + + remainingTimeout := time.Until(deadline) + + // Always round up to the next millisecond value so we never truncate the calculated + // maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms). + maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond) + if maxTimeMS <= 0 { + return 0, false + } + + // The server will return a "BadValue" error if maxTimeMS is greater + // than the maximum positive int32 value (about 24.9 days). If the + // user specified a timeout value greater than that, omit maxTimeMS + // and let the client-side timeout handle cancelling the op if the + // timeout is ever reached. + if maxTimeMS > math.MaxInt32 { + return 0, true + } + + return maxTimeMS, true +} diff --git a/internal/driverutil/operation_test.go b/internal/driverutil/operation_test.go new file mode 100644 index 0000000000..474c3e1aa1 --- /dev/null +++ b/internal/driverutil/operation_test.go @@ -0,0 +1,113 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driverutil + +import ( + "context" + "math" + "testing" + "time" + + "go.mongodb.org/mongo-driver/v2/internal/assert" +) + +func TestCalculateMaxTimeMS(t *testing.T) { + tests := []struct { + name string + ctx context.Context + rttMin time.Duration + wantZero bool + wantOk bool + wantPositive bool + wantExact int64 + }{ + { + name: "no deadline", + ctx: context.Background(), + rttMin: 10 * time.Millisecond, + wantZero: true, + wantOk: true, + wantPositive: false, + }, + { + name: "deadline expired", + ctx: func() context.Context { + ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) //nolint:govet + return ctx + }(), + wantZero: true, + wantOk: false, + wantPositive: false, + }, + { + name: "remaining timeout < rttMin", + ctx: func() context.Context { + ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(1*time.Millisecond)) //nolint:govet + return ctx + }(), + rttMin: 10 * time.Millisecond, + wantZero: true, + wantOk: false, + wantPositive: false, + }, + { + name: "normal positive result", + ctx: func() context.Context { + ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) //nolint:govet + return ctx + }(), + wantZero: false, + wantOk: true, + wantPositive: true, + }, + { + name: "beyond maxInt32", + ctx: func() context.Context { + dur := time.Now().Add(time.Duration(math.MaxInt32+1000) * time.Millisecond) + ctx, _ := context.WithDeadline(context.Background(), dur) //nolint:govet + return ctx + }(), + wantZero: true, + wantOk: true, + wantPositive: false, + }, + { + name: "round up to 1ms", + ctx: func() context.Context { + ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(999*time.Microsecond)) //nolint:govet + return ctx + }(), + wantOk: true, + wantExact: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := CalculateMaxTimeMS(tt.ctx, tt.rttMin) + + assert.Equal(t, tt.wantOk, got1) + + if tt.wantExact > 0 && got != tt.wantExact { + t.Errorf("CalculateMaxTimeMS() got = %v, want %v", got, tt.wantExact) + } + + if tt.wantZero && got != 0 { + t.Errorf("CalculateMaxTimeMS() got = %v, want 0", got) + } + + if !tt.wantZero && got == 0 { + t.Errorf("CalculateMaxTimeMS() got = %v, want > 0", got) + } + + if !tt.wantZero && tt.wantPositive && got <= 0 { + t.Errorf("CalculateMaxTimeMS() got = %v, want > 0", got) + } + }) + } + +} diff --git a/internal/integration/cursor_test.go b/internal/integration/cursor_test.go index 5ee9986ec2..6376e78e74 100644 --- a/internal/integration/cursor_test.go +++ b/internal/integration/cursor_test.go @@ -17,6 +17,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/failpoint" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" ) @@ -303,6 +304,75 @@ func TestCursor(t *testing.T) { batchSize = sizeVal.Int32() assert.Equal(mt, int32(4), batchSize, "expected batchSize 4, got %v", batchSize) }) + + tailableAwaitDataCursorOpts := mtest.NewOptions().MinServerVersion("4.4"). + Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single) + + mt.RunOpts("tailable awaitData cursor", tailableAwaitDataCursorOpts, func(mt *mtest.T) { + mt.Run("apply remaining timeoutMS if less than maxAwaitTimeMS", func(mt *mtest.T) { + initCollection(mt, mt.Coll) + mt.ClearEvents() + + // Create a find cursor + opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(100 * time.Millisecond) + + cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts) + require.NoError(mt, err) + + _ = mt.GetStartedEvent() // Empty find from started list. + + defer cursor.Close(context.Background()) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + // Iterate twice to force a getMore + cursor.Next(ctx) + cursor.Next(ctx) + + cmd := mt.GetStartedEvent().Command + + maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS") + require.NoError(mt, err) + + got, ok := maxTimeMSRaw.AsInt64OK() + require.True(mt, ok) + + assert.LessOrEqual(mt, got, int64(50)) + }) + + mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", tailableAwaitDataCursorOpts, func(mt *mtest.T) { + initCollection(mt, mt.Coll) + mt.ClearEvents() + + // Create a find cursor + opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond) + + cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts) + require.NoError(mt, err) + + _ = mt.GetStartedEvent() // Empty find from started list. + + defer cursor.Close(context.Background()) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Iterate twice to force a getMore + cursor.Next(ctx) + cursor.Next(ctx) + + cmd := mt.GetStartedEvent().Command + + maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS") + require.NoError(mt, err) + + got, ok := maxTimeMSRaw.AsInt64OK() + require.True(mt, ok) + + assert.LessOrEqual(mt, got, int64(50)) + }) + }) } type tryNextCursor interface { diff --git a/internal/integration/unified/client_operation_execution.go b/internal/integration/unified/client_operation_execution.go index 2bb2ec9326..75948ff8a0 100644 --- a/internal/integration/unified/client_operation_execution.go +++ b/internal/integration/unified/client_operation_execution.go @@ -245,10 +245,11 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati if errors.As(err, &bwe) { res = bwe.PartialResult } - if res == nil || !res.Acknowledged { + if res == nil { return newDocumentResult(emptyCoreDocument, err), nil } rawBuilder := bsoncore.NewDocumentBuilder(). + AppendBoolean("acknowledged", res.Acknowledged). AppendInt64("deletedCount", res.DeletedCount). AppendInt64("insertedCount", res.InsertedCount). AppendInt64("matchedCount", res.MatchedCount). diff --git a/internal/integration/unified/collection_operation_execution.go b/internal/integration/unified/collection_operation_execution.go index 6c8b38145a..c3e7040256 100644 --- a/internal/integration/unified/collection_operation_execution.go +++ b/internal/integration/unified/collection_operation_execution.go @@ -10,6 +10,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "go.mongodb.org/mongo-driver/v2/bson" @@ -1485,6 +1486,20 @@ func createFindCursor(ctx context.Context, operation *operation) (*cursorResult, opts.SetSkip(int64(val.Int32())) case "sort": opts.SetSort(val.Document()) + case "timeoutMode": + return nil, newSkipTestError("timeoutMode is not supported") + case "cursorType": + switch strings.ToLower(val.StringValue()) { + case "tailable": + opts.SetCursorType(options.Tailable) + case "tailableawait": + opts.SetCursorType(options.TailableAwait) + case "nontailable": + opts.SetCursorType(options.NonTailable) + } + case "maxAwaitTimeMS": + maxAwaitTimeMS := time.Duration(val.Int32()) * time.Millisecond + opts.SetMaxAwaitTime(maxAwaitTimeMS) default: return nil, fmt.Errorf("unrecognized find option %q", key) } diff --git a/internal/options/options.go b/internal/options/options.go new file mode 100644 index 0000000000..8d5f47f422 --- /dev/null +++ b/internal/options/options.go @@ -0,0 +1,45 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package options + +// Options stores internal options. +type Options struct { + values map[string]any +} + +// WithValue sets an option value with the associated key. +func WithValue(opts Options, key string, option any) Options { + if opts.values == nil { + opts.values = make(map[string]any) + } + opts.values[key] = option + return opts +} + +// Value returns the value associated with the options for key. +func Value(opts Options, key string) any { + if opts.values == nil { + return nil + } + if val, ok := opts.values[key]; ok { + return val + } + return nil +} + +// Equal compares two Options instances for equality. +func Equal(opts1, opts2 Options) bool { + if len(opts1.values) != len(opts2.values) { + return false + } + for key, val1 := range opts1.values { + if val2, ok := opts2.values[key]; !ok || val1 != val2 { + return false + } + } + return true +} diff --git a/internal/spectest/skip.go b/internal/spectest/skip.go index 858b92138e..c3c5d04fe3 100644 --- a/internal/spectest/skip.go +++ b/internal/spectest/skip.go @@ -11,6 +11,11 @@ import "testing" // skipTests is a map of "fully-qualified test name" to "the reason for skipping // the test". var skipTests = map[string][]string{ + // TODO(GODRIVER-3518): Test flexible numeric comparisons with $$lte + "Modifies $$lte operator test to also use floating point and Int64 types (GODRIVER-3518)": { + "TestUnifiedSpec/unified-test-format/tests/valid-pass/operator-lte.json/special_lte_matching_operator", + }, + // SPEC-1403: This test checks to see if the correct error is thrown when auto // encrypting with a server < 4.2. Currently, the test will fail because a // server < 4.2 wouldn't have mongocryptd, so Client construction would fail @@ -341,6 +346,10 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/client-side-operations-timeout/tests/retryability-timeoutMS.json/operation_is_retried_multiple_times_for_non-zero_timeoutMS_-_aggregate_on_collection", "TestUnifiedSpec/client-side-operations-timeout/tests/retryability-timeoutMS.json/operation_is_retried_multiple_times_for_non-zero_timeoutMS_-_aggregate_on_database", "TestUnifiedSpec/client-side-operations-timeout/tests/gridfs-find.json/timeoutMS_applied_to_find_command", + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find", + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set", + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set", + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_-_failure", }, // TODO(GODRIVER-3411): Tests require "getMore" with "maxTimeMS" settings. Not @@ -443,7 +452,6 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/change_stream_can_be_iterated_again_if_previous_iteration_times_out", "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/timeoutMS_is_refreshed_for_getMore_-_failure", "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS", - "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", }, // Unknown CSOT: @@ -579,12 +587,10 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-timeoutMS.json", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_timeoutMode_is_cursor_lifetime", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS", - "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_-_failure", - "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_maxAwaitTimeMS_if_less_than_remaining_timeout", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-non-awaitData.json/error_if_timeoutMode_is_cursor_lifetime", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-non-awaitData.json/timeoutMS_applied_to_find", @@ -814,6 +820,21 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/transactions-convenient-api/tests/unified/transaction-options.json/withTransaction_explicit_transaction_options_override_client_options", "TestUnifiedSpec/transactions-convenient-api/tests/unified/commit.json/withTransaction_commits_after_callback_returns", }, + + // GODRIVER-3473: the implementation of DRIVERS-2868 makes it clear that the + // Go Driver does not correctly implement the following validation for + // tailable awaitData cursors: + // + // Drivers MUST error if this option is set, timeoutMS is set to a + // non-zero value, and maxAwaitTimeMS is greater than or equal to + // timeoutMS. + // + // Once GODRIVER-3473 is completed, we can continue running these tests. + "When constructing tailable awaitData cusors must validate, timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or equal to timeoutMS (GODRIVER-3473)": { + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS", + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", + "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", + }, } // CheckSkip checks if the fully-qualified test name matches a list of skipped test names for a given reason. diff --git a/internal/test/compilecheck/go.mod b/internal/test/compilecheck/go.mod index e36075c668..d1d57257f4 100644 --- a/internal/test/compilecheck/go.mod +++ b/internal/test/compilecheck/go.mod @@ -56,8 +56,9 @@ require ( go.opentelemetry.io/otel v1.24.0 // indirect go.opentelemetry.io/otel/metric v1.24.0 // indirect go.opentelemetry.io/otel/trace v1.24.0 // indirect - golang.org/x/crypto v0.31.0 // indirect - golang.org/x/sys v0.28.0 // indirect + golang.org/x/crypto v0.36.0 // indirect + golang.org/x/net v0.38.0 // indirect + golang.org/x/sys v0.31.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/internal/test/compilecheck/go.sum b/internal/test/compilecheck/go.sum index 24ab8ad6f0..566e970ebe 100644 --- a/internal/test/compilecheck/go.sum +++ b/internal/test/compilecheck/go.sum @@ -135,8 +135,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= @@ -145,8 +145,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= -golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -160,14 +160,14 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= -golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= +golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 925325c6c5..649d6a8e3d 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -26,6 +26,7 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/httputil" + "go.mongodb.org/mongo-driver/v2/internal/options" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" "go.mongodb.org/mongo-driver/v2/mongo/readpref" "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" @@ -286,14 +287,22 @@ type ClientOptions struct { // encryption. // // Deprecated: This option is for internal use only and should not be set (see GODRIVER-2149). It may be - // changed or removed in any release. + // changed in any release. This option will be removed in 3.0 and replaced with the Custom options.Options + // pattern: SetInternalClientOptions(clientOptions, "crypt", myCrypt) Crypt driver.Crypt // Deployment specifies a custom deployment to use for the new Client. // + // Deprecated: This option is for internal use only and should not be set. It may be changed in any release. + // This option will be removed in 3.0 and replaced with the Custom options.Options pattern: + // SetInternalClientOptions(clientOptions, "deployment", myDeployment) + Deployment driver.Deployment + + // Custom specifies internal options for the new Client. + // // Deprecated: This option is for internal use only and should not be set. It may be changed or removed in any // release. - Deployment driver.Deployment + Custom options.Options connString *connstring.ConnString err error diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index 907584d5f0..a8a00122e0 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -27,6 +27,7 @@ import ( "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/httputil" + "go.mongodb.org/mongo-driver/v2/internal/options" "go.mongodb.org/mongo-driver/v2/internal/ptrutil" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" "go.mongodb.org/mongo-driver/v2/mongo/readpref" @@ -156,6 +157,7 @@ func TestClientOptions(t *testing.T) { cmp.Comparer(func(r1, r2 *bson.Registry) bool { return r1 == r2 }), cmp.Comparer(func(cfg1, cfg2 *tls.Config) bool { return cfg1 == cfg2 }), cmp.Comparer(func(fp1, fp2 *event.PoolMonitor) bool { return fp1 == fp2 }), + cmp.Comparer(options.Equal), cmp.AllowUnexported(ClientOptions{}), cmpopts.IgnoreFields(http.Client{}, "Transport"), ); diff != "" { @@ -1253,6 +1255,7 @@ func TestApplyURI(t *testing.T) { cmp.Comparer(func(r1, r2 *bson.Registry) bool { return r1 == r2 }), cmp.Comparer(compareTLSConfig), cmp.Comparer(compareErrors), + cmp.Comparer(options.Equal), cmpopts.SortSlices(stringLess), cmpopts.IgnoreFields(connstring.ConnString{}, "SSLClientCertificateKeyPassword"), cmpopts.IgnoreFields(http.Client{}, "Transport"), diff --git a/testdata/bson-binary-vector/float32.json b/testdata/bson-binary-vector/float32.json index 845f504ff3..0bc88fc65a 100644 --- a/testdata/bson-binary-vector/float32.json +++ b/testdata/bson-binary-vector/float32.json @@ -32,7 +32,7 @@ { "description": "Infinity Vector FLOAT32", "valid": true, - "vector": ["-inf", 0.0, "inf"], + "vector": [{"$numberDouble": "-Infinity"}, 0.0, {"$numberDouble": "Infinity"}], "dtype_hex": "0x27", "dtype_alias": "FLOAT32", "padding": 0, diff --git a/testdata/specifications b/testdata/specifications index 7891688772..0c41c8b283 160000 --- a/testdata/specifications +++ b/testdata/specifications @@ -1 +1 @@ -Subproject commit 78916887729bf0d3088541bd9a7390770fae5f77 +Subproject commit 0c41c8b28321e093e7e156c72abe65744cc1b467 diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index f444739661..6d6cd211a5 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -381,14 +381,40 @@ func (bc *BatchCursor) getMore(ctx context.Context) { bc.err = Operation{ CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) { + // If maxAwaitTime > remaining timeoutMS - minRoundTripTime, then use + // send remaining TimeoutMS - minRoundTripTime allowing the server an + // opportunity to respond with an empty batch. + var maxTimeMS int64 + if bc.maxAwaitTime != nil { + _, ctxDeadlineSet := ctx.Deadline() + + if ctxDeadlineSet { + rttMonitor := bc.Server().RTTMonitor() + + var ok bool + maxTimeMS, ok = driverutil.CalculateMaxTimeMS(ctx, rttMonitor.Min()) + if !ok && maxTimeMS <= 0 { + return nil, fmt.Errorf( + "calculated server-side timeout (%v ms) is less than or equal to 0 (%v): %w", + maxTimeMS, + rttMonitor.Stats(), + ErrDeadlineWouldBeExceeded) + } + } + + if !ctxDeadlineSet || bc.maxAwaitTime.Milliseconds() < maxTimeMS { + maxTimeMS = bc.maxAwaitTime.Milliseconds() + } + } + dst = bsoncore.AppendInt64Element(dst, "getMore", bc.id) dst = bsoncore.AppendStringElement(dst, "collection", bc.collection) if numToReturn > 0 { dst = bsoncore.AppendInt32Element(dst, "batchSize", numToReturn) } - if bc.maxAwaitTime != nil && *bc.maxAwaitTime > 0 { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(*bc.maxAwaitTime)/int64(time.Millisecond)) + if maxTimeMS > 0 { + dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", maxTimeMS) } comment, err := codecutil.MarshalValue(bc.comment, bc.encoderFn) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 2597a5de66..50136456e4 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -1724,34 +1724,16 @@ func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration return 0, nil } - deadline, ok := ctx.Deadline() - if !ok { - return 0, nil - } - - remainingTimeout := time.Until(deadline) - - // Always round up to the next millisecond value so we never truncate the calculated - // maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms). - maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond) - if maxTimeMS <= 0 { + // Calculate maxTimeMS value to potentially be appended to the wire message. + maxTimeMS, ok := driverutil.CalculateMaxTimeMS(ctx, rttMin) + if !ok && maxTimeMS <= 0 { return 0, fmt.Errorf( - "remaining time %v until context deadline is less than or equal to min network round-trip time %v (%v): %w", - remainingTimeout, - rttMin, + "calculated server-side timeout (%v ms) is less than or equal to 0 (%v): %w", + maxTimeMS, rttStats, ErrDeadlineWouldBeExceeded) } - // The server will return a "BadValue" error if maxTimeMS is greater - // than the maximum positive int32 value (about 24.9 days). If the - // user specified a timeout value greater than that, omit maxTimeMS - // and let the client-side timeout handle cancelling the op if the - // timeout is ever reached. - if maxTimeMS > math.MaxInt32 { - return 0, nil - } - return maxTimeMS, nil } diff --git a/x/mongo/driver/xoptions/options.go b/x/mongo/driver/xoptions/options.go new file mode 100644 index 0000000000..6eb1bd0dc2 --- /dev/null +++ b/x/mongo/driver/xoptions/options.go @@ -0,0 +1,38 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package xoptions + +import ( + "fmt" + + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" +) + +// SetInternalClientOptions sets internal options for ClientOptions. +// +// Deprecated: This function is for internal use only. It may be changed or removed in any release. +func SetInternalClientOptions(opts *options.ClientOptions, key string, option any) error { + const typeErr = "unexpected type for %s" + switch key { + case "crypt": + c, ok := option.(driver.Crypt) + if !ok { + return fmt.Errorf(typeErr, key) + } + opts.Crypt = c + case "deployment": + d, ok := option.(driver.Deployment) + if !ok { + return fmt.Errorf(typeErr, key) + } + opts.Deployment = d + default: + return fmt.Errorf("unsupported option: %s", key) + } + return nil +} diff --git a/x/mongo/driver/xoptions/options_test.go b/x/mongo/driver/xoptions/options_test.go new file mode 100644 index 0000000000..b459ec8ada --- /dev/null +++ b/x/mongo/driver/xoptions/options_test.go @@ -0,0 +1,64 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package xoptions + +import ( + "testing" + + "go.mongodb.org/mongo-driver/v2/internal/require" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest" +) + +func TestSetInternalClientOptions(t *testing.T) { + t.Parallel() + + t.Run("set crypt", func(t *testing.T) { + t.Parallel() + + c := driver.NewCrypt(&driver.CryptOptions{}) + opts := options.Client() + err := SetInternalClientOptions(opts, "crypt", c) + require.NoError(t, err, "error setting crypt: %v", err) + require.Equal(t, c, opts.Crypt, "expected %v, got %v", c, opts.Crypt) + }) + + t.Run("set crypt - wrong type", func(t *testing.T) { + t.Parallel() + + opts := options.Client() + err := SetInternalClientOptions(opts, "crypt", &drivertest.MockDeployment{}) + require.EqualError(t, err, "unexpected type for crypt") + }) + + t.Run("set deployment", func(t *testing.T) { + t.Parallel() + + d := &drivertest.MockDeployment{} + opts := options.Client() + err := SetInternalClientOptions(opts, "deployment", d) + require.NoError(t, err, "error setting deployment: %v", err) + require.Equal(t, d, opts.Deployment, "expected %v, got %v", d, opts.Deployment) + }) + + t.Run("set deployment - wrong type", func(t *testing.T) { + t.Parallel() + + opts := options.Client() + err := SetInternalClientOptions(opts, "deployment", driver.NewCrypt(&driver.CryptOptions{})) + require.EqualError(t, err, "unexpected type for deployment") + }) + + t.Run("set unsupported option", func(t *testing.T) { + t.Parallel() + + opts := options.Client() + err := SetInternalClientOptions(opts, "unsupported", "unsupported") + require.EqualError(t, err, "unsupported option: unsupported") + }) +}