diff --git a/.golangci.yml b/.golangci.yml index b3c66c4d7d..89f67d25ae 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,16 +1,9 @@ -run: - timeout: 5m - +version: "2" linters: - disable-all: true - # TODO(GODRIVER-2156): Enable all commented-out linters. + default: none enable: - errcheck - # - errorlint - - exportloopref - gocritic - - goimports - - gosimple - gosec - govet - ineffassign @@ -21,95 +14,129 @@ linters: - prealloc - revive - staticcheck - - typecheck - - unused - unconvert - unparam + - unused + settings: + errcheck: + exclude-functions: + - .errcheck-excludes + govet: + disable: + - cgocall + - composites + paralleltest: + # Ignore missing calls to `t.Parallel()` and only report incorrect uses of + # `t.Parallel()`. + ignore-missing: true + staticcheck: + checks: + - all + # Disable deprecation warnings for now. + - -SA1012 + # Disable "do not pass a nil Context" to allow testing nil contexts in + # tests. + - -SA1019 + exclusions: + generated: lax + rules: + # Ignore some linters for example code that is intentionally simplified. + - linters: + - errcheck + - revive + path: examples/ + # Disable "unused" linter for code files that depend on the + # "mongocrypt.MongoCrypt" type because the linter build doesn't work + # correctly with CGO enabled. As a result, all calls to a + # "mongocrypt.MongoCrypt" API appear to always panic (see + # mongocrypt_not_enabled.go), leading to confusing messages about unused + # code. + - linters: + - unused + path: x/mongo/driver/crypt.go|mongo/(crypt_retrievers|mongocryptd).go + # Ignore "TLS MinVersion too low", "TLS InsecureSkipVerify set true", and + # "Use of weak random number generator (math/rand instead of crypto/rand)" + # in tests. Disable gosec entirely for test files to reduce noise. + - linters: + - gosec + path: _test\.go + # Ignore missing comments for exported variable/function/type for code in + # the "internal" and "benchmark" directories. + - path: (internal\/|benchmark\/) + text: exported (.+) should have comment( \(or a comment on this block\))? or be unexported + # Ignore missing package comments for directories that aren't frequently + # used by external users. + - path: (internal\/|benchmark\/|x\/|cmd\/|mongo\/integration\/) + text: should have a package comment -linters-settings: - errcheck: - exclude-functions: .errcheck-excludes - govet: - disable: - - cgocall - - composites - paralleltest: - # Ignore missing calls to `t.Parallel()` and only report incorrect uses of `t.Parallel()`. - ignore-missing: true - staticcheck: - checks: [ - "all", - "-SA1019", # Disable deprecation warnings for now. - "-SA1012", # Disable "do not pass a nil Context" to allow testing nil contexts in tests. - ] - -issues: - exclude-dirs-use-default: false - exclude-dirs: - - (^|/)testdata($|/) - - (^|/)etc($|/) - # Disable all linters for copied third-party code. - - internal/rand - - internal/aws - - internal/assert - exclude-use-default: false - exclude: - # Add all default excluded issues except issues related to exported types/functions not having - # comments; we want those warnings. The defaults are copied from the "--exclude-use-default" - # documentation on https://golangci-lint.run/usage/configuration/#command-line-options - ## Defaults ## - # EXC0001 errcheck: Almost all programs ignore errors on these functions and in most cases it's ok - - Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*print(f|ln)?|os\.(Un)?Setenv). is not checked - # EXC0003 golint: False positive when tests are defined in package 'test' - - func name will be used as test\.Test.* by other packages, and that stutters; consider calling this - # EXC0004 govet: Common false positives - - (possible misuse of unsafe.Pointer|should have signature) - # EXC0005 staticcheck: Developers tend to write in C-style with an explicit 'break' in a 'switch', so it's ok to ignore - - ineffective break statement. Did you mean to break out of the outer loop - # EXC0006 gosec: Too many false-positives on 'unsafe' usage - - Use of unsafe calls should be audited - # EXC0007 gosec: Too many false-positives for parametrized shell calls - - Subprocess launch(ed with variable|ing should be audited) - # EXC0008 gosec: Duplicated errcheck checks - - (G104|G307) - # EXC0009 gosec: Too many issues in popular repos - - (Expect directory permissions to be 0750 or less|Expect file permissions to be 0600 or less) - # EXC0010 gosec: False positive is triggered by 'src, err := ioutil.ReadFile(filename)' - - Potential file inclusion via variable - ## End Defaults ## + # Add all default excluded issues except issues related to exported + # types/functions not having comments; we want those warnings. The + # defaults are copied from the "--exclude-use-default" documentation on + # https://golangci-lint.run/usage/configuration/#command-line-options + # + ## Defaults ## + # + # EXC0001 errcheck: Almost all programs ignore errors on these functions + # and in most cases it's ok + - path: (.+)\.go$ + text: Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*print(f|ln)?|os\.(Un)?Setenv). is not checked + # EXC0003 golint: False positive when tests are defined in package 'test' + - path: (.+)\.go$ + text: func name will be used as test\.Test.* by other packages, and that stutters; consider calling this + # EXC0004 govet: Common false positives + - path: (.+)\.go$ + text: (possible misuse of unsafe.Pointer|should have signature) + # EXC0005 staticcheck: Developers tend to write in C-style with an explicit 'break' in a 'switch', so it's ok to ignore + - path: (.+)\.go$ + text: ineffective break statement. Did you mean to break out of the outer loop + # EXC0006 gosec: Too many false-positives on 'unsafe' usage + - path: (.+)\.go$ + text: Use of unsafe calls should be audited + # EXC0007 gosec: Too many false-positives for parametrized shell calls + - path: (.+)\.go$ + text: Subprocess launch(ed with variable|ing should be audited) + # EXC0008 gosec: Duplicated errcheck checks + - path: (.+)\.go$ + text: (G104|G307) + # EXC0009 gosec: Too many issues in popular repos + - path: (.+)\.go$ + text: (Expect directory permissions to be 0750 or less|Expect file permissions to be 0600 or less) + # EXC0010 gosec: False positive is triggered by + # 'src, err := ioutil.ReadFile(filename)' + - path: (.+)\.go$ + text: Potential file inclusion via variable + ## End Defaults ## - # Ignore capitalization warning for this weird field name. - - "var-naming: struct field CqCssWxW should be CqCSSWxW" - # Ignore warnings for common "wiremessage.Read..." usage because the safest way to use that API - # is by assigning possibly unused returned byte buffers. - - "SA4006: this value of `wm` is never used" - - "SA4006: this value of `rem` is never used" - - "ineffectual assignment to wm" - - "ineffectual assignment to rem" + # Ignore capitalization warning for this weird field name. + - path: (.+)\.go$ + text: "var-naming: struct field CqCssWxW should be CqCSSWxW" - exclude-rules: - # Ignore some linters for example code that is intentionally simplified. - - path: examples/ - linters: - - revive - - errcheck - # Disable "unused" linter for code files that depend on the "mongocrypt.MongoCrypt" type because - # the linter build doesn't work correctly with CGO enabled. As a result, all calls to a - # "mongocrypt.MongoCrypt" API appear to always panic (see mongocrypt_not_enabled.go), leading - # to confusing messages about unused code. - - path: x/mongo/driver/crypt.go|mongo/(crypt_retrievers|mongocryptd).go - linters: - - unused - # Ignore "TLS MinVersion too low", "TLS InsecureSkipVerify set true", and "Use of weak random - # number generator (math/rand instead of crypto/rand)" in tests. - - path: _test\.go - text: G401|G402|G404 - linters: - - gosec - # Ignore missing comments for exported variable/function/type for code in the "internal" and - # "benchmark" directories. - - path: (internal\/|benchmark\/) - text: exported (.+) should have comment( \(or a comment on this block\))? or be unexported - # Ignore missing package comments for directories that aren't frequently used by external users. - - path: (internal\/|benchmark\/|x\/|cmd\/|mongo\/integration\/) - text: should have a package comment + # Ignore warnings for common "wiremessage.Read..." usage because the + # safest way to use that API is by assigning possibly unused returned byte + # buffers. + - path: (.+)\.go$ + text: "SA4006: this value of `wm` is never used" + - path: (.+)\.go$ + text: "SA4006: this value of `rem` is never used" + - path: (.+)\.go$ + text: ineffectual assignment to wm + - path: (.+)\.go$ + text: ineffectual assignment to rem + paths: + - (^|/)testdata($|/) + - (^|/)etc($|/) + # Disable all linters for copied third-party code. + - internal/rand + - internal/aws + - internal/assert +formatters: + enable: + - goimports + exclusions: + generated: lax + paths: + - (^|/)testdata($|/) + - (^|/)etc($|/) + - internal/rand + - internal/aws + - internal/assert diff --git a/bson/buffered_byte_src.go b/bson/buffered_byte_src.go index eb19e3cb3f..84fb0306c7 100644 --- a/bson/buffered_byte_src.go +++ b/bson/buffered_byte_src.go @@ -9,6 +9,8 @@ package bson import ( "bytes" "io" + + "go.mongodb.org/mongo-driver/v2/internal/mathutil" ) // bufferedByteSrc implements the low-level byteSrc interface by reading @@ -115,7 +117,12 @@ func (b *bufferedByteSrc) regexLength() (int32, error) { // Total length = first C-string length (pattern) + second C-string length // (options) + 2 null terminators - return int32(i + j + 2), nil + length, err := mathutil.SafeConvertNumeric[int32](i + j + 2) + if err != nil { + return 0, err + } + + return length, nil } func (*bufferedByteSrc) streamable() bool { diff --git a/bson/decimal.go b/bson/decimal.go index 6241733a19..1257c72a54 100644 --- a/bson/decimal.go +++ b/bson/decimal.go @@ -74,11 +74,13 @@ func (d Decimal128) BigInt() (*big.Int, int, error) { if high>>61&3 == 3 { // Bits: 1*sign 2*ignored 14*exponent 111*significand. // Implicit 0b100 prefix in significand. + // nolint:gosec // G115: bitmasked to 14 bits, safe conversion exp = int(high >> 47 & (1<<14 - 1)) // Spec says all of these values are out of range. high, low = 0, 0 } else { // Bits: 1*sign 14*exponent 113*significand + // nolint:gosec // G115: bitmasked to 14 bits, safe conversion exp = int(high >> 49 & (1<<14 - 1)) high &= (1<<49 - 1) } @@ -314,6 +316,7 @@ func ParseDecimal128FromBigInt(bi *big.Int, exp int) (Decimal128, bool) { l = l<<8 | uint64(b[i]) } + // nolint:gosec // G115: exp is within MinDecimal128Exp to MaxDecimal128Exp range h |= uint64(exp-MinDecimal128Exp) & uint64(1<<14-1) << 49 if bi.Sign() == -1 { h |= 1 << 63 diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 67d464cb88..7af97afabf 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -14,6 +14,7 @@ import ( "reflect" "sync" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -120,17 +121,30 @@ func fitsIn32Bits(i int64) bool { func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Int8, reflect.Int16, reflect.Int32: - return vw.WriteInt32(int32(val.Int())) + i64 := val.Int() + i32, err := mathutil.SafeConvertNumeric[int32](i64) + if err != nil { + return ValueEncoderError{Name: "IntEncodeValue", Kinds: []reflect.Kind{val.Kind()}, Received: val} + } + return vw.WriteInt32(i32) case reflect.Int: i64 := val.Int() if fitsIn32Bits(i64) { - return vw.WriteInt32(int32(i64)) + i32, err := mathutil.SafeConvertNumeric[int32](i64) + if err != nil { + return ValueEncoderError{Name: "IntEncodeValue", Kinds: []reflect.Kind{reflect.Int}, Received: val} + } + return vw.WriteInt32(i32) } return vw.WriteInt64(i64) case reflect.Int64: i64 := val.Int() if ec.minSize && fitsIn32Bits(i64) { - return vw.WriteInt32(int32(i64)) + i32, err := mathutil.SafeConvertNumeric[int32](i64) + if err != nil { + return ValueEncoderError{Name: "IntEncodeValue", Kinds: []reflect.Kind{reflect.Int64}, Received: val} + } + return vw.WriteInt32(i32) } return vw.WriteInt64(i64) } @@ -369,7 +383,8 @@ func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { t := val.Type() if !val.IsValid() || t != tVector { - return ValueEncoderError{Name: "VectorEncodeValue", + return ValueEncoderError{ + Name: "VectorEncodeValue", Types: []reflect.Type{tVector}, Received: val, } diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index 6aa74f292c..db0871525e 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -201,7 +201,7 @@ func getterEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) erro return vw.WriteNull() } vv := reflect.ValueOf(x) - encoder, err := ec.Registry.LookupEncoder(vv.Type()) + encoder, err := ec.LookupEncoder(vv.Type()) if err != nil { return err } diff --git a/bson/objectid.go b/bson/objectid.go index 024281eb7f..0d69993c5b 100644 --- a/bson/objectid.go +++ b/bson/objectid.go @@ -20,6 +20,8 @@ import ( "io" "sync/atomic" "time" + + "go.mongodb.org/mongo-driver/v2/internal/mathutil" ) // ErrInvalidHex indicates that a hex string cannot be converted to an ObjectID. @@ -31,11 +33,15 @@ type ObjectID [12]byte // NilObjectID is the zero value for ObjectID. var NilObjectID ObjectID -var objectIDCounter = readRandomUint32() -var processUnique = processUniqueBytes() +var ( + objectIDCounter = readRandomUint32() + processUnique = processUniqueBytes() +) -var _ encoding.TextMarshaler = ObjectID{} -var _ encoding.TextUnmarshaler = &ObjectID{} +var ( + _ encoding.TextMarshaler = ObjectID{} + _ encoding.TextUnmarshaler = &ObjectID{} +) // NewObjectID generates a new ObjectID. func NewObjectID() ObjectID { @@ -46,7 +52,11 @@ func NewObjectID() ObjectID { func NewObjectIDFromTimestamp(timestamp time.Time) ObjectID { var b [12]byte - binary.BigEndian.PutUint32(b[0:4], uint32(timestamp.Unix())) + secs, err := mathutil.SafeConvertNumeric[uint32](timestamp.Unix()) + if err != nil { + secs = 0 + } + binary.BigEndian.PutUint32(b[0:4], secs) copy(b[4:9], processUnique[:]) putUint24(b[9:12], atomic.AddUint32(&objectIDCounter, 1)) diff --git a/bson/uint_codec.go b/bson/uint_codec.go index 0cdcc635d1..25c7294699 100644 --- a/bson/uint_codec.go +++ b/bson/uint_codec.go @@ -27,6 +27,7 @@ var _ typeDecoder = &uintCodec{} func (uic *uintCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Uint8, reflect.Uint16: + // nolint:gosec // G115: uint8 and uint16 fit within int32 return vw.WriteInt32(int32(val.Uint())) case reflect.Uint, reflect.Uint32, reflect.Uint64: u64 := val.Uint() @@ -35,6 +36,7 @@ func (uic *uintCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect. useMinSize := ec.minSize || (uic.encodeToMinSize && val.Kind() != reflect.Uint64) if u64 <= math.MaxInt32 && useMinSize { + // nolint:gosec // G115: checked against MaxInt32 return vw.WriteInt32(int32(u64)) } if u64 > math.MaxInt64 { diff --git a/bson/value_reader.go b/bson/value_reader.go index e5bcc1985f..7099640b1f 100644 --- a/bson/value_reader.go +++ b/bson/value_reader.go @@ -14,6 +14,8 @@ import ( "io" "math" "sync" + + "go.mongodb.org/mongo-driver/v2/internal/binaryutil" ) type byteSrc interface { @@ -916,7 +918,7 @@ func (vr *valueReader) peekLength() (int32, error) { if err != nil { return 0, err } - return int32(binary.LittleEndian.Uint32(buf)), nil + return binaryutil.ReadI32Unsafe(buf), nil } func (vr *valueReader) readLength() (int32, error) { @@ -929,7 +931,11 @@ func (vr *valueReader) readi32() (int32, error) { return 0, err } - return int32(binary.LittleEndian.Uint32(raw)), nil + value, _, ok := binaryutil.ReadI32(raw) + if !ok { + return 0, fmt.Errorf("insufficient bytes to read int32") + } + return value, nil } func (vr *valueReader) readu32() (uint32, error) { @@ -947,6 +953,9 @@ func (vr *valueReader) readi64() (int64, error) { return 0, err } + // BSON stores signed integers using two's complement. + // This uint64->int64 conversion is intentional bit reinterpretation per BSON spec. + // nolint:gosec // G115: BSON spec requires reinterpreting bits, not validating range return int64(binary.LittleEndian.Uint64(raw)), nil } diff --git a/bson/value_writer.go b/bson/value_writer.go index 9dd8912d08..da963b7a7a 100644 --- a/bson/value_writer.go +++ b/bson/value_writer.go @@ -15,6 +15,7 @@ import ( "strings" "sync" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -138,7 +139,12 @@ func (vw *valueWriter) push(m mode) { } func (vw *valueWriter) reserveLength() { - vw.stack[vw.frame].start = int32(len(vw.buf)) + start, err := mathutil.SafeConvertNumeric[int32](len(vw.buf)) + if err != nil { + panic(fmt.Errorf("valueWriter buffer length %d overflows int32: %w", len(vw.buf), err)) + } + + vw.stack[vw.frame].start = start vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00) } diff --git a/etc/golangci-lint.sh b/etc/golangci-lint.sh index fede529ed0..5633f1a029 100755 --- a/etc/golangci-lint.sh +++ b/etc/golangci-lint.sh @@ -3,7 +3,7 @@ set -ex # Keep this in sync with go version used in static-analysis Evergreen build variant. GO_VERSION=1.23.0 -GOLANGCI_LINT_VERSION=1.60.1 +GOLANGCI_LINT_VERSION=2.6.2 # Unset the cross-compiler overrides while downloading binaries. GOOS_ORIG=${GOOS:-} diff --git a/internal/binaryutil/binaryutil.go b/internal/binaryutil/binaryutil.go new file mode 100644 index 0000000000..300fb26858 --- /dev/null +++ b/internal/binaryutil/binaryutil.go @@ -0,0 +1,90 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package binaryutil + +// TODO(GODRIVER-3707): Consolidate remaining duplicate binary utility functions +// from bsoncore and wiremessage packages. + +// ReadI64 reads an 8-byte little-endian int64 from src returning the value, +// remaining bytes, and ok flag. +func ReadI64(src []byte) (int64, []byte, bool) { + if len(src) < 8 { + return 0, src, false + } + + _ = src[7] // bounds check hint to compiler + + // Do arithmetic in uint64, then convert to int64 + value := uint64(src[0]) | + uint64(src[1])<<8 | + uint64(src[2])<<16 | + uint64(src[3])<<24 | + uint64(src[4])<<32 | + uint64(src[5])<<40 | + uint64(src[6])<<48 | + uint64(src[7])<<56 // MSB contains sign bit (bit 63) + + return int64(value), src[8:], true +} + +// AppendI64 appends an int64 to dst in little-endian byte order. +func AppendI64(dst []byte, x int64) []byte { + return append(dst, + byte(x), + byte(x>>8), + byte(x>>16), + byte(x>>24), + byte(x>>32), + byte(x>>40), + byte(x>>48), + byte(x>>56), + ) +} + +// AppendI32 appends an int32 to dst in little-endian byte order. +func AppendI32(dst []byte, x int32) []byte { + return append(dst, + byte(x), + byte(x>>8), + byte(x>>16), + byte(x>>24), + ) +} + +// ReadI32 reads a 32-bit little-endian int32 from src returning the value, +// remaining bytes, and ok flag. +func ReadI32(src []byte) (int32, []byte, bool) { + if len(src) < 4 { + return 0, src, false + } + + _ = src[3] // bounds check hint to compiler + + value := int32(src[0]) | + int32(src[1])<<8 | + int32(src[2])<<16 | + int32(src[3])<<24 + + return value, src[4:], true +} + +// ReadI32Unsafe reads a 32-bit little-endian int32 from src without length +// checks. The caller must ensure src has at least 4 bytes. +func ReadI32Unsafe(src []byte) int32 { + _ = src[3] // bounds check hint to compiler + + return int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24 +} + +// PutI32 writes a little-endian int32 into dst starting at offset. Caller must +// ensure capacity. +func PutI32(dst []byte, offset int, value int32) { + dst[offset] = byte(value) + dst[offset+1] = byte(value >> 8) + dst[offset+2] = byte(value >> 16) + dst[offset+3] = byte(value >> 24) +} diff --git a/internal/binaryutil/binaryutil_test.go b/internal/binaryutil/binaryutil_test.go new file mode 100644 index 0000000000..708265d629 --- /dev/null +++ b/internal/binaryutil/binaryutil_test.go @@ -0,0 +1,388 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package binaryutil + +import ( + "math" + "testing" + + "go.mongodb.org/mongo-driver/v2/internal/assert" +) + +func TestReadi64(t *testing.T) { + testCases := []struct { + desc string + src []byte + want int64 + wantRem []byte + wantOK bool + }{ + { + desc: "0", + src: []byte{0, 0, 0, 0, 0, 0, 0, 0}, // little-endian int64(0) + want: 0, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "1", + src: []byte{1, 0, 0, 0, 0, 0, 0, 0}, // little-endian int64(1) + want: 1, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "-1", + src: []byte{255, 255, 255, 255, 255, 255, 255, 255}, // little-endian int64(-1) + want: -1, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "max", + src: []byte{255, 255, 255, 255, 255, 255, 255, 127}, // little-endian max int64 + want: math.MaxInt64, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "min", + src: []byte{0, 0, 0, 0, 0, 0, 0, 128}, // little-endian min int64 + want: math.MinInt64, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "non-empty remaining", + src: []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3}, // little-endian int64(1) + extra bytes + want: 1, + wantRem: []byte{0, 1, 2, 3}, + wantOK: true, + }, + { + desc: "not enough bytes", + src: []byte{0, 1, 2, 3, 4, 5, 6}, // only 7 bytes, need 8 + want: 0, + wantRem: []byte{0, 1, 2, 3, 4, 5, 6}, + wantOK: false, + }, + { + desc: "nil", + src: nil, // nil slice + want: 0, + wantRem: nil, + wantOK: false, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + x, rem, ok := ReadI64(tc.src) + assert.Equal(t, tc.want, x, "int64 result does not match") + assert.Equal(t, tc.wantRem, rem, "remaining bytes do not match") + assert.Equal(t, tc.wantOK, ok, "OK does not match") + }) + } +} + +func TestAppendI64(t *testing.T) { + testCases := []struct { + desc string + dst []byte + x int64 + want []byte + }{ + { + desc: "0", + dst: []byte{}, + x: 0, + want: []byte{0, 0, 0, 0, 0, 0, 0, 0}, // little-endian int64(0) + }, + { + desc: "1", + dst: []byte{}, + x: 1, + want: []byte{1, 0, 0, 0, 0, 0, 0, 0}, // little-endian int64(1) + }, + { + desc: "-1", + dst: []byte{}, + x: -1, + want: []byte{255, 255, 255, 255, 255, 255, 255, 255}, // little-endian int64(-1) + }, + { + desc: "max", + dst: []byte{}, + x: math.MaxInt64, + want: []byte{255, 255, 255, 255, 255, 255, 255, 127}, // little-endian max int64 + }, + { + desc: "min", + dst: []byte{}, + x: math.MinInt64, + want: []byte{0, 0, 0, 0, 0, 0, 0, 128}, // little-endian min int64 + }, + { + desc: "non-empty dst", + dst: []byte{0, 1, 2, 3}, + x: 1, + want: []byte{0, 1, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0}, // existing bytes + little-endian int64(1) + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + b := AppendI64(tc.dst, tc.x) + assert.Equal(t, tc.want, b, "bytes do not match") + }) + } +} + +func TestAppendI32(t *testing.T) { + testCases := []struct { + desc string + dst []byte + x int32 + want []byte + }{ + { + desc: "0", + x: 0, + want: []byte{0, 0, 0, 0}, // little-endian int32(0) + }, + { + desc: "1", + x: 1, + want: []byte{1, 0, 0, 0}, // little-endian int32(1) + }, + { + desc: "-1", + x: -1, + want: []byte{255, 255, 255, 255}, // little-endian int32(-1) + }, + { + desc: "max", + x: math.MaxInt32, + want: []byte{255, 255, 255, 127}, // little-endian max int32 + }, + { + desc: "min", + x: math.MinInt32, + want: []byte{0, 0, 0, 128}, // little-endian min int32 + }, + { + desc: "non-empty dst", + dst: []byte{0, 1, 2, 3}, + x: 1, + want: []byte{0, 1, 2, 3, 1, 0, 0, 0}, // existing bytes + little-endian int32(1) + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + b := AppendI32(tc.dst, tc.x) + assert.Equal(t, tc.want, b, "bytes do not match") + }) + } +} + +func TestReadI32(t *testing.T) { + testCases := []struct { + desc string + src []byte + want int32 + wantRem []byte + wantOK bool + }{ + { + desc: "0", + src: []byte{0, 0, 0, 0}, // little-endian int32(0) + want: 0, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "1", + src: []byte{1, 0, 0, 0}, // little-endian int32(1) + want: 1, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "-1", + src: []byte{255, 255, 255, 255}, // little-endian int32(-1) + want: -1, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "max", + src: []byte{255, 255, 255, 127}, // little-endian max int32 + want: math.MaxInt32, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "min", + src: []byte{0, 0, 0, 128}, // little-endian min int32 + want: math.MinInt32, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "non-empty remaining", + src: []byte{1, 0, 0, 0, 0, 1, 2, 3}, // little-endian int32(1) + extra bytes + want: 1, + wantRem: []byte{0, 1, 2, 3}, + wantOK: true, + }, + { + desc: "not enough bytes", + src: []byte{0, 1, 2}, // only 3 bytes, need 4 + want: 0, + wantRem: []byte{0, 1, 2}, + wantOK: false, + }, + { + desc: "nil", + src: nil, // nil slice + want: 0, + wantRem: nil, + wantOK: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + x, rem, ok := ReadI32(tc.src) + assert.Equal(t, tc.want, x, "int32 result does not match") + assert.Equal(t, tc.wantRem, rem, "remaining bytes do not match") + assert.Equal(t, tc.wantOK, ok, "OK does not match") + }) + } +} + +func TestReadI32Unsafe(t *testing.T) { + testCases := []struct { + desc string + src []byte + want int32 + }{ + { + desc: "0", + src: []byte{0, 0, 0, 0}, // little-endian int32(0) + want: 0, + }, + { + desc: "1", + src: []byte{1, 0, 0, 0}, // little-endian int32(1) + want: 1, + }, + { + desc: "-1", + src: []byte{255, 255, 255, 255}, // little-endian int32(-1) + want: -1, + }, + { + desc: "max", + src: []byte{255, 255, 255, 127}, // little-endian max int32 + want: math.MaxInt32, + }, + { + desc: "min", + src: []byte{0, 0, 0, 128}, // little-endian min int32 + want: math.MinInt32, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + x := ReadI32Unsafe(tc.src) + assert.Equal(t, tc.want, x, "int32 result does not match") + }) + } +} + +func TestPutI32(t *testing.T) { + testCases := []struct { + desc string + dst []byte + offset int + value int32 + want []byte + }{ + { + desc: "0", + dst: make([]byte, 4), + offset: 0, + value: 0, + want: []byte{0, 0, 0, 0}, // little-endian int32(0) + }, + { + desc: "1", + dst: make([]byte, 4), + offset: 0, + value: 1, + want: []byte{1, 0, 0, 0}, // little-endian int32(1) + }, + { + desc: "-1", + dst: make([]byte, 4), + offset: 0, + value: -1, + want: []byte{255, 255, 255, 255}, // little-endian int32(-1) + }, + { + desc: "max", + dst: make([]byte, 4), + offset: 0, + value: math.MaxInt32, + want: []byte{255, 255, 255, 127}, // little-endian max int32 + }, + { + desc: "min", + dst: make([]byte, 4), + offset: 0, + value: math.MinInt32, + want: []byte{0, 0, 0, 128}, // little-endian min int32 + }, + { + desc: "with offset", + dst: []byte{99, 99, 0, 0, 0, 0, 99, 99}, + offset: 2, + value: 1, + want: []byte{99, 99, 1, 0, 0, 0, 99, 99}, // little-endian int32(1) at offset 2 + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + PutI32(tc.dst, tc.offset, tc.value) + assert.Equal(t, tc.want, tc.dst, "bytes do not match") + }) + } +} diff --git a/internal/binaryutil/doc.go b/internal/binaryutil/doc.go new file mode 100644 index 0000000000..0f8bbf3f3f --- /dev/null +++ b/internal/binaryutil/doc.go @@ -0,0 +1,13 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +// Package binaryutil provides utility functions for reading and writing +// binary integer data in little-endian byte order. +// +// This package centralizes binary encoding/decoding operations used throughout +// the MongoDB Go driver, particularly for BSON encoding/decoding and wire +// protocol message handling. +package binaryutil diff --git a/internal/credproviders/static_provider.go b/internal/credproviders/static_provider.go index bbb0e8033a..a99b223f02 100644 --- a/internal/credproviders/static_provider.go +++ b/internal/credproviders/static_provider.go @@ -45,7 +45,7 @@ func verify(v credentials.Value) error { func (s *StaticProvider) Retrieve() (credentials.Value, error) { if !s.verified { s.err = verify(s.Value) - s.Value.ProviderName = staticProviderName + s.ProviderName = staticProviderName s.verified = true } return s.Value, s.err diff --git a/internal/decimal128/decimal128.go b/internal/decimal128/decimal128.go index 2767e4457b..41548215de 100644 --- a/internal/decimal128/decimal128.go +++ b/internal/decimal128/decimal128.go @@ -30,6 +30,7 @@ func divmod(h, l uint64, div uint32) (qh, ql uint64, rem uint32) { d := cr<<32 + l&(1<<32-1) dq := d / div64 dr := d % div64 + //nolint:gosec // G115: dr is result of modulo, always fits in uint32 return (aq<<32 | bq), (cq<<32 | dq), uint32(dr) } @@ -54,11 +55,13 @@ func String(h, l uint64) string { if h>>61&3 == 3 { // Bits: 1*sign 2*ignored 14*exponent 111*significand. // Implicit 0b100 prefix in significand. + //nolint:gosec // G115: 14-bit value always fits in int exp = int(h >> 47 & (1<<14 - 1)) // Spec says all of these values are out of range. high, low = 0, 0 } else { // Bits: 1*sign 14*exponent 113*significand + //nolint:gosec // G115: 14-bit value always fits in int exp = int(h >> 49 & (1<<14 - 1)) high = h & (1<<49 - 1) } diff --git a/internal/driverutil/description.go b/internal/driverutil/description.go index 926de0bdf9..b63dd48b70 100644 --- a/internal/driverutil/description.go +++ b/internal/driverutil/description.go @@ -14,6 +14,7 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/bsonutil" "go.mongodb.org/mongo-driver/v2/internal/handshake" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/internal/ptrutil" "go.mongodb.org/mongo-driver/v2/mongo/address" "go.mongodb.org/mongo-driver/v2/tag" @@ -314,21 +315,36 @@ func NewServerDescription(addr address.Address, response bson.Raw) description.S desc.LastError = fmt.Errorf("expected 'maxBsonObjectSize' to be an integer but it's a BSON %s", element.Value().Type) return desc } - desc.MaxDocumentSize = uint32(i64) + size, err := mathutil.SafeConvertNumeric[uint32](i64) + if err != nil { + desc.LastError = fmt.Errorf("maxBsonObjectSize value out of range: %d", i64) + return desc + } + desc.MaxDocumentSize = size case "maxMessageSizeBytes": i64, ok := element.Value().AsInt64OK() if !ok { desc.LastError = fmt.Errorf("expected 'maxMessageSizeBytes' to be an integer but it's a BSON %s", element.Value().Type) return desc } - desc.MaxMessageSize = uint32(i64) + size, err := mathutil.SafeConvertNumeric[uint32](i64) + if err != nil { + desc.LastError = fmt.Errorf("maxMessageSizeBytes value out of range: %d", i64) + return desc + } + desc.MaxMessageSize = size case "maxWriteBatchSize": i64, ok := element.Value().AsInt64OK() if !ok { desc.LastError = fmt.Errorf("expected 'maxWriteBatchSize' to be an integer but it's a BSON %s", element.Value().Type) return desc } - desc.MaxBatchCount = uint32(i64) + count, err := mathutil.SafeConvertNumeric[uint32](i64) + if err != nil { + desc.LastError = fmt.Errorf("maxWriteBatchSize value out of range: %d", i64) + return desc + } + desc.MaxBatchCount = count case "me": me, ok := element.Value().StringValueOK() if !ok { @@ -338,18 +354,28 @@ func NewServerDescription(addr address.Address, response bson.Raw) description.S desc.CanonicalAddr = address.Address(me).Canonicalize() case "maxWireVersion": verMax, ok := element.Value().AsInt64OK() - versionRange.Max = int32(verMax) if !ok { - desc.LastError = fmt.Errorf("expected 'maxWireVersion' to be an integer but it's a BSON %s", element.Value().Type) + desc.LastError = fmt.Errorf("invalid maxWireVersion value") return desc } + max, err := mathutil.SafeConvertNumeric[int32](verMax) + if err != nil { + desc.LastError = fmt.Errorf("invalid maxWireVersion value: %w", err) + return desc + } + versionRange.Max = max case "minWireVersion": verMin, ok := element.Value().AsInt64OK() - versionRange.Min = int32(verMin) if !ok { - desc.LastError = fmt.Errorf("expected 'minWireVersion' to be an integer but it's a BSON %s", element.Value().Type) + desc.LastError = fmt.Errorf("invalid minWireVersion value") return desc } + min, err := mathutil.SafeConvertNumeric[int32](verMin) + if err != nil { + desc.LastError = fmt.Errorf("invalid minWireVersion value: %w", err) + return desc + } + versionRange.Min = min case "msg": msg, ok = element.Value().StringValueOK() if !ok { @@ -416,7 +442,12 @@ func NewServerDescription(addr address.Address, response bson.Raw) description.S desc.LastError = fmt.Errorf("expected 'setVersion' to be an integer but it's a BSON %s", element.Value().Type) return desc } - desc.SetVersion = uint32(i64) + version, err := mathutil.SafeConvertNumeric[uint32](i64) + if err != nil { + desc.LastError = fmt.Errorf("setVersion value out of range: %d", i64) + return desc + } + desc.SetVersion = version case "tags": m, err := decodeStringMap(element, "tags") if err != nil { diff --git a/internal/errutil/join_go1.19.go b/internal/errutil/join_go1.19.go index 569a0216b5..becdf9774c 100644 --- a/internal/errutil/join_go1.19.go +++ b/internal/errutil/join_go1.19.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build !go1.20 -// +build !go1.20 package errutil diff --git a/internal/errutil/join_go1.20.go b/internal/errutil/join_go1.20.go index 69b9ad2231..831991c793 100644 --- a/internal/errutil/join_go1.20.go +++ b/internal/errutil/join_go1.20.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build go1.20 -// +build go1.20 package errutil diff --git a/internal/integration/client_side_encryption_prose_test.go b/internal/integration/client_side_encryption_prose_test.go index 18d55c4cf5..d2f84a1c2b 100644 --- a/internal/integration/client_side_encryption_prose_test.go +++ b/internal/integration/client_side_encryption_prose_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package integration diff --git a/internal/integration/client_side_encryption_spec_test.go b/internal/integration/client_side_encryption_spec_test.go index 369f631a7b..d61d486ac7 100644 --- a/internal/integration/client_side_encryption_spec_test.go +++ b/internal/integration/client_side_encryption_spec_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package integration diff --git a/internal/integration/client_side_encryption_test.go b/internal/integration/client_side_encryption_test.go index 6be3b1c6e3..bb1dba65d2 100644 --- a/internal/integration/client_side_encryption_test.go +++ b/internal/integration/client_side_encryption_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package integration diff --git a/internal/integration/csot_cse_prose_test.go b/internal/integration/csot_cse_prose_test.go index 960db3e2a5..b3dfe8094a 100644 --- a/internal/integration/csot_cse_prose_test.go +++ b/internal/integration/csot_cse_prose_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package integration diff --git a/internal/integration/cursor_test.go b/internal/integration/cursor_test.go index 94acd222c5..500058bf6b 100644 --- a/internal/integration/cursor_test.go +++ b/internal/integration/cursor_test.go @@ -215,10 +215,7 @@ func TestCursor_RemainingBatchLength(t *testing.T) { defer cursor.Close(context.Background()) mt.ClearEvents() - for { - if cursor.TryNext(context.Background()) { - break - } + for !cursor.TryNext(context.Background()) { assert.Nil(mt, cursor.Err(), "cursor error: %v", err) assert.Equal(mt, diff --git a/internal/integration/errors_test.go b/internal/integration/errors_test.go index a33b8df93a..2d12ace87a 100644 --- a/internal/integration/errors_test.go +++ b/internal/integration/errors_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build go1.13 -// +build go1.13 package integration diff --git a/internal/integration/mtest/csfle_enabled.go b/internal/integration/mtest/csfle_enabled.go index 588e9ad6de..ba940400eb 100644 --- a/internal/integration/mtest/csfle_enabled.go +++ b/internal/integration/mtest/csfle_enabled.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package mtest diff --git a/internal/integration/mtest/csfle_not_enabled.go b/internal/integration/mtest/csfle_not_enabled.go index 289cf5ce6d..f7c3565fcf 100644 --- a/internal/integration/mtest/csfle_not_enabled.go +++ b/internal/integration/mtest/csfle_not_enabled.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build !cse -// +build !cse package mtest diff --git a/internal/integration/mtest/mongotest.go b/internal/integration/mtest/mongotest.go index 3924b58604..0942f7cc83 100644 --- a/internal/integration/mtest/mongotest.go +++ b/internal/integration/mtest/mongotest.go @@ -20,6 +20,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/csfle" "go.mongodb.org/mongo-driver/v2/internal/failpoint" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/internal/mongoutil" "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo" @@ -600,7 +601,7 @@ func (t *T) createTestClient() { } // Setup command monitor - var customMonitor = clientOpts.Monitor + customMonitor := clientOpts.Monitor clientOpts.SetMonitor(&event.CommandMonitor{ Started: func(ctx context.Context, cse *event.CommandStartedEvent) { if customMonitor != nil && customMonitor.Started != nil { @@ -862,16 +863,5 @@ func (t *T) verifyConstraints() error { } func (t *T) interfaceToInt32(i any) (int32, error) { - switch conv := i.(type) { - case int: - return int32(conv), nil - case int32: - return conv, nil - case int64: - return int32(conv), nil - case float64: - return int32(conv), nil - } - - return 0, fmt.Errorf("type %T cannot be converted to int32", i) + return mathutil.SafeConvertNumeric[int32](i) } diff --git a/internal/integration/mtest/proxy_dialer.go b/internal/integration/mtest/proxy_dialer.go index 0d980c406c..1e59584749 100644 --- a/internal/integration/mtest/proxy_dialer.go +++ b/internal/integration/mtest/proxy_dialer.go @@ -16,6 +16,7 @@ import ( "time" "go.mongodb.org/mongo-driver/v2/internal/handshake" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -168,7 +169,13 @@ func (pc *proxyConn) Read(buffer []byte) (int, error) { // buffer to the end of a four-byte slice and using UpdateLength to set the length bytes. idx, wm := bsoncore.ReserveLength(nil) wm = append(wm, buffer...) - wm = bsoncore.UpdateLength(wm, idx, int32(len(wm[idx:]))) + + wmLen, err := mathutil.SafeConvertNumeric[int32](len(wm)) + if err != nil { + return 0, fmt.Errorf("wire message size %d exceeds maximum int32 size: %w", len(wm), err) + } + + wm = bsoncore.UpdateLength(wm, idx, wmLen) if err := pc.dialer.storeReceivedMessage(wm, pc.RemoteAddr().String()); err != nil { wrapped := fmt.Errorf("error storing received message: %w", err) diff --git a/internal/integration/sdam_error_handling_test.go b/internal/integration/sdam_error_handling_test.go index 737cb67830..d9aa5f2846 100644 --- a/internal/integration/sdam_error_handling_test.go +++ b/internal/integration/sdam_error_handling_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build go1.13 -// +build go1.13 package integration diff --git a/internal/integration/unified/client_entity.go b/internal/integration/unified/client_entity.go index bc981793df..1ad0632bb1 100644 --- a/internal/integration/unified/client_entity.go +++ b/internal/integration/unified/client_entity.go @@ -20,6 +20,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" "go.mongodb.org/mongo-driver/v2/internal/integtest" "go.mongodb.org/mongo-driver/v2/internal/logger" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" @@ -91,7 +92,12 @@ func awaitMinimumPoolSize(ctx context.Context, entity *clientEntity, minPoolSize case <-awaitCtx.Done(): return fmt.Errorf("timed out waiting for client to reach minPoolSize") case <-ticker.C: - if uint64(entity.eventsCount[connectionReadyEvent]) >= minPoolSize { + creCount, err := mathutil.SafeConvertNumeric[uint64](int(entity.eventsCount[connectionReadyEvent])) + if err != nil { + return fmt.Errorf("connectionReadyEvent count %d exceeds maximum uint64 size: %w", entity.eventsCount[connectionReadyEvent], err) + } + + if creCount >= minPoolSize { return nil } } @@ -261,7 +267,7 @@ func (c *clientEntity) disconnect(ctx context.Context) error { return nil } - if err := c.Client.Disconnect(ctx); err != nil { + if err := c.Disconnect(ctx); err != nil { return err } @@ -665,11 +671,26 @@ func setClientOptionsFromURIOptions(clientOpts *options.ClientOptions, uriOpts b case "maxidletimems": clientOpts.SetMaxConnIdleTime(time.Duration(value.(int32)) * time.Millisecond) case "minpoolsize": - clientOpts.SetMinPoolSize(uint64(value.(int32))) + minPoolSize, err := mathutil.SafeConvertNumeric[uint64](int(value.(int32))) + if err != nil { + return fmt.Errorf("minPoolSize value %d is out of range: %w", value.(int32), err) + } + + clientOpts.SetMinPoolSize(minPoolSize) case "maxpoolsize": - clientOpts.SetMaxPoolSize(uint64(value.(int32))) + maxPoolSize, err := mathutil.SafeConvertNumeric[uint64](int(value.(int32))) + if err != nil { + return fmt.Errorf("maxPoolSize value %d is out of range: %w", value.(int32), err) + } + + clientOpts.SetMaxPoolSize(maxPoolSize) case "maxconnecting": - clientOpts.SetMaxConnecting(uint64(value.(int32))) + maxConnecting, err := mathutil.SafeConvertNumeric[uint64](int(value.(int32))) + if err != nil { + return fmt.Errorf("maxConnecting value %d is out of range: %w", value.(int32), err) + } + + clientOpts.SetMaxConnecting(maxConnecting) case "readconcernlevel": clientOpts.SetReadConcern(&readconcern.ReadConcern{Level: value.(string)}) case "retryreads": diff --git a/internal/integration/unified/collection_operation_execution.go b/internal/integration/unified/collection_operation_execution.go index 1e03f0eaed..f980b0c4ae 100644 --- a/internal/integration/unified/collection_operation_execution.go +++ b/internal/integration/unified/collection_operation_execution.go @@ -15,6 +15,7 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/bsonutil" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/internal/mongoutil" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" @@ -1144,8 +1145,12 @@ func executeInsertMany(ctx context.Context, operation *operation) (*operationRes // We return InsertedIDs as []any but the CRUD spec documents it as a map[int64]any, so // comparisons will fail if we include it in the result document. This is marked as an optional field and is // always surrounded in an $$unsetOrMatches assertion, so we leave it out of the document. + insertedCount, err := mathutil.SafeConvertNumeric[int32](len(res.InsertedIDs)) + if err != nil { + return nil, err + } raw = bsoncore.NewDocumentBuilder(). - AppendInt32("insertedCount", int32(len(res.InsertedIDs))). + AppendInt32("insertedCount", insertedCount). AppendInt32("deletedCount", 0). AppendInt32("matchedCount", 0). AppendInt32("modifiedCount", 0). diff --git a/internal/integration/unified/entity.go b/internal/integration/unified/entity.go index b1b827a124..c1a44299ef 100644 --- a/internal/integration/unified/entity.go +++ b/internal/integration/unified/entity.go @@ -18,15 +18,14 @@ import ( "time" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) -var ( - // ErrEntityMapOpen is returned when a slice entity is accessed while the EntityMap is open - ErrEntityMapOpen = errors.New("slices cannot be accessed while EntityMap is open") -) +// ErrEntityMapOpen is returned when a slice entity is accessed while the EntityMap is open +var ErrEntityMapOpen = errors.New("slices cannot be accessed while EntityMap is open") var ( tlsCAFile = os.Getenv("CSFLE_TLS_CA_FILE") @@ -96,17 +95,23 @@ func (eo *entityOptions) setHeartbeatFrequencyMS(freq time.Duration) { } if _, ok := eo.URIOptions["heartbeatFrequencyMS"]; !ok { + freqMS, err := mathutil.SafeConvertNumeric[int32](int64(freq.Milliseconds())) + if err != nil { + panic(fmt.Sprintf("heartbeatFrequencyMS value %d overflows int32", freq.Milliseconds())) + } + // The UST values for heartbeatFrequencyMS are given as int32, // so we need to cast the frequency as int32 before setting it // on the URIOptions map. - eo.URIOptions["heartbeatFrequencyMS"] = int32(freq.Milliseconds()) + eo.URIOptions["heartbeatFrequencyMS"] = freqMS } } // newCollectionEntityOptions constructs an entity options object for a // collection. func newCollectionEntityOptions(id string, databaseID string, collectionName string, - opts *dbOrCollectionOptions) *entityOptions { + opts *dbOrCollectionOptions, +) *entityOptions { options := &entityOptions{ ID: id, DatabaseID: databaseID, @@ -598,7 +603,6 @@ func getKmsCredential(kmsDocument bson.Raw, credentialName string, envVar string return "", fmt.Errorf("unable to get environment value for %v. Please set the CSFLE environment variable: %v", credentialName, envVar) } return os.Getenv(envVar), nil - } func (em *EntityMap) addClientEncryptionEntity(entityOptions *entityOptions) error { diff --git a/internal/integration/unified/error.go b/internal/integration/unified/error.go index ca4e985433..87c0bd4464 100644 --- a/internal/integration/unified/error.go +++ b/internal/integration/unified/error.go @@ -13,6 +13,7 @@ import ( "strings" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/mongo" ) @@ -199,28 +200,38 @@ func extractErrorDetails(err error) (errorDetails, bool) { details.raw = converted.Raw case mongo.WriteException: if converted.WriteConcernError != nil { - details.codes = append(details.codes, int32(converted.WriteConcernError.Code)) + if code, err := mathutil.SafeConvertNumeric[int32](converted.WriteConcernError.Code); err == nil { + details.codes = append(details.codes, code) + } details.codeNames = append(details.codeNames, converted.WriteConcernError.Name) } for _, we := range converted.WriteErrors { - details.codes = append(details.codes, int32(we.Code)) + if code, err := mathutil.SafeConvertNumeric[int32](we.Code); err == nil { + details.codes = append(details.codes, code) + } } details.labels = converted.Labels details.raw = converted.Raw case mongo.BulkWriteException: if converted.WriteConcernError != nil { - details.codes = append(details.codes, int32(converted.WriteConcernError.Code)) + if code, err := mathutil.SafeConvertNumeric[int32](converted.WriteConcernError.Code); err == nil { + details.codes = append(details.codes, code) + } details.codeNames = append(details.codeNames, converted.WriteConcernError.Name) } for _, we := range converted.WriteErrors { - details.codes = append(details.codes, int32(we.Code)) + if code, err := mathutil.SafeConvertNumeric[int32](we.Code); err == nil { + details.codes = append(details.codes, code) + } details.raw = we.Raw } details.labels = converted.Labels case mongo.ClientBulkWriteException: if converted.WriteError != nil { details.raw = converted.WriteError.Raw - details.codes = append(details.codes, int32(converted.WriteError.Code)) + if code, err := mathutil.SafeConvertNumeric[int32](converted.WriteError.Code); err == nil { + details.codes = append(details.codes, code) + } } default: return errorDetails{}, false diff --git a/internal/integration/unified/event_verification.go b/internal/integration/unified/event_verification.go index 0521f0653e..46a88fab79 100644 --- a/internal/integration/unified/event_verification.go +++ b/internal/integration/unified/event_verification.go @@ -429,22 +429,22 @@ func stringifyEventsForClient(client *clientEntity) string { str.WriteString("\n\nStarted Events\n\n") for _, evt := range client.startedEvents() { - str.WriteString(fmt.Sprintf("[%s] %s\n", evt.ConnectionID, evt.Command)) + fmt.Fprintf(str, "[%s] %s\n", evt.ConnectionID, evt.Command) } str.WriteString("\nSucceeded Events\n\n") for _, evt := range client.succeededEvents() { - str.WriteString(fmt.Sprintf("[%s] CommandName: %s, Reply: %s\n", evt.ConnectionID, evt.CommandName, evt.Reply)) + fmt.Fprintf(str, "[%s] CommandName: %s, Reply: %s\n", evt.ConnectionID, evt.CommandName, evt.Reply) } str.WriteString("\nFailed Events\n\n") for _, evt := range client.failedEvents() { - str.WriteString(fmt.Sprintf("[%s] CommandName: %s, Failure: %s\n", evt.ConnectionID, evt.CommandName, evt.Failure)) + fmt.Fprintf(str, "[%s] CommandName: %s, Failure: %s\n", evt.ConnectionID, evt.CommandName, evt.Failure) } str.WriteString("\nPool Events\n\n") for _, evt := range client.poolEvents() { - str.WriteString(fmt.Sprintf("[%s] Event Type: %q\n", evt.Address, evt.Type)) + fmt.Fprintf(str, "[%s] Event Type: %q\n", evt.Address, evt.Type) } return str.String() diff --git a/internal/integration/unified/testrunner_operation.go b/internal/integration/unified/testrunner_operation.go index bb1f9ecac6..a1b4d01655 100644 --- a/internal/integration/unified/testrunner_operation.go +++ b/internal/integration/unified/testrunner_operation.go @@ -14,6 +14,7 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/session" @@ -205,7 +206,11 @@ func executeTestRunnerOperation(ctx context.Context, op *operation, loopDone <-c return err } - expected := int32(lookupInteger(args, "connections")) + expected, err := mathutil.SafeConvertNumeric[int32](lookupInteger(args, "connections")) + if err != nil { + return fmt.Errorf("'connections' argument is out of int32 range: %w", err) + } + actual := client.numberConnectionsCheckedOut() if expected != actual { return fmt.Errorf("expected %d connections to be checked out, got %d", expected, actual) diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index 99906d448c..af4495cb76 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -69,7 +69,7 @@ func decodeTestData(dc bson.DecodeContext, vr bson.ValueReader, val reflect.Valu switch vr.Type() { case bson.TypeArray: docsVal := val.FieldByName("Documents") - decoder, err := dc.Registry.LookupDecoder(docsVal.Type()) + decoder, err := dc.LookupDecoder(docsVal.Type()) if err != nil { return err } @@ -77,7 +77,7 @@ func decodeTestData(dc bson.DecodeContext, vr bson.ValueReader, val reflect.Valu return decoder.DecodeValue(dc, vr, docsVal) case bson.TypeEmbeddedDocument: gridfsDataVal := val.FieldByName("GridFSData") - decoder, err := dc.Registry.LookupDecoder(gridfsDataVal.Type()) + decoder, err := dc.LookupDecoder(gridfsDataVal.Type()) if err != nil { return err } diff --git a/internal/israce/norace.go b/internal/israce/norace.go index 5c4422a678..b5d9662b1f 100644 --- a/internal/israce/norace.go +++ b/internal/israce/norace.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build !race -// +build !race // Package israce reports if the Go race detector is enabled. package israce diff --git a/internal/israce/race.go b/internal/israce/race.go index bd252147e7..c24e4d1d7f 100644 --- a/internal/israce/race.go +++ b/internal/israce/race.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build race -// +build race // Package israce reports if the Go race detector is enabled. package israce diff --git a/internal/logger/logger.go b/internal/logger/logger.go index c9b700b2d7..92ac264303 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -16,6 +16,7 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/bsoncoreutil" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -28,8 +29,10 @@ const DefaultMaxDocumentLength = 1000 // toward the max document length. const TruncationSuffix = "..." -const logSinkPathEnvVar = "MONGODB_LOG_PATH" -const maxDocumentLengthEnvVar = "MONGODB_LOG_MAX_DOCUMENT_LENGTH" +const ( + logSinkPathEnvVar = "MONGODB_LOG_PATH" + maxDocumentLengthEnvVar = "MONGODB_LOG_MAX_DOCUMENT_LENGTH" +) // LogSink represents a logging implementation, this interface should be 1-1 // with the exported "LogSink" interface in the mongo/options package. @@ -185,7 +188,7 @@ func selectLogSink(sink LogSink) (LogSink, *os.File, error) { } if path != "" { - logFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666) + logFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o666) if err != nil { return nil, nil, fmt.Errorf("unable to open log file: %w", err) } @@ -241,7 +244,13 @@ func FormatDocument(msg bson.Raw, width uint) string { return "{}" } - str, truncated := bsoncore.Document(msg).StringN(int(width)) + widthi, err := mathutil.SafeConvertNumeric[int](width) + if err != nil { + // Propagate warning about width being too large to format. + return "[WARNING] width too large to format document for logging, exceeds max int" + } + + str, truncated := bsoncore.Document(msg).StringN(widthi) if truncated { str += TruncationSuffix @@ -253,7 +262,14 @@ func FormatDocument(msg bson.Raw, width uint) string { // FormatString formats a String for logging. The string is truncated // to the given width. func FormatString(str string, width uint) string { - strTrunc := bsoncoreutil.Truncate(str, int(width)) + var strTrunc string + widthi, err := mathutil.SafeConvertNumeric[int](width) + if err != nil { + // Propagate warning about width being too large to format. + return "[WARNING] width too large to format string for logging, exceeds max int" + } + + strTrunc = bsoncoreutil.Truncate(str, widthi) // Checks if the string was truncating by comparing the lengths of the two strings. if len(strTrunc) < len(str) { diff --git a/internal/mathutil/mathutil.go b/internal/mathutil/mathutil.go new file mode 100644 index 0000000000..c71e22f6df --- /dev/null +++ b/internal/mathutil/mathutil.go @@ -0,0 +1,161 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mathutil + +import ( + "errors" + "fmt" + "math" +) + +var ( + ErrOverflow = errors.New("numeric overflow") + ErrUnderflow = errors.New("numeric underflow") + ErrUnsupported = errors.New("unsupported numeric type") +) + +func overflowError(from any, to any) error { + return fmt.Errorf("%w: %v (%T) to %v (%T)", ErrOverflow, from, from, to, to) +} + +func underflowError(from any, to any) error { + return fmt.Errorf("%w: %v (%T) to %v (%T)", ErrUnderflow, from, from, to, to) +} + +func unsupportedError(from any, to any) error { + return fmt.Errorf("%w: %v (%T) to %v (%T)", ErrUnsupported, from, from, to, to) +} + +type Numeric interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | + ~uint64 | ~float32 | ~float64 +} + +func i64ToT[T Numeric](i64 int64) (T, error) { + var zero T + + switch any(zero).(type) { + case int: + if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) { + return zero, overflowError(i64, zero) + } + case int32: + if i64 > int64(math.MaxInt32) || i64 < int64(math.MinInt32) { + return zero, overflowError(i64, zero) + } + case uint32: + if i64 < 0 || i64 > int64(math.MaxUint32) { + return zero, overflowError(i64, zero) + } + case uint64: + if i64 < 0 { + return zero, underflowError(i64, zero) + } + default: + return zero, unsupportedError(i64, zero) + } + + return T(i64), nil +} + +func iToT[T Numeric](i int) (T, error) { + var zero T + + switch any(zero).(type) { + case int32: + if i > int(math.MaxInt32) || i < int(math.MinInt32) { + return zero, overflowError(i, zero) + } + case int64: + return T(i), nil + case uint: + if i < 0 { + return zero, overflowError(i, zero) + } + case uint32: + if i < 0 || i > int(math.MaxUint32) { + return zero, overflowError(i, zero) + } + case uint64: + if i < 0 { + return zero, underflowError(i, zero) + } + default: + return zero, unsupportedError(i, zero) + } + + return T(i), nil +} + +func u64ToT[T Numeric](u64 uint64) (T, error) { + var zero T + maxUint := ^uint(0) + + switch any(zero).(type) { + case int: + if u64 > uint64(math.MaxInt) { + return zero, overflowError(u64, zero) + } + case int32: + if u64 > uint64(math.MaxInt32) { + return zero, overflowError(u64, zero) + } + case int64: + if u64 > uint64(math.MaxInt64) { + return zero, overflowError(u64, zero) + } + case uint: + if u64 > uint64(maxUint) { + return zero, overflowError(u64, zero) + } + case uint32: + if u64 > uint64(math.MaxUint32) { + return zero, overflowError(u64, zero) + } + case uint64: + default: + return zero, unsupportedError(u64, zero) + } + + return T(u64), nil +} + +func uToT[T Numeric](u uint) (T, error) { + return u64ToT[T](uint64(u)) +} + +func f64ToT[T Numeric](f64 float64) (T, error) { + var zero T + + switch any(zero).(type) { + case uint64: + if f64 < 0 || f64 > float64(math.MaxUint64) { + return zero, overflowError(f64, zero) + } + default: + return zero, unsupportedError(f64, zero) + } + + return T(f64), nil +} + +func SafeConvertNumeric[T Numeric](number any) (T, error) { + switch v := number.(type) { + case int: + return iToT[T](v) + case int64: + return i64ToT[T](v) + case uint: + return uToT[T](v) + case uint64: + return u64ToT[T](v) + case float64: + return f64ToT[T](v) + } + + return *new(T), unsupportedError(number, *new(T)) +} diff --git a/internal/mathutil/mathutil_test.go b/internal/mathutil/mathutil_test.go new file mode 100644 index 0000000000..e1e6d95dc3 --- /dev/null +++ b/internal/mathutil/mathutil_test.go @@ -0,0 +1,181 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mathutil + +import ( + "errors" + "math" + "reflect" + "testing" +) + +var ( + typeInt = reflect.TypeOf(int(0)) + typeInt32 = reflect.TypeOf(int32(0)) + typeUint32 = reflect.TypeOf(uint32(0)) + typeUint64 = reflect.TypeOf(uint64(0)) + typeInt64 = reflect.TypeOf(int64(0)) + typeUint = reflect.TypeOf(uint(0)) + maxInt = int(^uint(0) >> 1) + overflowU = uint(maxInt) + 1 + overflowI32 = uint(math.MaxInt32) + 1 +) + +func TestSafeConvertNumeric(t *testing.T) { + testCases := []struct { + name string + target reflect.Type + input any + expectVal any + expectErr error + }{ + // int64 sources + {name: "int64 to int32", target: typeInt32, input: int64(123), expectVal: int32(123)}, + {name: "int64 to int32 OF", target: typeInt32, input: int64(math.MaxInt32) + 1, expectErr: ErrOverflow}, + {name: "int64 to uint32", target: typeUint32, input: int64(789), expectVal: uint32(789)}, + {name: "int64 to uint32 OF", target: typeUint32, input: int64(math.MaxUint32) + 1, expectErr: ErrOverflow}, + {name: "int64 to uint32 UF", target: typeUint32, input: int64(-1), expectErr: ErrOverflow}, + {name: "int64 to uint64", target: typeUint64, input: int64(131415), expectVal: uint64(131415)}, + {name: "int64 to uint64 UF", target: typeUint64, input: int64(-1), expectErr: ErrUnderflow}, + + // int sources + {name: "int to int32", target: typeInt32, input: int(42), expectVal: int32(42)}, + {name: "int to int32 OF", target: typeInt32, input: int(math.MaxInt32) + 1, expectErr: ErrOverflow}, + {name: "int to uint32", target: typeUint32, input: int(101112), expectVal: uint32(101112)}, + {name: "int to uint32 OF", target: typeUint32, input: int(math.MaxUint32) + 1, expectErr: ErrOverflow}, + {name: "int to uint32 UF", target: typeUint32, input: int(-1), expectErr: ErrOverflow}, + {name: "int to uint64", target: typeUint64, input: int(161718), expectVal: uint64(161718)}, + {name: "int to uint64 UF", target: typeUint64, input: int(-1), expectErr: ErrUnderflow}, + {name: "int to uint", target: typeUint, input: int(202122), expectVal: uint(202122)}, + {name: "int to uint UF", target: typeUint, input: int(-1), expectErr: ErrOverflow}, + + // uint sources + {name: "uint to int", target: typeInt, input: uint(123), expectVal: int(123)}, + {name: "uint to int OF", target: typeInt, input: overflowU, expectErr: ErrOverflow}, + {name: "uint to int32", target: typeInt32, input: uint(321), expectVal: int32(321)}, + {name: "uint to int32 OF", target: typeInt32, input: overflowI32, expectErr: ErrOverflow}, + {name: "uint to int64", target: typeInt64, input: uint(654321), expectVal: int64(654321)}, + {name: "uint to uint", target: typeUint, input: uint(777), expectVal: uint(777)}, + {name: "uint to uint32", target: typeUint32, input: uint(888), expectVal: uint32(888)}, + {name: "uint to uint64", target: typeUint64, input: uint(999), expectVal: uint64(999)}, + + // float64 sources + {name: "float64 to uint64", target: typeUint64, input: float64(123), expectVal: uint64(123)}, + {name: "float64 to uint64 OF", target: typeUint64, input: math.Nextafter(float64(math.MaxUint64), math.Inf(1)), expectErr: ErrOverflow}, + {name: "float64 to uint64 UF", target: typeUint64, input: float64(-1), expectErr: ErrOverflow}, + {name: "float64 unsupported target", target: typeInt32, input: float64(1), expectErr: ErrUnsupported}, + + // unsupported cases + {name: "unsupported input type", target: typeInt32, input: "not-a-number", expectErr: ErrUnsupported}, + {name: "unsupported target type", target: typeInt64, input: int64(1), expectErr: ErrUnsupported}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + var ( + got any + err error + ) + + switch tc.target { + case typeInt32: + got, err = SafeConvertNumeric[int32](tc.input) + case typeUint32: + got, err = SafeConvertNumeric[uint32](tc.input) + case typeUint64: + got, err = SafeConvertNumeric[uint64](tc.input) + case typeInt64: + got, err = SafeConvertNumeric[int64](tc.input) + case typeUint: + got, err = SafeConvertNumeric[uint](tc.input) + case typeInt: + got, err = SafeConvertNumeric[int](tc.input) + default: + t.Fatalf("unexpected target type: %v", tc.target) + } + + if tc.expectErr != nil { + if !errors.Is(err, tc.expectErr) { + t.Fatalf("expected error %v, got %v", tc.expectErr, err) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tc.expectVal { + t.Fatalf("expected %v, got %v", tc.expectVal, got) + } + }) + } +} + +func BenchmarkSafeConvertNumeric(b *testing.B) { + benchmarks := []struct { + name string + target reflect.Type + input any + }{ + // int64 sources + {name: "int64 to int32", target: typeInt32, input: int64(123)}, + {name: "int64 to uint32", target: typeUint32, input: int64(789)}, + {name: "int64 to uint64", target: typeUint64, input: int64(131415)}, + + // int sources + {name: "int to int32", target: typeInt32, input: int(456)}, + {name: "int to uint32", target: typeUint32, input: int(101112)}, + {name: "int to uint64", target: typeUint64, input: int(161718)}, + {name: "int to uint", target: typeUint, input: int(202122)}, + + // uint sources + {name: "uint to int", target: typeInt, input: uint(123)}, + {name: "uint to int32", target: typeInt32, input: uint(321)}, + {name: "uint to uint32", target: typeUint32, input: uint(888)}, + {name: "uint to uint64", target: typeUint64, input: uint(999)}, + {name: "uint to uint", target: typeUint, input: uint(202122)}, + + // float64 sources + {name: "float64 to uint64", target: typeUint64, input: float64(123)}, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + b.ReportAllocs() + switch bm.target { + case typeInt32: + for i := 0; i < b.N; i++ { + _, _ = SafeConvertNumeric[int32](bm.input) + } + case typeUint32: + for i := 0; i < b.N; i++ { + _, _ = SafeConvertNumeric[uint32](bm.input) + } + case typeUint64: + for i := 0; i < b.N; i++ { + _, _ = SafeConvertNumeric[uint64](bm.input) + } + case typeInt64: + for i := 0; i < b.N; i++ { + _, _ = SafeConvertNumeric[int64](bm.input) + } + case typeUint: + for i := 0; i < b.N; i++ { + _, _ = SafeConvertNumeric[uint](bm.input) + } + case typeInt: + for i := 0; i < b.N; i++ { + _, _ = SafeConvertNumeric[int](bm.input) + } + default: + b.Fatalf("unexpected target type: %v", bm.target) + } + }) + } +} diff --git a/mongo/change_stream.go b/mongo/change_stream.go index d5ad8058ec..53ade7b20d 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -403,7 +403,7 @@ func (cs *ChangeStream) storeResumeToken() error { func (cs *ChangeStream) buildPipelineSlice(pipeline any) error { val := reflect.ValueOf(pipeline) - if !val.IsValid() || !(val.Kind() == reflect.Slice) { + if !val.IsValid() || (val.Kind() != reflect.Slice) { cs.err = errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid") return cs.err } diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index cb9d8cb4d8..78f2e1e679 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -15,6 +15,7 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/driverutil" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" @@ -407,8 +408,19 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, tota return 0, dst[:l], nil } - dst = fn.updateLength(dst, opsIdx, int32(len(dst[opsIdx:]))) - nsDst = fn.updateLength(nsDst, nsIdx, int32(len(nsDst[nsIdx:]))) + dstLenI32, err := mathutil.SafeConvertNumeric[int32](len(dst[opsIdx:])) + if err != nil { + return 0, nil, err + } + + dst = fn.updateLength(dst, opsIdx, dstLenI32) + + nsDstLenI32, err := mathutil.SafeConvertNumeric[int32](len(nsDst[nsIdx:])) + if err != nil { + return 0, nil, err + } + + nsDst = fn.updateLength(nsDst, nsIdx, nsDstLenI32) dst = append(dst, nsDst...) mb.retryMode = driver.RetryNone @@ -600,7 +612,12 @@ type clientInsertDoc struct { func (d *clientInsertDoc) marshal(bsonOpts *options.BSONOptions, registry *bson.Registry) (any, bsoncore.Document, error) { uidx, doc := bsoncore.AppendDocumentStart(nil) - doc = bsoncore.AppendInt32Element(doc, "insert", int32(d.namespace)) + namespaceI32, err := mathutil.SafeConvertNumeric[int32](d.namespace) + if err != nil { + return nil, nil, err + } + + doc = bsoncore.AppendInt32Element(doc, "insert", namespaceI32) f, err := marshal(d.document, bsonOpts, registry) if err != nil { return nil, nil, err @@ -631,7 +648,12 @@ type clientUpdateDoc struct { func (d *clientUpdateDoc) marshal(bsonOpts *options.BSONOptions, registry *bson.Registry) (bsoncore.Document, error) { uidx, doc := bsoncore.AppendDocumentStart(nil) - doc = bsoncore.AppendInt32Element(doc, "update", int32(d.namespace)) + namespaceI32, err := mathutil.SafeConvertNumeric[int32](d.namespace) + if err != nil { + return nil, err + } + + doc = bsoncore.AppendInt32Element(doc, "update", namespaceI32) if d.filter == nil { return nil, fmt.Errorf("update filter cannot be nil") @@ -702,7 +724,12 @@ type clientDeleteDoc struct { func (d *clientDeleteDoc) marshal(bsonOpts *options.BSONOptions, registry *bson.Registry) (bsoncore.Document, error) { didx, doc := bsoncore.AppendDocumentStart(nil) - doc = bsoncore.AppendInt32Element(doc, "delete", int32(d.namespace)) + namespaceI32, err := mathutil.SafeConvertNumeric[int32](d.namespace) + if err != nil { + return nil, err + } + + doc = bsoncore.AppendInt32Element(doc, "delete", namespaceI32) if d.filter == nil { return nil, fmt.Errorf("delete filter cannot be nil") diff --git a/mongo/client_encryption_test.go b/mongo/client_encryption_test.go index 1c49b3f707..ea967a0f0f 100644 --- a/mongo/client_encryption_test.go +++ b/mongo/client_encryption_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package mongo diff --git a/mongo/collection.go b/mongo/collection.go index af1960d970..663194f9b5 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -16,6 +16,7 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/csfle" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/internal/mongoutil" "go.mongodb.org/mongo-driver/v2/internal/optionsutil" "go.mongodb.org/mongo-driver/v2/internal/serverselector" @@ -191,8 +192,8 @@ func (coll *Collection) Database() *Database { // // The opts parameter can be used to specify options for the operation (see the options.BulkWriteOptions documentation.) func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, - opts ...options.Lister[options.BulkWriteOptions]) (*BulkWriteResult, error) { - + opts ...options.Lister[options.BulkWriteOptions], +) (*BulkWriteResult, error) { if len(models) == 0 { return nil, fmt.Errorf("invalid models: %w", ErrEmptySlice) } @@ -263,7 +264,6 @@ func (coll *Collection) insert( documents []any, opts ...options.Lister[options.InsertManyOptions], ) ([]any, error) { - if ctx == nil { ctx = context.Background() } @@ -374,8 +374,8 @@ func (coll *Collection) insert( // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/insert/. func (coll *Collection) InsertOne(ctx context.Context, document any, - opts ...options.Lister[options.InsertOneOptions]) (*InsertOneResult, error) { - + opts ...options.Lister[options.InsertOneOptions], +) (*InsertOneResult, error) { args, err := mongoutil.NewOptions(opts...) if err != nil { return nil, err @@ -431,7 +431,6 @@ func (coll *Collection) InsertMany( documents any, opts ...options.Lister[options.InsertManyOptions], ) (*InsertManyResult, error) { - dv := reflect.ValueOf(documents) if dv.Kind() != reflect.Slice { return nil, fmt.Errorf("invalid documents: %w", ErrNotSlice) @@ -483,7 +482,6 @@ func (coll *Collection) delete( expectedRr returnResult, args *options.DeleteManyOptions, ) (*DeleteResult, error) { - if ctx == nil { ctx = context.Background() } @@ -644,7 +642,6 @@ func (coll *Collection) updateOrReplace( sort any, args *options.UpdateManyOptions, ) (*UpdateResult, error) { - if ctx == nil { ctx = context.Background() } @@ -1023,7 +1020,7 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption op.AllowDiskUse(*args.AllowDiskUse) } // ignore batchSize of 0 with $out - if args.BatchSize != nil && !(*args.BatchSize == 0 && hasOutputStage) { + if args.BatchSize != nil && (*args.BatchSize != 0 || !hasOutputStage) { op.BatchSize(*args.BatchSize) cursorOpts.BatchSize = *args.BatchSize } @@ -1117,7 +1114,8 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption // // The opts parameter can be used to specify options for the operation (see the options.CountOptions documentation). func (coll *Collection) CountDocuments(ctx context.Context, filter any, - opts ...options.Lister[options.CountOptions]) (int64, error) { + opts ...options.Lister[options.CountOptions], +) (int64, error) { if ctx == nil { ctx = context.Background() } @@ -1385,7 +1383,8 @@ func (coll *Collection) Distinct( // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/find/. func (coll *Collection) Find(ctx context.Context, filter any, - opts ...options.Lister[options.FindOptions]) (*Cursor, error) { + opts ...options.Lister[options.FindOptions], +) (*Cursor, error) { args, err := mongoutil.NewOptions(opts...) if err != nil { return nil, err @@ -1404,7 +1403,6 @@ func (coll *Collection) find( omitMaxTimeMS bool, args *options.FindOptions, ) (cur *Cursor, err error) { - if ctx == nil { ctx = context.Background() } @@ -1502,7 +1500,13 @@ func (coll *Collection) find( limit = -1 * limit op.SingleBatch(true) } - cursorOpts.Limit = int32(limit) + + var convErr error + cursorOpts.Limit, convErr = mathutil.SafeConvertNumeric[int32](limit) + if convErr != nil { + return nil, convErr + } + op.Limit(limit) } if args.Max != nil { @@ -1607,8 +1611,8 @@ func newFindArgsFromFindOneArgs(args *options.FindOneOptions) *options.FindOptio // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/find/. func (coll *Collection) FindOne(ctx context.Context, filter any, - opts ...options.Lister[options.FindOneOptions]) *SingleResult { - + opts ...options.Lister[options.FindOneOptions], +) *SingleResult { if ctx == nil { ctx = context.Background() } @@ -1698,8 +1702,8 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd func (coll *Collection) FindOneAndDelete( ctx context.Context, filter any, - opts ...options.Lister[options.FindOneAndDeleteOptions]) *SingleResult { - + opts ...options.Lister[options.FindOneAndDeleteOptions], +) *SingleResult { f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} @@ -1782,7 +1786,6 @@ func (coll *Collection) FindOneAndReplace( replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions], ) *SingleResult { - f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} @@ -1884,8 +1887,8 @@ func (coll *Collection) FindOneAndUpdate( ctx context.Context, filter any, update any, - opts ...options.Lister[options.FindOneAndUpdateOptions]) *SingleResult { - + opts ...options.Lister[options.FindOneAndUpdateOptions], +) *SingleResult { if ctx == nil { ctx = context.Background() } @@ -1994,8 +1997,8 @@ func (coll *Collection) FindOneAndUpdate( // The opts parameter can be used to specify options for change stream creation (see the options.ChangeStreamOptions // documentation). func (coll *Collection) Watch(ctx context.Context, pipeline any, - opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) { - + opts ...options.Lister[options.ChangeStreamOptions], +) (*ChangeStream, error) { csConfig := changeStreamConfig{ readConcern: coll.readConcern, readPreference: coll.readPreference, @@ -2139,7 +2142,12 @@ func toDocument(co *options.Collation) bson.Raw { doc = bsoncore.AppendStringElement(doc, "caseFirst", co.CaseFirst) } if co.Strength != 0 { - doc = bsoncore.AppendInt32Element(doc, "strength", int32(co.Strength)) + strength, err := mathutil.SafeConvertNumeric[int32](co.Strength) + if err != nil { + panic(fmt.Errorf("collation strength %d overflows int32: %w", co.Strength, err)) + } + + doc = bsoncore.AppendInt32Element(doc, "strength", strength) } if co.NumericOrdering { doc = bsoncore.AppendBooleanElement(doc, "numericOrdering", true) diff --git a/mongo/errors.go b/mongo/errors.go index 234445ab86..b360d2214f 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -364,10 +364,12 @@ func hasErrorCode(srvErr ServerError, code int) bool { return false } -var _ ServerError = CommandError{} -var _ ServerError = WriteError{} -var _ ServerError = WriteException{} -var _ ServerError = BulkWriteException{} +var ( + _ ServerError = CommandError{} + _ ServerError = WriteError{} + _ ServerError = WriteException{} + _ ServerError = BulkWriteException{} +) var _ error = ClientBulkWriteException{} diff --git a/mongo/gridfs_download_stream.go b/mongo/gridfs_download_stream.go index c7967b748f..1897648939 100644 --- a/mongo/gridfs_download_stream.go +++ b/mongo/gridfs_download_stream.go @@ -14,6 +14,7 @@ import ( "time" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" ) // ErrMissingChunk indicates that the number of chunks read from the server is @@ -260,7 +261,12 @@ func (ds *GridFSDownloadStream) fillBuffer(ctx context.Context) error { var chunkIndexInt32 int32 if chunkIndexInt64, ok := chunkIndex.Int64OK(); ok { - chunkIndexInt32 = int32(chunkIndexInt64) + var convErr error + + chunkIndexInt32, convErr = mathutil.SafeConvertNumeric[int32](chunkIndexInt64) + if convErr != nil { + return convErr + } } else { chunkIndexInt32 = chunkIndex.Int32() } @@ -278,7 +284,10 @@ func (ds *GridFSDownloadStream) fillBuffer(ctx context.Context) error { _, dataBytes := data.Binary() copied := copy(ds.buffer, dataBytes) - bytesLen := int32(len(dataBytes)) + bytesLen, err := mathutil.SafeConvertNumeric[int32](len(dataBytes)) + if err != nil { + return err + } if ds.expectedChunk == ds.numChunks { // final chunk can be fewer than ds.chunkSize bytes bytesDownloaded := int64(ds.chunkSize) * (int64(ds.expectedChunk) - int64(1)) diff --git a/mongo/gridfs_upload_stream.go b/mongo/gridfs_upload_stream.go index ba7883d978..84c6502e97 100644 --- a/mongo/gridfs_upload_stream.go +++ b/mongo/gridfs_upload_stream.go @@ -7,14 +7,13 @@ package mongo import ( - "errors" - "context" - "time" - + "errors" "math" + "time" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" ) // uploadBufferSize is the size in bytes of one stream batch. Chunks will be written to the db after the sum of chunk @@ -99,10 +98,7 @@ func (us *GridFSUploadStream) Write(p []byte) (int, error) { } origLen := len(p) - for { - if len(p) == 0 { - break - } + for len(p) != 0 { n := copy(us.buffer[us.bufferIndex:], p) // copy as much as possible p = p[n:] @@ -163,10 +159,14 @@ func (us *GridFSUploadStream) uploadChunks(ctx context.Context, uploadPartial bo endIndex = us.bufferIndex } chunkData := us.buffer[i:endIndex] + chunkIndex, err := mathutil.SafeConvertNumeric[int32](us.chunkIndex) + if err != nil { + return err + } docs[us.chunkIndex-begChunkIndex] = bson.D{ {"_id", bson.NewObjectID()}, {"files_id", us.FileID}, - {"n", int32(us.chunkIndex)}, + {"n", chunkIndex}, {"data", bson.Binary{Subtype: 0x00, Data: chunkData}}, } us.chunkIndex++ diff --git a/mongo/mongo.go b/mongo/mongo.go index 703115fdc7..1117226977 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -18,6 +18,7 @@ import ( "strings" "go.mongodb.org/mongo-driver/v2/internal/codecutil" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" @@ -179,7 +180,12 @@ func ensureID( // Remove and re-write the BSON document length header. const int32Len = 4 doc = append(doc, olddoc[int32Len:]...) - doc = bsoncore.UpdateLength(doc, 0, int32(len(doc))) + docLength, err := mathutil.SafeConvertNumeric[int32](len(doc)) + if err != nil { + return nil, nil, fmt.Errorf("document length %d overflows int32: %w", len(doc), err) + } + + doc = bsoncore.UpdateLength(doc, 0, docLength) return doc, oid, nil } diff --git a/mongo/read_write_concern_spec_test.go b/mongo/read_write_concern_spec_test.go index 5118cc511d..05b3f5b689 100644 --- a/mongo/read_write_concern_spec_test.go +++ b/mongo/read_write_concern_spec_test.go @@ -246,16 +246,16 @@ func writeConcernFromRaw(t *testing.T, wcRaw bson.Raw) writeConcern { switch val.Type { case bson.TypeInt32: w := int(val.Int32()) - wc.WriteConcern.W = w + wc.W = w case bson.TypeString: - wc.WriteConcern.W = val.StringValue() + wc.W = val.StringValue() default: t.Fatalf("unexpected type for w: %v", val.Type) } case "journal": wc.jSet = true j := val.Boolean() - wc.WriteConcern.Journal = &j + wc.Journal = &j case "wtimeoutMS": // Do nothing, this field is deprecated t.Skip("the wtimeoutMS write concern option is not supported") default: diff --git a/x/bsonx/bsoncore/array.go b/x/bsonx/bsoncore/array.go index bfedbc8661..5f661f66f7 100644 --- a/x/bsonx/bsoncore/array.go +++ b/x/bsonx/bsoncore/array.go @@ -11,6 +11,8 @@ import ( "io" "strconv" "strings" + + "go.mongodb.org/mongo-driver/v2/internal/mathutil" ) // NewArrayLengthError creates and returns an error for when the length of an array exceeds the @@ -64,7 +66,13 @@ func (a Array) DebugString() string { var ok bool for length > 1 { elem, rem, ok = ReadElement(rem) - length -= int32(len(elem)) + + elemLenI32, err := mathutil.SafeConvertNumeric[int32](len(elem)) + if err != nil { + panic("array element length exceeds max int32") + } + + length -= elemLenI32 if !ok { buf.WriteString(fmt.Sprintf("", length)) break @@ -135,7 +143,11 @@ func (a Array) StringN(n int) (string, bool) { } elem, rem, ok = ReadElement(rem) - length -= int32(len(elem)) + elemLen, err := mathutil.SafeConvertNumeric[int32](len(elem)) + if err != nil { + return "", false + } + length -= elemLen // Exit on malformed element. if !ok || length < 0 { return "", false @@ -183,13 +195,17 @@ func (a Array) Validate() error { var keyNum int64 for length > 1 { elem, rem, ok = ReadElement(rem) - length -= int32(len(elem)) + elemLen, err := mathutil.SafeConvertNumeric[int32](len(elem)) + if err != nil { + return NewArrayLengthError(len(elem), len(a)) + } + length -= elemLen if !ok { return NewInsufficientBytesError(a, rem) } // validate element - err := elem.Validate() + err = elem.Validate() if err != nil { return err } diff --git a/x/bsonx/bsoncore/bsoncore.go b/x/bsonx/bsoncore/bsoncore.go index 13219ee547..8b7b42c676 100644 --- a/x/bsonx/bsoncore/bsoncore.go +++ b/x/bsonx/bsoncore/bsoncore.go @@ -14,6 +14,9 @@ import ( "strconv" "strings" "time" + + "go.mongodb.org/mongo-driver/v2/internal/binaryutil" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" ) const ( @@ -206,7 +209,11 @@ func AppendDocumentEnd(dst []byte, index int32) ([]byte, error) { return dst, fmt.Errorf("not enough bytes available after index to write length") } dst = append(dst, 0x00) - dst = UpdateLength(dst, index, int32(len(dst[index:]))) + length, err := mathutil.SafeConvertNumeric[int32](len(dst[index:])) + if err != nil { + return dst, fmt.Errorf("document length %d exceeds int32", len(dst[index:])) + } + dst = UpdateLength(dst, index, length) return dst, nil } @@ -227,7 +234,12 @@ func BuildDocument(dst []byte, elems ...[]byte) []byte { dst = append(dst, elem...) } dst = append(dst, 0x00) - dst = UpdateLength(dst, idx, int32(len(dst[idx:]))) + length, err := mathutil.SafeConvertNumeric[int32](len(dst[idx:])) + if err != nil { + panic(fmt.Errorf("document length %d overflows int32: %w", len(dst[idx:]), err)) + } + + dst = UpdateLength(dst, idx, length) return dst } @@ -279,7 +291,12 @@ func BuildArray(dst []byte, values ...Value) []byte { dst = AppendValueElement(dst, strconv.Itoa(pos), val) } dst = append(dst, 0x00) - dst = UpdateLength(dst, idx, int32(len(dst[idx:]))) + length, err := mathutil.SafeConvertNumeric[int32](len(dst[idx:])) + if err != nil { + panic(fmt.Errorf("array length %d overflows int32: %w", len(dst[idx:]), err)) + } + + dst = UpdateLength(dst, idx, length) return dst } @@ -297,7 +314,12 @@ func AppendBinary(dst []byte, subtype byte, b []byte) []byte { if subtype == 0x02 { return appendBinarySubtype2(dst, subtype, b) } - dst = append(appendLength(dst, int32(len(b))), subtype) + length, err := mathutil.SafeConvertNumeric[int32](len(b)) + if err != nil { + panic(fmt.Errorf("binary length %d overflows int32: %w", len(b), err)) + } + + dst = append(appendLength(dst, length), subtype) return append(dst, b...) } @@ -385,7 +407,7 @@ func ReadBoolean(src []byte) (bool, []byte, bool) { } // AppendDateTime will append dt to dst and return the extended buffer. -func AppendDateTime(dst []byte, dt int64) []byte { return appendi64(dst, dt) } +func AppendDateTime(dst []byte, dt int64) []byte { return binaryutil.AppendI64(dst, dt) } // AppendDateTimeElement will append a BSON datetime element using key and dt to dst // and return the extended buffer. @@ -395,7 +417,9 @@ func AppendDateTimeElement(dst []byte, key string, dt int64) []byte { // ReadDateTime will read an int64 datetime from src. If there are not enough bytes it // will return false. -func ReadDateTime(src []byte) (int64, []byte, bool) { return readi64(src) } +func ReadDateTime(src []byte) (int64, []byte, bool) { + return binaryutil.ReadI64(src) +} // AppendTime will append time as a BSON DateTime to dst and return the extended buffer. func AppendTime(dst []byte, t time.Time) []byte { @@ -411,7 +435,7 @@ func AppendTimeElement(dst []byte, key string, t time.Time) []byte { // ReadTime will read an time.Time datetime from src. If there are not enough bytes it // will return false. func ReadTime(src []byte) (time.Time, []byte, bool) { - dt, rem, ok := readi64(src) + dt, rem, ok := binaryutil.ReadI64(src) return time.Unix(dt/1e3, dt%1e3*1e6), rem, ok } @@ -501,7 +525,11 @@ func ReadSymbol(src []byte) (symbol string, rem []byte, ok bool) { return readst // AppendCodeWithScope will append code and scope to dst and return the extended buffer. func AppendCodeWithScope(dst []byte, code string, scope []byte) []byte { - length := int32(4 + 4 + len(code) + 1 + len(scope)) // length of cws, length of code, code, 0x00, scope + lengthVal := 4 + 4 + len(code) + 1 + len(scope) // length of cws, length of code, code, 0x00, scope + length, err := mathutil.SafeConvertNumeric[int32](lengthVal) + if err != nil { + panic(fmt.Errorf("code with scope length %d overflows int32: %w", lengthVal, err)) + } dst = appendLength(dst, length) return append(appendstring(dst, code), scope...) @@ -535,7 +563,9 @@ func ReadCodeWithScope(src []byte) (code string, scope []byte, rem []byte, ok bo } // AppendInt32 will append i32 to dst and return the extended buffer. -func AppendInt32(dst []byte, i32 int32) []byte { return appendi32(dst, i32) } +func AppendInt32(dst []byte, i32 int32) []byte { + return binaryutil.AppendI32(dst, i32) +} // AppendInt32Element will append a BSON int32 element using key and i32 to dst // and return the extended buffer. @@ -545,7 +575,9 @@ func AppendInt32Element(dst []byte, key string, i32 int32) []byte { // ReadInt32 will read an int32 from src. If there are not enough bytes it // will return false. -func ReadInt32(src []byte) (int32, []byte, bool) { return readi32(src) } +func ReadInt32(src []byte) (int32, []byte, bool) { + return binaryutil.ReadI32(src) +} // AppendTimestamp will append t and i to dst and return the extended buffer. func AppendTimestamp(dst []byte, t, i uint32) []byte { @@ -573,7 +605,7 @@ func ReadTimestamp(src []byte) (t, i uint32, rem []byte, ok bool) { } // AppendInt64 will append i64 to dst and return the extended buffer. -func AppendInt64(dst []byte, i64 int64) []byte { return appendi64(dst, i64) } +func AppendInt64(dst []byte, i64 int64) []byte { return binaryutil.AppendI64(dst, i64) } // AppendInt64Element will append a BSON int64 element using key and i64 to dst // and return the extended buffer. @@ -583,7 +615,9 @@ func AppendInt64Element(dst []byte, key string, i64 int64) []byte { // ReadInt64 will read an int64 from src. If there are not enough bytes it // will return false. -func ReadInt64(src []byte) (int64, []byte, bool) { return readi64(src) } +func ReadInt64(src []byte) (int64, []byte, bool) { + return binaryutil.ReadI64(src) +} // AppendDecimal128 will append high and low parts of a d128 to dst and return the extended buffer. func AppendDecimal128(dst []byte, high, low uint64) []byte { @@ -681,7 +715,13 @@ func valueLength(src []byte, t Type) (int32, bool) { ok = false break } - length = int32(int64(regex) + 1 + int64(pattern) + 1) + sum := int64(regex) + 1 + int64(pattern) + 1 + length64, err := mathutil.SafeConvertNumeric[int32](sum) + if err != nil { + ok = false + break + } + length = length64 default: ok = false } @@ -702,54 +742,36 @@ func readValue(src []byte, t Type) ([]byte, []byte, bool) { // and the []byte with reserved space. func ReserveLength(dst []byte) (int32, []byte) { index := len(dst) - return int32(index), append(dst, 0x00, 0x00, 0x00, 0x00) + index32, err := mathutil.SafeConvertNumeric[int32](index) + if err != nil { + panic(fmt.Errorf("reserve length index %d overflows int32: %w", index, err)) + } + return index32, append(dst, 0x00, 0x00, 0x00, 0x00) } // UpdateLength updates the length at index with length and returns the []byte. func UpdateLength(dst []byte, index, length int32) []byte { - binary.LittleEndian.PutUint32(dst[index:], uint32(length)) + if length < 0 { + panic("UpdateLength: negative length") + } + + binaryutil.PutI32(dst, int(index), length) return dst } -func appendLength(dst []byte, l int32) []byte { return appendi32(dst, l) } - -func appendi32(dst []byte, i32 int32) []byte { - b := []byte{0, 0, 0, 0} - binary.LittleEndian.PutUint32(b, uint32(i32)) - return append(dst, b...) -} +func appendLength(dst []byte, l int32) []byte { return binaryutil.AppendI32(dst, l) } // ReadLength reads an int32 length from src and returns the length and the remaining bytes. If // there aren't enough bytes to read a valid length, src is returned unomdified and the returned // bool will be false. func ReadLength(src []byte) (int32, []byte, bool) { - ln, src, ok := readi32(src) + ln, src, ok := binaryutil.ReadI32(src) if ln < 0 { return ln, src, false } return ln, src, ok } -func readi32(src []byte) (int32, []byte, bool) { - if len(src) < 4 { - return 0, src, false - } - return int32(binary.LittleEndian.Uint32(src)), src[4:], true -} - -func appendi64(dst []byte, i64 int64) []byte { - b := []byte{0, 0, 0, 0, 0, 0, 0, 0} - binary.LittleEndian.PutUint64(b, uint64(i64)) - return append(dst, b...) -} - -func readi64(src []byte) (int64, []byte, bool) { - if len(src) < 8 { - return 0, src, false - } - return int64(binary.LittleEndian.Uint64(src)), src[8:], true -} - func appendu32(dst []byte, u32 uint32) []byte { b := []byte{0, 0, 0, 0} binary.LittleEndian.PutUint32(b, u32) @@ -796,7 +818,11 @@ func readcstringbytes(src []byte) ([]byte, []byte, bool) { } func appendstring(dst []byte, s string) []byte { - l := int32(len(s) + 1) + lengthVal := len(s) + 1 + l, err := mathutil.SafeConvertNumeric[int32](lengthVal) + if err != nil { + panic(fmt.Errorf("string length %d overflows int32: %w", lengthVal, err)) + } dst = appendLength(dst, l) dst = append(dst, s...) return append(dst, 0x00) @@ -831,9 +857,20 @@ func readLengthBytes(src []byte) ([]byte, []byte, bool) { } func appendBinarySubtype2(dst []byte, subtype byte, b []byte) []byte { - dst = appendLength(dst, int32(len(b)+4)) // The bytes we'll encode need to be 4 larger for the length bytes + lengthVal := len(b) + 4 // The bytes we'll encode need to be 4 larger for the length bytes + length, err := mathutil.SafeConvertNumeric[int32](lengthVal) + if err != nil { + panic(fmt.Errorf("binary subtype 0x02 length %d overflows int32: %w", lengthVal, err)) + } + + dst = appendLength(dst, length) dst = append(dst, subtype) - dst = appendLength(dst, int32(len(b))) + bLength, err := mathutil.SafeConvertNumeric[int32](len(b)) + if err != nil { + panic(fmt.Errorf("binary subtype 0x02 data length %d overflows int32: %w", len(b), err)) + } + + dst = appendLength(dst, bLength) return append(dst, b...) } diff --git a/x/bsonx/bsoncore/document.go b/x/bsonx/bsoncore/document.go index 03e78e5997..8048a89c45 100644 --- a/x/bsonx/bsoncore/document.go +++ b/x/bsonx/bsoncore/document.go @@ -12,6 +12,9 @@ import ( "io" "strconv" "strings" + + "go.mongodb.org/mongo-driver/v2/internal/binaryutil" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" ) // ValidationError is an error type returned when attempting to validate a document or array. @@ -112,7 +115,7 @@ func newBufferFromReader(r io.Reader) ([]byte, error) { return nil, err } - length, _, _ := readi32(lengthBytes[:]) // ignore ok since we always have enough bytes to read a length + length, _, _ := binaryutil.ReadI32(lengthBytes[:]) // ignore ok since we always have enough bytes to read a length if length < 0 { return nil, ErrInvalidLength } @@ -156,11 +159,16 @@ func (d Document) LookupErr(key ...string) (Value, error) { var elem Element for length > 1 { elem, rem, ok = ReadElement(rem) - length -= int32(len(elem)) + elemLen, err := mathutil.SafeConvertNumeric[int32](len(elem)) + if err != nil { + return Value{}, fmt.Errorf("element length %d overflows int32: %w", len(elem), err) + } + length -= elemLen if !ok { return Value{}, NewInsufficientBytesError(d, rem) } // We use `KeyBytes` rather than `Key` to avoid a needless string alloc. + // nolint:gosec // G602: key length is validated at function entry if string(elem.KeyBytes()) != key[0] { continue } @@ -216,7 +224,11 @@ func indexErr(b []byte, index uint) (Element, error) { var elem Element for length > 1 { elem, rem, ok = ReadElement(rem) - length -= int32(len(elem)) + elemLen, err := mathutil.SafeConvertNumeric[int32](len(elem)) + if err != nil { + return nil, fmt.Errorf("element length %d overflows int32: %w", len(elem), err) + } + length -= elemLen if !ok { return nil, NewInsufficientBytesError(b, rem) } @@ -246,7 +258,12 @@ func (d Document) DebugString() string { var ok bool for length > 1 { elem, rem, ok = ReadElement(rem) - length -= int32(len(elem)) + elemLen, err := mathutil.SafeConvertNumeric[int32](len(elem)) + if err != nil { + buf.WriteString(fmt.Sprintf("", len(elem))) + break + } + length -= elemLen if !ok { buf.WriteString(fmt.Sprintf("", length)) break @@ -314,7 +331,11 @@ func (d Document) StringN(n int) (string, bool) { } elem, rem, ok = ReadElement(rem) - length -= int32(len(elem)) + elemLen, err := mathutil.SafeConvertNumeric[int32](len(elem)) + if err != nil { + return "", false + } + length -= elemLen // Exit on malformed element. if !ok || length < 0 { return "", false @@ -351,7 +372,11 @@ func (d Document) Elements() ([]Element, error) { var elems []Element for length > 1 { elem, rem, ok = ReadElement(rem) - length -= int32(len(elem)) + elemLen, err := mathutil.SafeConvertNumeric[int32](len(elem)) + if err != nil { + return elems, fmt.Errorf("element length %d overflows int32: %w", len(elem), err) + } + length -= elemLen if !ok { return elems, NewInsufficientBytesError(d, rem) } @@ -382,7 +407,11 @@ func values(b []byte) ([]Value, error) { var vals []Value for length > 1 { elem, rem, ok = ReadElement(rem) - length -= int32(len(elem)) + elemLen, err := mathutil.SafeConvertNumeric[int32](len(elem)) + if err != nil { + return vals, fmt.Errorf("element length %d overflows int32: %w", len(elem), err) + } + length -= elemLen if !ok { return vals, NewInsufficientBytesError(b, rem) } @@ -412,11 +441,15 @@ func (d Document) Validate() error { for length > 1 { elem, rem, ok = ReadElement(rem) - length -= int32(len(elem)) + elemLen, err := mathutil.SafeConvertNumeric[int32](len(elem)) + if err != nil { + return fmt.Errorf("element length %d overflows int32: %w", len(elem), err) + } + length -= elemLen if !ok { return NewInsufficientBytesError(d, rem) } - err := elem.Validate() + err = elem.Validate() if err != nil { return err } diff --git a/x/bsonx/bsoncore/value.go b/x/bsonx/bsoncore/value.go index fec33029e0..3f901a8b9b 100644 --- a/x/bsonx/bsoncore/value.go +++ b/x/bsonx/bsoncore/value.go @@ -20,6 +20,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/bsoncoreutil" "go.mongodb.org/mongo-driver/v2/internal/decimal128" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" ) // ElementTypeError specifies that a method to obtain a BSON value an incorrect type was called on a bson.Value. @@ -83,7 +84,11 @@ func (v Value) AsInt32() int32 { if !ok { panic(NewInsufficientBytesError(v.Data, v.Data)) } - i32 = int32(i64) + var convErr error + i32, convErr = mathutil.SafeConvertNumeric[int32](i64) + if convErr != nil { + panic(fmt.Sprintf("bsoncore.Value.AsInt32: int64 value %d overflows int32", i64)) + } case TypeDecimal128: panic(ElementTypeError{"bsoncore.Value.AsInt32", v.Type}) } @@ -97,13 +102,15 @@ func (v Value) AsInt32OK() (int32, bool) { return 0, false } var i32 int32 + var convErr error switch v.Type { case TypeDouble: f64, _, ok := ReadDouble(v.Data) if !ok { return 0, false } - i32 = int32(f64) + + i32, convErr = mathutil.SafeConvertNumeric[int32](int64(f64)) case TypeInt32: var ok bool i32, _, ok = ReadInt32(v.Data) @@ -115,10 +122,16 @@ func (v Value) AsInt32OK() (int32, bool) { if !ok { return 0, false } - i32 = int32(i64) + + i32, convErr = mathutil.SafeConvertNumeric[int32](i64) case TypeDecimal128: return 0, false } + + if convErr != nil { + return 0, false + } + return i32, true } diff --git a/x/mongo/driver/auth/gssapi.go b/x/mongo/driver/auth/gssapi.go index 0ae7571d23..5f915efdd4 100644 --- a/x/mongo/driver/auth/gssapi.go +++ b/x/mongo/driver/auth/gssapi.go @@ -5,8 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build gssapi && (windows || linux || darwin) -// +build gssapi -// +build windows linux darwin package auth diff --git a/x/mongo/driver/auth/gssapi_not_enabled.go b/x/mongo/driver/auth/gssapi_not_enabled.go index e50553c7a1..2f45412100 100644 --- a/x/mongo/driver/auth/gssapi_not_enabled.go +++ b/x/mongo/driver/auth/gssapi_not_enabled.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build !gssapi -// +build !gssapi package auth diff --git a/x/mongo/driver/auth/gssapi_not_supported.go b/x/mongo/driver/auth/gssapi_not_supported.go index 12046ff67c..b53f2df31b 100644 --- a/x/mongo/driver/auth/gssapi_not_supported.go +++ b/x/mongo/driver/auth/gssapi_not_supported.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build gssapi && !windows && !linux && !darwin -// +build gssapi,!windows,!linux,!darwin package auth diff --git a/x/mongo/driver/auth/gssapi_test.go b/x/mongo/driver/auth/gssapi_test.go index eedfe428e9..8a89c54678 100644 --- a/x/mongo/driver/auth/gssapi_test.go +++ b/x/mongo/driver/auth/gssapi_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build gssapi -// +build gssapi package auth diff --git a/x/mongo/driver/auth/sasl.go b/x/mongo/driver/auth/sasl.go index 659fe45e2d..5908dda5d5 100644 --- a/x/mongo/driver/auth/sasl.go +++ b/x/mongo/driver/auth/sasl.go @@ -11,6 +11,7 @@ import ( "fmt" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation" @@ -127,9 +128,14 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *driver.AuthConfig, return nil } + cidI32, err := mathutil.SafeConvertNumeric[int32](cid) + if err != nil { + return fmt.Errorf("conversation ID %d is too large to encode: %w", cid, err) + } + doc := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendInt32Element(nil, "saslContinue", 1), - bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)), + bsoncore.AppendInt32Element(nil, "conversationId", cidI32), bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), ) saslContinueCmd := operation.NewCommand(doc). diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index 00674fbd2d..b5533db1d0 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -19,6 +19,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/codecutil" "go.mongodb.org/mongo-driver/v2/internal/csot" "go.mongodb.org/mongo-driver/v2/internal/driverutil" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet" @@ -213,7 +214,14 @@ func NewBatchCursor( } if firstBatch != nil { - bc.numReturned = int32(firstBatch.Count()) + numReturned, err := mathutil.SafeConvertNumeric[int32](int64(firstBatch.Count())) + if err != nil { + // This should never happen unless the server is returning more than 2 + // billion documents in a single batch. + return nil, fmt.Errorf("batch size %d exceeds int32 limit: %w", firstBatch.Count(), err) + } + + bc.numReturned = numReturned } bc.currentBatch = firstBatch @@ -446,8 +454,15 @@ func (bc *BatchCursor) getMore(ctx context.Context) { bc.currentBatch.List = batch bc.currentBatch.Reset() + numReturned, err := mathutil.SafeConvertNumeric[int32](int64(bc.currentBatch.Count())) + if err != nil { + // This should never happen unless the server is returning more than 2 + // billion documents in a single batch. + return fmt.Errorf("batch size %d exceeds int32 limit: %w", bc.currentBatch.Count(), err) + } + // Required for legacy operations which don't support limit. - bc.numReturned += int32(bc.currentBatch.Count()) + bc.numReturned += numReturned pbrt, err := response.LookupErr("cursor", "postBatchResumeToken") if err != nil { @@ -559,9 +574,11 @@ type loadBalancedCursorDeployment struct { conn *mnet.Connection } -var _ Deployment = (*loadBalancedCursorDeployment)(nil) -var _ Server = (*loadBalancedCursorDeployment)(nil) -var _ ErrorProcessor = (*loadBalancedCursorDeployment)(nil) +var ( + _ Deployment = (*loadBalancedCursorDeployment)(nil) + _ Server = (*loadBalancedCursorDeployment)(nil) + _ ErrorProcessor = (*loadBalancedCursorDeployment)(nil) +) func (lbcd *loadBalancedCursorDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) { return lbcd, nil diff --git a/x/mongo/driver/batches.go b/x/mongo/driver/batches.go index 51a32bc962..fbc8cec8b5 100644 --- a/x/mongo/driver/batches.go +++ b/x/mongo/driver/batches.go @@ -7,9 +7,11 @@ package driver import ( + "fmt" "io" "strconv" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" ) @@ -56,7 +58,13 @@ func (b *Batches) AppendBatchSequence(dst []byte, maxCount, totalSize int) (int, if n == 0 { return 0, dst[:l], nil } - dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) + + dlen, err := mathutil.SafeConvertNumeric[int32](len(dst[idx:])) + if err != nil { + return 0, nil, fmt.Errorf("batch sequence size %d exceeds maximum int32 size: %w", len(dst[idx:]), err) + } + + dst = bsoncore.UpdateLength(dst, idx, dlen) return n, dst, nil } diff --git a/x/mongo/driver/compression.go b/x/mongo/driver/compression.go index ccf4534edc..9b5ee967e8 100644 --- a/x/mongo/driver/compression.go +++ b/x/mongo/driver/compression.go @@ -15,6 +15,7 @@ import ( "github.com/golang/snappy" "github.com/klauspost/compress/zstd" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" ) @@ -162,9 +163,17 @@ func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { l, err := snappy.DecodedLen(in) if err != nil { return nil, fmt.Errorf("decoding compressed length %w", err) - } else if int32(l) != opts.UncompressedSize { + } + + li32, err := mathutil.SafeConvertNumeric[int32](l) + if err != nil { + return nil, fmt.Errorf("decompression size %v overflows int32: %w", l, err) + } + + if li32 != opts.UncompressedSize { return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) } + out := make([]byte, opts.UncompressedSize) return snappy.Decode(out, in) case wiremessage.CompressorZLib: diff --git a/x/mongo/driver/connstring/connstring_spec_test.go b/x/mongo/driver/connstring/connstring_spec_test.go index b6243a471e..21aff46b64 100644 --- a/x/mongo/driver/connstring/connstring_spec_test.go +++ b/x/mongo/driver/connstring/connstring_spec_test.go @@ -112,7 +112,7 @@ func runTest(t *testing.T, test testCase, warningsError bool) { // URI options, but don't with some of the older things, we do a switch on the filename // here. We are trying to not break existing user applications that have unrecognized // options. - if test.Valid && !(test.Warning && warningsError) { + if test.Valid && (!test.Warning || !warningsError) { require.NoError(t, err) } else { require.Error(t, err) diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index 4e1f2c78c5..79ce8961af 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -10,6 +10,7 @@ import ( "context" "errors" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/mongo/address" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" @@ -105,7 +106,13 @@ func MakeReply(doc bsoncore.Document) []byte { dst = wiremessage.AppendReplyStartingFrom(dst, 0) dst = wiremessage.AppendReplyNumberReturned(dst, 1) dst = append(dst, doc...) - return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) + + wmLen, err := mathutil.SafeConvertNumeric[int32](len(dst[idx:])) + if err != nil { + panic("reply size exceeds int32") + } + + return bsoncore.UpdateLength(dst, idx, wmLen) } // GetCommandFromQueryWireMessage returns the command sent in an OP_QUERY wire message. diff --git a/x/mongo/driver/drivertest/opmsg_deployment.go b/x/mongo/driver/drivertest/opmsg_deployment.go index 3887f1dcfd..acbcf4bc0b 100644 --- a/x/mongo/driver/drivertest/opmsg_deployment.go +++ b/x/mongo/driver/drivertest/opmsg_deployment.go @@ -9,11 +9,13 @@ package drivertest import ( "context" "errors" + "fmt" "time" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/csot" "go.mongodb.org/mongo-driver/v2/internal/driverutil" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/mongo/address" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" @@ -56,8 +58,10 @@ type connection struct { responses []bson.D // responses to send when ReadWireMessage is called } -var _ mnet.ReadWriteCloser = &connection{} -var _ mnet.Describer = &connection{} +var ( + _ mnet.ReadWriteCloser = &connection{} + _ mnet.Describer = &connection{} +) // Write is a no-op. func (c *connection) Write(context.Context, []byte) error { @@ -86,7 +90,13 @@ func (c *connection) Read(_ context.Context) ([]byte, error) { dst = wiremessage.AppendMsgSectionType(dst, wiremessage.SingleDocument) resBytes, _ := bson.Marshal(nextRes) dst = append(dst, resBytes...) - dst = bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))) + + wmLen, err := mathutil.SafeConvertNumeric[int32](len(dst[wmindex:])) + if err != nil { + return nil, fmt.Errorf("response size %d exceeds int32: %w", len(dst[wmindex:]), err) + } + + dst = bsoncore.UpdateLength(dst, wmindex, wmLen) return dst, nil } @@ -132,11 +142,13 @@ type MockDeployment struct { updates chan description.Topology } -var _ driver.Deployment = &MockDeployment{} -var _ driver.Server = &MockDeployment{} -var _ driver.Connector = &MockDeployment{} -var _ driver.Disconnector = &MockDeployment{} -var _ driver.Subscriber = &MockDeployment{} +var ( + _ driver.Deployment = &MockDeployment{} + _ driver.Server = &MockDeployment{} + _ driver.Connector = &MockDeployment{} + _ driver.Disconnector = &MockDeployment{} + _ driver.Subscriber = &MockDeployment{} +) // SelectServer implements the Deployment interface. This method does not use the // description.SelectedServer provided and instead returns itself. The Connections returned from the diff --git a/x/mongo/driver/integration/main_test.go b/x/mongo/driver/integration/main_test.go index f1c24bfd0d..97f85de244 100644 --- a/x/mongo/driver/integration/main_test.go +++ b/x/mongo/driver/integration/main_test.go @@ -133,7 +133,7 @@ func dropCollection(t *testing.T, dbname, colname string) { err := operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "drop", colname))). Database(dbname).ServerSelector(&serverselector.Write{}).Deployment(integtest.Topology(t)). Execute(context.Background()) - if de, ok := err.(driver.Error); err != nil && !(ok && de.NamespaceNotFound()) { + if de, ok := err.(driver.Error); err != nil && (!ok || !de.NamespaceNotFound()) { require.NoError(t, err) } } diff --git a/x/mongo/driver/mongocrypt/binary.go b/x/mongo/driver/mongocrypt/binary.go index 4e4b51d74b..116b2b5e1a 100644 --- a/x/mongo/driver/mongocrypt/binary.go +++ b/x/mongo/driver/mongocrypt/binary.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package mongocrypt diff --git a/x/mongo/driver/mongocrypt/binary_test.go b/x/mongo/driver/mongocrypt/binary_test.go index 6f122745d4..c6c86455dc 100644 --- a/x/mongo/driver/mongocrypt/binary_test.go +++ b/x/mongo/driver/mongocrypt/binary_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package mongocrypt diff --git a/x/mongo/driver/mongocrypt/errors.go b/x/mongo/driver/mongocrypt/errors.go index 3401e73849..d80bacfe12 100644 --- a/x/mongo/driver/mongocrypt/errors.go +++ b/x/mongo/driver/mongocrypt/errors.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package mongocrypt diff --git a/x/mongo/driver/mongocrypt/errors_not_enabled.go b/x/mongo/driver/mongocrypt/errors_not_enabled.go index 706a0f9e75..970de44625 100644 --- a/x/mongo/driver/mongocrypt/errors_not_enabled.go +++ b/x/mongo/driver/mongocrypt/errors_not_enabled.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build !cse -// +build !cse package mongocrypt diff --git a/x/mongo/driver/mongocrypt/mongocrypt.go b/x/mongo/driver/mongocrypt/mongocrypt.go index 91b950c371..c2c47a6334 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt.go +++ b/x/mongo/driver/mongocrypt/mongocrypt.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package mongocrypt diff --git a/x/mongo/driver/mongocrypt/mongocrypt_context.go b/x/mongo/driver/mongocrypt/mongocrypt_context.go index 5a34516533..b7f2c325b7 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt_context.go +++ b/x/mongo/driver/mongocrypt/mongocrypt_context.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package mongocrypt diff --git a/x/mongo/driver/mongocrypt/mongocrypt_context_not_enabled.go b/x/mongo/driver/mongocrypt/mongocrypt_context_not_enabled.go index a04272781b..14bc764d18 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt_context_not_enabled.go +++ b/x/mongo/driver/mongocrypt/mongocrypt_context_not_enabled.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build !cse -// +build !cse package mongocrypt diff --git a/x/mongo/driver/mongocrypt/mongocrypt_kms_context.go b/x/mongo/driver/mongocrypt/mongocrypt_kms_context.go index 49baa37f2e..fb3c354331 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt_kms_context.go +++ b/x/mongo/driver/mongocrypt/mongocrypt_kms_context.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package mongocrypt diff --git a/x/mongo/driver/mongocrypt/mongocrypt_kms_context_not_enabled.go b/x/mongo/driver/mongocrypt/mongocrypt_kms_context_not_enabled.go index 7968897648..acb24e4b07 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt_kms_context_not_enabled.go +++ b/x/mongo/driver/mongocrypt/mongocrypt_kms_context_not_enabled.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build !cse -// +build !cse package mongocrypt diff --git a/x/mongo/driver/mongocrypt/mongocrypt_not_enabled.go b/x/mongo/driver/mongocrypt/mongocrypt_not_enabled.go index 6e21e64917..1b78b05ccd 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt_not_enabled.go +++ b/x/mongo/driver/mongocrypt/mongocrypt_not_enabled.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build !cse -// +build !cse // Package mongocrypt is intended for internal use only. It is made available to // facilitate use cases that require access to internal MongoDB driver diff --git a/x/mongo/driver/mongocrypt/mongocrypt_test.go b/x/mongo/driver/mongocrypt/mongocrypt_test.go index 964991d537..3a1369d38c 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt_test.go +++ b/x/mongo/driver/mongocrypt/mongocrypt_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build cse -// +build cse package mongocrypt diff --git a/x/mongo/driver/ocsp/ocsp_test.go b/x/mongo/driver/ocsp/ocsp_test.go index 58a89315bb..0137865a2c 100644 --- a/x/mongo/driver/ocsp/ocsp_test.go +++ b/x/mongo/driver/ocsp/ocsp_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build go1.13 -// +build go1.13 package ocsp diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 3e720eba63..9dad790df0 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -24,6 +24,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/driverutil" "go.mongodb.org/mongo-driver/v2/internal/handshake" "go.mongodb.org/mongo-driver/v2/internal/logger" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/internal/serverselector" "go.mongodb.org/mongo-driver/v2/mongo/address" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" @@ -475,7 +476,7 @@ func (op Operation) getServerAndConnection( server, err := op.selectServer(ctx, requestID, deprioritized) if err != nil { if op.Client != nil && - !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() { + (!op.Client.Committing && !op.Client.Aborting) && op.Client.TransactionRunning() { err = Error{ Message: err.Error(), Labels: []string{TransientTransactionError}, @@ -747,7 +748,6 @@ func (op Operation) Execute(ctx context.Context) error { var moreToCome bool var startedInfo startedInformation *wm, moreToCome, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) - if err != nil { return err } @@ -792,7 +792,7 @@ func (op Operation) Execute(ctx context.Context) error { serverConnID: startedInfo.serverConnID, redacted: startedInfo.redacted, serviceID: startedInfo.serviceID, - serverAddress: desc.Server.Addr, + serverAddress: desc.Addr, } startedTime := time.Now() @@ -845,7 +845,7 @@ func (op Operation) Execute(ctx context.Context) error { retryableErr := tt.Retryable(connDesc.Kind, connDesc.WireVersion) preRetryWriteLabelVersion := connDesc.WireVersion != nil && connDesc.WireVersion.Max < 9 inTransaction := op.Client != nil && - !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() + (!op.Client.Committing && !op.Client.Aborting) && op.Client.TransactionRunning() // If retry is enabled and the operation isn't in a transaction, add a RetryableWriteError label for // retryable errors from pre-4.4 servers if retryableErr && preRetryWriteLabelVersion && retryEnabled && !inTransaction { @@ -900,10 +900,16 @@ func (op Operation) Execute(ctx context.Context) error { } } if op.Client != nil && op.Client.Committing && tt.WriteConcernError != nil { + // MongoDB only supports 64-bit platforms where int is 32 bits, and BSON + // encodes the code field as an int32. Drivers currently deserialize, so + // representing errorCodes as 32-bit integers keeps us aligned with what + // the server actually sends. + code, _ := mathutil.SafeConvertNumeric[int32](tt.WriteConcernError.Code) + // When running commitTransaction we return WriteConcernErrors as an Error. err := Error{ Name: tt.WriteConcernError.Name, - Code: int32(tt.WriteConcernError.Code), + Code: code, Message: tt.WriteConcernError.Message, Labels: tt.Labels, Raw: tt.Raw, @@ -961,7 +967,7 @@ func (op Operation) Execute(ctx context.Context) error { retryableErr = tt.RetryableWrite(connDesc.WireVersion) preRetryWriteLabelVersion := connDesc.WireVersion != nil && connDesc.WireVersion.Max < 9 inTransaction := op.Client != nil && - !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() + (!op.Client.Committing && !op.Client.Aborting) && op.Client.TransactionRunning() // If retryWrites is enabled and the operation isn't in a transaction, add a RetryableWriteError label // for network errors and retryable errors from pre-4.4 servers if retryEnabled && !inTransaction && @@ -1081,7 +1087,7 @@ func (op Operation) retryable(desc description.Server) bool { return true } if retryWritesSupported(desc) && - op.Client != nil && !(op.Client.TransactionInProgress() || op.Client.TransactionStarting()) && + op.Client != nil && (!op.Client.TransactionInProgress() && !op.Client.TransactionStarting()) && op.WriteConcern.Acknowledged() { return true } @@ -1089,7 +1095,7 @@ func (op Operation) retryable(desc description.Server) bool { if op.Client != nil && (op.Client.Committing || op.Client.Aborting) { return true } - if op.Client == nil || !(op.Client.TransactionInProgress() || op.Client.TransactionStarting()) { + if op.Client == nil || (!op.Client.TransactionInProgress() && !op.Client.TransactionStarting()) { return true } } @@ -1335,10 +1341,8 @@ func (op Operation) createMsgWireMessage( if err != nil { return dst, nil, err } - retryWrite := false - if op.retryable(conn.Description()) && op.RetryMode != nil && op.RetryMode.Enabled() { - retryWrite = true - } + retryWrite := op.retryable(conn.Description()) && op.RetryMode != nil && op.RetryMode.Enabled() + dst, err = op.addSession(dst, desc, retryWrite) if err != nil { return dst, nil, err @@ -1478,7 +1482,13 @@ func (op Operation) createWireMessage( moreToCome = true } info.requestID = requestID - return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), moreToCome, info, nil + + wmLen, err := mathutil.SafeConvertNumeric[int32](len(dst[wmindex:])) + if err != nil { + return nil, false, info, fmt.Errorf("wire message length %d exceeds maximum int32: %w", len(dst[wmindex:]), err) + } + + return bsoncore.UpdateLength(dst, wmindex, wmLen), moreToCome, info, nil } func (op Operation) addEncryptCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) (int, []byte, error) { @@ -1842,7 +1852,7 @@ func (op Operation) createReadPref(desc description.SelectedServer, isOpQuery bo // TODO if supplied readPreference was "overwritten" with primary in description.selectForReplicaSet. if desc.Server.Kind == description.ServerKindStandalone || (isOpQuery && desc.Server.Kind != description.ServerKindMongos) || - op.Type == Write || (op.IsOutputAggregate && desc.Server.WireVersion.Max < 13) { + op.Type == Write || (op.IsOutputAggregate && desc.WireVersion.Max < 13) { // Don't send read preference for: // 1. all standalones // 2. non-mongos when using OP_QUERY @@ -2000,7 +2010,14 @@ func (Operation) decodeOpReply(wm []byte) opReply { reply.err = ErrCursorNotFound return reply } - if reply.numReturned != int32(len(reply.documents)) { + + replyLen, err := mathutil.SafeConvertNumeric[int32](len(reply.documents)) + if err != nil { + reply.err = fmt.Errorf("number of documents %d exceeds maximum int32: %w", len(reply.documents), err) + return reply + } + + if reply.numReturned != replyLen { reply.err = ErrReplyDocumentMismatch return reply } diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index dd232f0623..41ccfd8ac8 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -576,7 +576,7 @@ func (h *Hello) handshakeCommand(dst []byte, desc description.SelectedServer) ([ func (h *Hello) command(dst []byte, desc description.SelectedServer) ([]byte, error) { // Use "hello" if topology is LoadBalanced, API version is declared or server // has responded with "helloOk". Otherwise, use legacy hello. - if h.loadBalanced || h.serverAPI != nil || desc.Server.HelloOK { + if h.loadBalanced || h.serverAPI != nil || desc.HelloOK { dst = bsoncore.AppendInt32Element(dst, "hello", 1) } else { dst = bsoncore.AppendInt32Element(dst, handshake.LegacyHello, 1) diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index cd89b707b4..616f71bc29 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -422,9 +422,10 @@ func (c *Client) StartTransaction(opts *TransactionOptions) error { // CheckCommitTransaction checks to see if allowed to commit transaction and returns // an error if not allowed. func (c *Client) CheckCommitTransaction() error { - if c.TransactionState == None { + switch c.TransactionState { + case None: return ErrNoTransactStarted - } else if c.TransactionState == Aborted { + case Aborted: return ErrCommitAfterAbort } return nil @@ -462,12 +463,12 @@ func (c *Client) UpdateCommitTransactionWriteConcern() { // CheckAbortTransaction checks to see if allowed to abort transaction and returns // an error if not allowed. func (c *Client) CheckAbortTransaction() error { - switch { - case c.TransactionState == None: + switch c.TransactionState { + case None: return ErrNoTransactStarted - case c.TransactionState == Committed: + case Committed: return ErrAbortAfterCommit - case c.TransactionState == Aborted: + case Aborted: return ErrAbortTwice } return nil @@ -506,13 +507,14 @@ func (c *Client) ApplyCommand(desc description.Server) error { // Do not change state if committing after already committed return nil } - if c.TransactionState == Starting { + switch c.TransactionState { + case Starting: c.TransactionState = InProgress // If this is in a transaction and the server is a mongos, pin it if desc.Kind == description.ServerKindMongos { c.PinnedServerAddr = &desc.Addr } - } else if c.TransactionState == Committed || c.TransactionState == Aborted { + case Committed, Aborted: c.TransactionState = None return c.clearTransactionOpts() } diff --git a/x/mongo/driver/topology/CMAP_spec_test.go b/x/mongo/driver/topology/CMAP_spec_test.go index a0579ff552..29f6276375 100644 --- a/x/mongo/driver/topology/CMAP_spec_test.go +++ b/x/mongo/driver/topology/CMAP_spec_test.go @@ -427,10 +427,8 @@ func runOperationInThread(t *testing.T, operation map[string]any, testInfo *test t.Fatalf("unable to find thread to wait for: %v", threadName) } - for { - if atomic.LoadInt32(&thread.JobsCompleted) == atomic.LoadInt32(&thread.JobsAssigned) { - break - } + for atomic.LoadInt32(&thread.JobsCompleted) != atomic.LoadInt32(&thread.JobsAssigned) { + } case "waitForEvent": var targetCount int diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index b1bf1d13f1..6d6c42b52e 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -13,6 +13,7 @@ import ( "errors" "fmt" "io" + "math" "net" "strings" "sync" @@ -20,6 +21,7 @@ import ( "time" "go.mongodb.org/mongo-driver/v2/internal/driverutil" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/mongo/address" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" @@ -226,7 +228,6 @@ func (c *connection) connect(ctx context.Context) (err error) { HTTPClient: c.config.httpClient, } tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts) - if err != nil { return ConnectionError{Wrapped: err, init: true, message: fmt.Sprintf("failed to configure TLS for %s", c.addr)} } @@ -427,7 +428,12 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) { // read the length as an int32 - size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:])) + rawSize := binary.LittleEndian.Uint32(wmSizeBytes[:]) + if rawSize > uint32(math.MaxInt32) { + return 0, fmt.Errorf("message length exceeds int32 max: %d", rawSize) + } + + size := int32(rawSize) if size < 4 { return 0, fmt.Errorf("malformed message length: %d", size) @@ -475,7 +481,12 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, // reading messages from an exhaust cursor. n, err := io.ReadFull(c.nc, sizeBuf[:]) if err != nil { - if l := int32(n); l == 0 && isCSOTTimeout(err) { + nI32, convErr := mathutil.SafeConvertNumeric[int32](n) + if convErr != nil { + return nil, "incomplete read of message header", convErr + } + + if l := nI32; l == 0 && isCSOTTimeout(err) { c.awaitRemainingBytes = &l } return nil, "incomplete read of message header", err @@ -490,7 +501,12 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, n, err = io.ReadFull(c.nc, dst[4:]) if err != nil { - remainingBytes := size - 4 - int32(n) + nI32, convErr := mathutil.SafeConvertNumeric[int32](n) + if convErr != nil { + return dst, "incomplete read of full message", convErr + } + + remainingBytes := size - 4 - nI32 if remainingBytes > 0 && isCSOTTimeout(err) { c.awaitRemainingBytes = &remainingBytes } @@ -586,15 +602,17 @@ func (c *connection) SetOIDCTokenGenID(genID uint64) { // *connection to a Handshaker. type initConnection struct{ *connection } -var _ mnet.ReadWriteCloser = initConnection{} -var _ mnet.Describer = initConnection{} -var _ mnet.Streamer = initConnection{} +var ( + _ mnet.ReadWriteCloser = initConnection{} + _ mnet.Describer = initConnection{} + _ mnet.Streamer = initConnection{} +) func (c initConnection) Description() description.Server { if c.connection == nil { return description.Server{} } - return c.connection.desc + return c.desc } func (c initConnection) Close() error { return nil } func (c initConnection) ID() string { return c.id } @@ -606,18 +624,23 @@ func (c initConnection) LocalAddress() address.Address { } return address.Address(c.nc.LocalAddr().String()) } + func (c initConnection) Write(ctx context.Context, wm []byte) error { return c.writeWireMessage(ctx, wm) } + func (c initConnection) Read(ctx context.Context) ([]byte, error) { return c.readWireMessage(ctx) } + func (c initConnection) SetStreaming(streaming bool) { c.setStreaming(streaming) } + func (c initConnection) CurrentlyStreaming() bool { return c.getCurrentlyStreaming() } + func (c initConnection) SupportsStreaming() bool { return c.canStream } @@ -639,11 +662,13 @@ type Connection struct { mu sync.RWMutex } -var _ mnet.ReadWriteCloser = (*Connection)(nil) -var _ mnet.Describer = (*Connection)(nil) -var _ mnet.Compressor = (*Connection)(nil) -var _ mnet.Pinner = (*Connection)(nil) -var _ driver.Expirable = (*Connection)(nil) +var ( + _ mnet.ReadWriteCloser = (*Connection)(nil) + _ mnet.Describer = (*Connection)(nil) + _ mnet.Compressor = (*Connection)(nil) + _ mnet.Pinner = (*Connection)(nil) + _ driver.Expirable = (*Connection)(nil) +) // WriteWireMessage handles writing a wire message to the underlying connection. func (c *Connection) Write(ctx context.Context, wm []byte) error { @@ -684,7 +709,13 @@ func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) { } idx, dst := wiremessage.AppendHeaderStart(dst, reqid, respto, wiremessage.OpCompressed) dst = wiremessage.AppendCompressedOriginalOpCode(dst, origcode) - dst = wiremessage.AppendCompressedUncompressedSize(dst, int32(len(rem))) + + remI32, err := mathutil.SafeConvertNumeric[int32](len(rem)) + if err != nil { + return nil, err + } + + dst = wiremessage.AppendCompressedUncompressedSize(dst, remI32) dst = wiremessage.AppendCompressedCompressorID(dst, c.connection.compressor) opts := driver.CompressionOpts{ Compressor: c.connection.compressor, @@ -696,7 +727,11 @@ func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) { return nil, err } dst = wiremessage.AppendCompressedCompressedMessage(dst, compressed) - return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))), nil + length, err := mathutil.SafeConvertNumeric[int32](len(dst[idx:])) + if err != nil { + return nil, err + } + return bsoncore.UpdateLength(dst, idx, length), nil } // Description returns the server description of the server this connection is connected to. diff --git a/x/mongo/driver/topology/connection_errors_test.go b/x/mongo/driver/topology/connection_errors_test.go index 0a7259035d..a95afde989 100644 --- a/x/mongo/driver/topology/connection_errors_test.go +++ b/x/mongo/driver/topology/connection_errors_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build go1.13 -// +build go1.13 package topology diff --git a/x/mongo/driver/topology/errors.go b/x/mongo/driver/topology/errors.go index dee1941542..ffb1180083 100644 --- a/x/mongo/driver/topology/errors.go +++ b/x/mongo/driver/topology/errors.go @@ -16,6 +16,7 @@ import ( "strings" "time" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" ) @@ -118,13 +119,24 @@ func (w WaitQueueTimeoutError) Error() string { msg := fmt.Sprintf("%s; total connections: %d, maxPoolSize: %d, ", errorMsg, w.totalConnections, w.maxPoolSize) if pinnedConnections := w.pinnedConnections; pinnedConnections != nil { - openConnectionCount := uint64(w.totalConnections) - + var totcalConnectionsWarning string + + totalConnections, err := mathutil.SafeConvertNumeric[uint64](w.totalConnections) + if err != nil { + totcalConnectionsWarning = fmt.Sprintf("[WARNING]: totalConnections is negative (%d); this may indicate a bug in the driver. ", + w.totalConnections) + totalConnections = 0 + } + + openConnectionCount := totalConnections - pinnedConnections.cursorConnections - pinnedConnections.transactionConnections - msg += fmt.Sprintf("connections in use by cursors: %d, connections in use by transactions: %d, connections in use by other operations: %d, ", + + msg += fmt.Sprintf("connections in use by cursors: %d, connections in use by transactions: %d, connections in use by other operations: %d%s, ", pinnedConnections.cursorConnections, pinnedConnections.transactionConnections, openConnectionCount, + totcalConnectionsWarning, ) } msg += fmt.Sprintf("idle connections: %d, wait duration: %s", w.availableConnections, w.waitDuration.String()) diff --git a/x/mongo/driver/topology/fsm.go b/x/mongo/driver/topology/fsm.go index 26e2fc320b..7e96b3c341 100644 --- a/x/mongo/driver/topology/fsm.go +++ b/x/mongo/driver/topology/fsm.go @@ -125,7 +125,7 @@ func (f *fsm) apply(s description.Server) (description.Topology, description.Ser SetName: f.SetName, } - f.Topology.SessionTimeoutMinutes = serverTimeoutMinutes + f.SessionTimeoutMinutes = serverTimeoutMinutes if _, ok := f.findServer(s.Addr); !ok { return f.Topology, s @@ -157,7 +157,7 @@ func (f *fsm) apply(s description.Server) (description.Topology, description.Ser SupportedWireVersions.Min, MinSupportedMongoDBVersion, ) - f.Topology.CompatibilityErr = f.compatibilityErr + f.CompatibilityErr = f.compatibilityErr return f.Topology, s } @@ -169,7 +169,7 @@ func (f *fsm) apply(s description.Server) (description.Topology, description.Ser server.WireVersion.Min, SupportedWireVersions.Max, ) - f.Topology.CompatibilityErr = f.compatibilityErr + f.CompatibilityErr = f.compatibilityErr return f.Topology, s } } diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 23656864b3..bc81cd5919 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -10,6 +10,7 @@ import ( "context" "fmt" "io" + "math" "net" "sync" "sync/atomic" @@ -18,6 +19,7 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/logger" + "go.mongodb.org/mongo-driver/v2/internal/mathutil" "go.mongodb.org/mongo-driver/v2/mongo/address" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" ) @@ -158,7 +160,6 @@ func logPoolMessage(pool *pool, msg string, keysAndValues ...any) { ServerHost: host, ServerPort: port, }, keysAndValues...)...) - } type reason struct { @@ -241,7 +242,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool { var ctx context.Context ctx, pool.cancelBackgroundCtx = context.WithCancel(context.Background()) - for i := 0; i < int(pool.maxConnecting); i++ { + for i := uint64(0); i < pool.maxConnecting; i++ { pool.backgroundDone.Add(1) go pool.createConnections(ctx, pool.backgroundDone) } @@ -1357,7 +1358,16 @@ func (p *pool) maintain(ctx context.Context, wg *sync.WaitGroup) { // the number of connections requested to max 10 at a time to prevent overshooting // minPoolSize in case other checkOut() calls are requesting new connections, too. total := p.totalConnectionCount() - n := int(p.minSize) - total - len(wantConns) + + // Since this is a forced mod 10 operation, we can safely ignore overflows. + minSize, err := mathutil.SafeConvertNumeric[int](p.minSize) + if err != nil { + // Ignore overflow here because this is only used to drive pool growth + // hints. + minSize = math.MaxInt + } + + n := minSize - total - len(wantConns) if n > 10 { n = 10 } diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index ae3009573b..7626ff1920 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build go1.13 -// +build go1.13 package topology diff --git a/x/mongo/driver/topology/tls_connection_source_1_16.go b/x/mongo/driver/topology/tls_connection_source_1_16.go index 387f2ec04d..dad53d010d 100644 --- a/x/mongo/driver/topology/tls_connection_source_1_16.go +++ b/x/mongo/driver/topology/tls_connection_source_1_16.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build !go1.17 -// +build !go1.17 package topology diff --git a/x/mongo/driver/topology/tls_connection_source_1_17.go b/x/mongo/driver/topology/tls_connection_source_1_17.go index c9822e0609..8306c6c437 100644 --- a/x/mongo/driver/topology/tls_connection_source_1_17.go +++ b/x/mongo/driver/topology/tls_connection_source_1_17.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build go1.17 -// +build go1.17 package topology diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index 5343934ccf..a4ce875837 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -772,7 +772,7 @@ func (t *Topology) pollSRVRecords(hosts string) { return } topoKind := t.Description().Kind - if !(topoKind == description.Unknown || topoKind == description.TopologyKindSharded) { + if topoKind != description.Unknown && topoKind != description.TopologyKindSharded { break } diff --git a/x/mongo/driver/topology/topology_errors_test.go b/x/mongo/driver/topology/topology_errors_test.go index 3b4306e606..0fd45da036 100644 --- a/x/mongo/driver/topology/topology_errors_test.go +++ b/x/mongo/driver/topology/topology_errors_test.go @@ -5,7 +5,6 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 //go:build go1.13 -// +build go1.13 package topology diff --git a/x/mongo/driver/wiremessage/wiremessage.go b/x/mongo/driver/wiremessage/wiremessage.go index f0b1b5e533..4f0a937377 100644 --- a/x/mongo/driver/wiremessage/wiremessage.go +++ b/x/mongo/driver/wiremessage/wiremessage.go @@ -16,9 +16,11 @@ package wiremessage import ( "bytes" "encoding/binary" + "math" "strings" "sync/atomic" + "go.mongodb.org/mongo-driver/v2/internal/binaryutil" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -219,18 +221,18 @@ const ( // starts in dst and the updated slice. func AppendHeaderStart(dst []byte, reqid, respto int32, opcode OpCode) (index int32, b []byte) { index, dst = bsoncore.ReserveLength(dst) - dst = appendi32(dst, reqid) - dst = appendi32(dst, respto) - dst = appendi32(dst, int32(opcode)) + dst = binaryutil.AppendI32(dst, reqid) + dst = binaryutil.AppendI32(dst, respto) + dst = binaryutil.AppendI32(dst, int32(opcode)) return index, dst } // AppendHeader appends a header to dst. func AppendHeader(dst []byte, length, reqid, respto int32, opcode OpCode) []byte { - dst = appendi32(dst, length) - dst = appendi32(dst, reqid) - dst = appendi32(dst, respto) - dst = appendi32(dst, int32(opcode)) + dst = binaryutil.AppendI32(dst, length) + dst = binaryutil.AppendI32(dst, reqid) + dst = binaryutil.AppendI32(dst, respto) + dst = binaryutil.AppendI32(dst, int32(opcode)) return dst } @@ -240,26 +242,31 @@ func ReadHeader(src []byte) (length, requestID, responseTo int32, opcode OpCode, return 0, 0, 0, 0, src, false } - length = readi32unsafe(src) - requestID = readi32unsafe(src[4:]) - responseTo = readi32unsafe(src[8:]) - opcode = OpCode(readi32unsafe(src[12:])) + length = binaryutil.ReadI32Unsafe(src) + requestID = binaryutil.ReadI32Unsafe(src[4:]) + responseTo = binaryutil.ReadI32Unsafe(src[8:]) + opcode = OpCode(binaryutil.ReadI32Unsafe(src[12:])) return length, requestID, responseTo, opcode, src[16:], true } // AppendQueryFlags appends the flags for an OP_QUERY wire message. func AppendQueryFlags(dst []byte, flags QueryFlag) []byte { - return appendi32(dst, int32(flags)) + return binaryutil.AppendI32(dst, int32(flags)) } // AppendMsgFlags appends the flags for an OP_MSG wire message. func AppendMsgFlags(dst []byte, flags MsgFlag) []byte { - return appendi32(dst, int32(flags)) + if flags > MsgFlag(math.MaxInt32) { + panic("AppendMsgFlags: flag value exceeds int32 range") + } + + // nolint:gosec // G602: key length is validated at function entry + return binaryutil.AppendI32(dst, int32(flags)) } // AppendReplyFlags appends the flags for an OP_REPLY wire message. func AppendReplyFlags(dst []byte, flags ReplyFlag) []byte { - return appendi32(dst, int32(flags)) + return binaryutil.AppendI32(dst, int32(flags)) } // AppendMsgSectionType appends the section type to dst. @@ -274,37 +281,39 @@ func AppendQueryFullCollectionName(dst []byte, ns string) []byte { // AppendQueryNumberToSkip appends the number to skip to dst. func AppendQueryNumberToSkip(dst []byte, skip int32) []byte { - return appendi32(dst, skip) + return binaryutil.AppendI32(dst, skip) } // AppendQueryNumberToReturn appends the number to return to dst. func AppendQueryNumberToReturn(dst []byte, nor int32) []byte { - return appendi32(dst, nor) + return binaryutil.AppendI32(dst, nor) } // AppendReplyCursorID appends the cursor ID to dst. func AppendReplyCursorID(dst []byte, id int64) []byte { - return appendi64(dst, id) + return binaryutil.AppendI64(dst, id) } // AppendReplyStartingFrom appends the starting from field to dst. func AppendReplyStartingFrom(dst []byte, sf int32) []byte { - return appendi32(dst, sf) + return binaryutil.AppendI32(dst, sf) } // AppendReplyNumberReturned appends the number returned to dst. func AppendReplyNumberReturned(dst []byte, nr int32) []byte { - return appendi32(dst, nr) + return binaryutil.AppendI32(dst, nr) } // AppendCompressedOriginalOpCode appends the original opcode to dst. func AppendCompressedOriginalOpCode(dst []byte, opcode OpCode) []byte { - return appendi32(dst, int32(opcode)) + return binaryutil.AppendI32(dst, int32(opcode)) } // AppendCompressedUncompressedSize appends the uncompressed size of a // compressed wiremessage to dst. -func AppendCompressedUncompressedSize(dst []byte, size int32) []byte { return appendi32(dst, size) } +func AppendCompressedUncompressedSize(dst []byte, size int32) []byte { + return binaryutil.AppendI32(dst, size) +} // AppendCompressedCompressorID appends the ID of the compressor to dst. func AppendCompressedCompressorID(dst []byte, id CompressorID) []byte { @@ -316,7 +325,7 @@ func AppendCompressedCompressedMessage(dst []byte, msg []byte) []byte { return a // AppendGetMoreZero appends the zero field to dst. func AppendGetMoreZero(dst []byte) []byte { - return appendi32(dst, 0) + return binaryutil.AppendI32(dst, 0) } // AppendGetMoreFullCollectionName appends the fullCollectionName field to dst. @@ -326,43 +335,50 @@ func AppendGetMoreFullCollectionName(dst []byte, ns string) []byte { // AppendGetMoreNumberToReturn appends the numberToReturn field to dst. func AppendGetMoreNumberToReturn(dst []byte, numToReturn int32) []byte { - return appendi32(dst, numToReturn) + return binaryutil.AppendI32(dst, numToReturn) } // AppendGetMoreCursorID appends the cursorID field to dst. func AppendGetMoreCursorID(dst []byte, cursorID int64) []byte { - return appendi64(dst, cursorID) + return binaryutil.AppendI64(dst, cursorID) } // AppendKillCursorsZero appends the zero field to dst. func AppendKillCursorsZero(dst []byte) []byte { - return appendi32(dst, 0) + return binaryutil.AppendI32(dst, 0) } // AppendKillCursorsNumberIDs appends the numberOfCursorIDs field to dst. func AppendKillCursorsNumberIDs(dst []byte, numIDs int32) []byte { - return appendi32(dst, numIDs) + return binaryutil.AppendI32(dst, numIDs) } // AppendKillCursorsCursorIDs appends each the cursorIDs field to dst. func AppendKillCursorsCursorIDs(dst []byte, cursors []int64) []byte { for _, cursor := range cursors { - dst = appendi64(dst, cursor) + dst = binaryutil.AppendI64(dst, cursor) } return dst } // ReadMsgFlags reads the OP_MSG flags from src. func ReadMsgFlags(src []byte) (flags MsgFlag, rem []byte, ok bool) { - i32, rem, ok := readi32(src) + i32, rem, ok := binaryutil.ReadI32(src) + if i32 < 0 { + return 0, rem, false + } + return MsgFlag(i32), rem, ok } // IsMsgMoreToCome returns if the provided wire message is an OP_MSG with the more to come flag set. func IsMsgMoreToCome(wm []byte) bool { return len(wm) >= 20 && - OpCode(readi32unsafe(wm[12:16])) == OpMsg && - MsgFlag(readi32unsafe(wm[16:20]))&MoreToCome == MoreToCome + OpCode(binaryutil.ReadI32Unsafe(wm[12:16])) == OpMsg && + func() bool { + flagValue := binaryutil.ReadI32Unsafe(wm[16:20]) + return flagValue >= 0 && MsgFlag(flagValue)&MoreToCome == MoreToCome + }() } // ReadMsgSectionType reads the section type from src. @@ -405,7 +421,7 @@ func ReadMsgSectionDocumentSequence(src []byte) (identifier string, docs []bsonc // ReadMsgSectionRawDocumentSequence reads an identifier and document sequence from src and returns the raw document // sequence data. func ReadMsgSectionRawDocumentSequence(src []byte) (identifier string, data []byte, rem []byte, ok bool) { - length, rem, ok := readi32(src) + length, rem, ok := binaryutil.ReadI32(src) if !ok || int(length) > len(src) || length < 4 { return "", nil, src, false } @@ -423,9 +439,8 @@ func ReadMsgSectionRawDocumentSequence(src []byte) (identifier string, data []by } // ReadMsgChecksum reads a checksum from src. -func ReadMsgChecksum(src []byte) (checksum uint32, rem []byte, ok bool) { - i32, rem, ok := readi32(src) - return uint32(i32), rem, ok +func ReadMsgChecksum(src []byte) (uint32, []byte, bool) { + return readu32(src) } // ReadQueryFlags reads OP_QUERY flags from src. @@ -433,7 +448,7 @@ func ReadMsgChecksum(src []byte) (checksum uint32, rem []byte, ok bool) { // Deprecated: Construct wiremessages with OpMsg and use the ReadMsg* functions // instead. func ReadQueryFlags(src []byte) (flags QueryFlag, rem []byte, ok bool) { - i32, rem, ok := readi32(src) + i32, rem, ok := binaryutil.ReadI32(src) return QueryFlag(i32), rem, ok } @@ -450,7 +465,7 @@ func ReadQueryFullCollectionName(src []byte) (collname string, rem []byte, ok bo // Deprecated: Construct wiremessages with OpMsg and use the ReadMsg* functions // instead. func ReadQueryNumberToSkip(src []byte) (nts int32, rem []byte, ok bool) { - return readi32(src) + return binaryutil.ReadI32(src) } // ReadQueryNumberToReturn reads the number to return from src. @@ -458,7 +473,7 @@ func ReadQueryNumberToSkip(src []byte) (nts int32, rem []byte, ok bool) { // Deprecated: Construct wiremessages with OpMsg and use the ReadMsg* functions // instead. func ReadQueryNumberToReturn(src []byte) (ntr int32, rem []byte, ok bool) { - return readi32(src) + return binaryutil.ReadI32(src) } // ReadQueryQuery reads the query from src. @@ -478,24 +493,25 @@ func ReadQueryReturnFieldsSelector(src []byte) (rfs bsoncore.Document, rem []byt } // ReadReplyFlags reads OP_REPLY flags from src. -func ReadReplyFlags(src []byte) (flags ReplyFlag, rem []byte, ok bool) { - i32, rem, ok := readi32(src) +func ReadReplyFlags(src []byte) (ReplyFlag, []byte, bool) { + i32, rem, ok := binaryutil.ReadI32(src) + return ReplyFlag(i32), rem, ok } // ReadReplyCursorID reads a cursor ID from src. func ReadReplyCursorID(src []byte) (cursorID int64, rem []byte, ok bool) { - return readi64(src) + return binaryutil.ReadI64(src) } // ReadReplyStartingFrom reads the starting from src. func ReadReplyStartingFrom(src []byte) (startingFrom int32, rem []byte, ok bool) { - return readi32(src) + return binaryutil.ReadI32(src) } // ReadReplyNumberReturned reads the numbered returned from src. func ReadReplyNumberReturned(src []byte) (numberReturned int32, rem []byte, ok bool) { - return readi32(src) + return binaryutil.ReadI32(src) } // ReadReplyDocuments reads as many documents as possible from src @@ -521,14 +537,14 @@ func ReadReplyDocument(src []byte) (doc bsoncore.Document, rem []byte, ok bool) // ReadCompressedOriginalOpCode reads the original opcode from src. func ReadCompressedOriginalOpCode(src []byte) (opcode OpCode, rem []byte, ok bool) { - i32, rem, ok := readi32(src) + i32, rem, ok := binaryutil.ReadI32(src) return OpCode(i32), rem, ok } // ReadCompressedUncompressedSize reads the uncompressed size of a // compressed wiremessage to dst. func ReadCompressedUncompressedSize(src []byte) (size int32, rem []byte, ok bool) { - return readi32(src) + return binaryutil.ReadI32(src) } // ReadCompressedCompressorID reads the ID of the compressor to dst. @@ -541,12 +557,12 @@ func ReadCompressedCompressorID(src []byte) (id CompressorID, rem []byte, ok boo // ReadKillCursorsZero reads the zero field from src. func ReadKillCursorsZero(src []byte) (zero int32, rem []byte, ok bool) { - return readi32(src) + return binaryutil.ReadI32(src) } // ReadKillCursorsNumberIDs reads the numberOfCursorIDs field from src. func ReadKillCursorsNumberIDs(src []byte) (numIDs int32, rem []byte, ok bool) { - return readi32(src) + return binaryutil.ReadI32(src) } // ReadKillCursorsCursorIDs reads numIDs cursor IDs from src. @@ -554,7 +570,7 @@ func ReadKillCursorsCursorIDs(src []byte, numIDs int32) (cursorIDs []int64, rem var i int32 var id int64 for i = 0; i < numIDs; i++ { - id, src, ok = readi64(src) + id, src, ok = binaryutil.ReadI64(src) if !ok { return cursorIDs, src, false } @@ -564,39 +580,17 @@ func ReadKillCursorsCursorIDs(src []byte, numIDs int32) (cursorIDs []int64, rem return cursorIDs, src, true } -func appendi32(dst []byte, x int32) []byte { - b := []byte{0, 0, 0, 0} - binary.LittleEndian.PutUint32(b, uint32(x)) - return append(dst, b...) -} - -func appendi64(dst []byte, x int64) []byte { - b := []byte{0, 0, 0, 0, 0, 0, 0, 0} - binary.LittleEndian.PutUint64(b, uint64(x)) - return append(dst, b...) -} - func appendCString(b []byte, str string) []byte { b = append(b, str...) return append(b, 0x00) } -func readi32(src []byte) (int32, []byte, bool) { +func readu32(src []byte) (uint32, []byte, bool) { if len(src) < 4 { return 0, src, false } - return readi32unsafe(src), src[4:], true -} -func readi32unsafe(src []byte) int32 { - return int32(binary.LittleEndian.Uint32(src)) -} - -func readi64(src []byte) (int64, []byte, bool) { - if len(src) < 8 { - return 0, src, false - } - return int64(binary.LittleEndian.Uint64(src)), src[8:], true + return binary.LittleEndian.Uint32(src), src[4:], true } func readcstring(src []byte) (string, []byte, bool) { diff --git a/x/mongo/driver/wiremessage/wiremessage_test.go b/x/mongo/driver/wiremessage/wiremessage_test.go index 9645d86037..9d59bbea9e 100644 --- a/x/mongo/driver/wiremessage/wiremessage_test.go +++ b/x/mongo/driver/wiremessage/wiremessage_test.go @@ -7,7 +7,6 @@ package wiremessage import ( - "math" "testing" "go.mongodb.org/mongo-driver/v2/internal/assert" @@ -215,274 +214,3 @@ func TestReadMsgSectionDocumentSequence(t *testing.T) { }) } } - -func TestAppendi32(t *testing.T) { - testCases := []struct { - desc string - dst []byte - x int32 - want []byte - }{ - { - desc: "0", - x: 0, - want: []byte{0, 0, 0, 0}, - }, - { - desc: "1", - x: 1, - want: []byte{1, 0, 0, 0}, - }, - { - desc: "-1", - x: -1, - want: []byte{255, 255, 255, 255}, - }, - { - desc: "max", - x: math.MaxInt32, - want: []byte{255, 255, 255, 127}, - }, - { - desc: "min", - x: math.MinInt32, - want: []byte{0, 0, 0, 128}, - }, - { - desc: "non-empty dst", - dst: []byte{0, 1, 2, 3}, - x: 1, - want: []byte{0, 1, 2, 3, 1, 0, 0, 0}, - }, - } - - for _, tc := range testCases { - tc := tc // Capture range variable. - - t.Run(tc.desc, func(t *testing.T) { - t.Parallel() - - b := appendi32(tc.dst, tc.x) - assert.Equal(t, tc.want, b, "bytes do not match") - }) - } -} - -func TestAppendi64(t *testing.T) { - testCases := []struct { - desc string - dst []byte - x int64 - want []byte - }{ - { - desc: "0", - x: 0, - want: []byte{0, 0, 0, 0, 0, 0, 0, 0}, - }, - { - desc: "1", - x: 1, - want: []byte{1, 0, 0, 0, 0, 0, 0, 0}, - }, - { - desc: "-1", - x: -1, - want: []byte{255, 255, 255, 255, 255, 255, 255, 255}, - }, - { - desc: "max", - x: math.MaxInt64, - want: []byte{255, 255, 255, 255, 255, 255, 255, 127}, - }, - { - desc: "min", - x: math.MinInt64, - want: []byte{0, 0, 0, 0, 0, 0, 0, 128}, - }, - { - desc: "non-empty dst", - dst: []byte{0, 1, 2, 3}, - x: 1, - want: []byte{0, 1, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0}, - }, - } - - for _, tc := range testCases { - tc := tc // Capture range variable. - - t.Run(tc.desc, func(t *testing.T) { - t.Parallel() - - b := appendi64(tc.dst, tc.x) - assert.Equal(t, tc.want, b, "bytes do not match") - }) - } -} - -func TestReadi32(t *testing.T) { - testCases := []struct { - desc string - src []byte - want int32 - wantRem []byte - wantOK bool - }{ - { - desc: "0", - src: []byte{0, 0, 0, 0}, - want: 0, - wantRem: []byte{}, - wantOK: true, - }, - { - desc: "1", - src: []byte{1, 0, 0, 0}, - want: 1, - wantRem: []byte{}, - wantOK: true, - }, - { - desc: "-1", - src: []byte{255, 255, 255, 255}, - want: -1, - wantRem: []byte{}, - wantOK: true, - }, - { - desc: "max", - src: []byte{255, 255, 255, 127}, - want: math.MaxInt32, - wantRem: []byte{}, - wantOK: true, - }, - { - desc: "min", - src: []byte{0, 0, 0, 128}, - want: math.MinInt32, - wantRem: []byte{}, - wantOK: true, - }, - { - desc: "non-empty remaining", - src: []byte{1, 0, 0, 0, 0, 1, 2, 3}, - want: 1, - wantRem: []byte{0, 1, 2, 3}, - wantOK: true, - }, - { - desc: "not enough bytes", - src: []byte{0, 1, 2}, - want: 0, - wantRem: []byte{0, 1, 2}, - wantOK: false, - }, - { - desc: "nil", - src: nil, - want: 0, - wantRem: nil, - wantOK: false, - }, - } - - for _, tc := range testCases { - tc := tc // Capture range variable. - - t.Run(tc.desc, func(t *testing.T) { - t.Parallel() - - x, rem, ok := readi32(tc.src) - assert.Equal(t, tc.want, x, "int32 result does not match") - assert.Equal(t, tc.wantRem, rem, "remaining bytes do not match") - assert.Equal(t, tc.wantOK, ok, "OK does not match") - }) - } -} - -func TestReadi64(t *testing.T) { - testCases := []struct { - desc string - src []byte - want int64 - wantRem []byte - wantOK bool - }{ - { - desc: "0", - src: []byte{0, 0, 0, 0, 0, 0, 0, 0}, - want: 0, - wantRem: []byte{}, - wantOK: true, - }, - { - desc: "1", - src: []byte{1, 0, 0, 0, 0, 0, 0, 0}, - want: 1, - wantRem: []byte{}, - wantOK: true, - }, - { - desc: "-1", - src: []byte{255, 255, 255, 255, 255, 255, 255, 255}, - want: -1, - wantRem: []byte{}, - wantOK: true, - }, - { - desc: "max", - src: []byte{255, 255, 255, 255, 255, 255, 255, 127}, - want: math.MaxInt64, - wantRem: []byte{}, - wantOK: true, - }, - { - desc: "min", - src: []byte{0, 0, 0, 0, 0, 0, 0, 128}, - want: math.MinInt64, - wantRem: []byte{}, - wantOK: true, - }, - { - desc: "non-empty remaining", - src: []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3}, - want: 1, - wantRem: []byte{0, 1, 2, 3}, - wantOK: true, - }, - { - desc: "not enough bytes", - src: []byte{0, 1, 2, 3, 4, 5, 6}, - want: 0, - wantRem: []byte{0, 1, 2, 3, 4, 5, 6}, - wantOK: false, - }, - { - desc: "not enough bytes", - src: []byte{0, 1, 2, 3, 4, 5, 6}, - want: 0, - wantRem: []byte{0, 1, 2, 3, 4, 5, 6}, - wantOK: false, - }, - { - desc: "nil", - src: nil, - want: 0, - wantRem: nil, - wantOK: false, - }, - } - - for _, tc := range testCases { - tc := tc // Capture range variable. - - t.Run(tc.desc, func(t *testing.T) { - t.Parallel() - - x, rem, ok := readi64(tc.src) - assert.Equal(t, tc.want, x, "int64 result does not match") - assert.Equal(t, tc.wantRem, rem, "remaining bytes do not match") - assert.Equal(t, tc.wantOK, ok, "OK does not match") - }) - } -}