diff --git a/internal/integration/csot_prose_test.go b/internal/integration/csot_prose_test.go index 28dcef7015..3923efd6e7 100644 --- a/internal/integration/csot_prose_test.go +++ b/internal/integration/csot_prose_test.go @@ -7,6 +7,7 @@ package integration import ( + "bytes" "context" "strings" "testing" @@ -17,6 +18,8 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" "go.mongodb.org/mongo-driver/v2/internal/integtest" + "go.mongodb.org/mongo-driver/v2/internal/mongoutil" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" ) @@ -83,6 +86,7 @@ func TestCSOTProse(t *testing.T) { assert.Equal(mt, started[1].CommandName, "insert", "expected a second insert event, got %v", started[1].CommandName) }) + mt.Run("8. server selection", func(mt *mtest.T) { cliOpts := options.Client().ApplyURI("mongodb://invalid/?serverSelectionTimeoutMS=100") mtOpts := mtest.NewOptions().ClientOptions(cliOpts).CreateCollection(false) @@ -107,7 +111,8 @@ func TestCSOTProse(t *testing.T) { cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=100&serverSelectionTimeoutMS=200") mtOpts = mtest.NewOptions().ClientOptions(cliOpts).CreateCollection(false) mt.RunOpts("timeoutMS honored for server selection if it's lower than serverSelectionTimeoutMS", mtOpts, func(mt *mtest.T) { - mt.Parallel() + // TODO(GODRIVER-3266): Why do parallel tests fail on windows builds? + // mt.Parallel() callback := func() bool { err := mt.Client.Ping(context.Background(), nil) @@ -126,7 +131,8 @@ func TestCSOTProse(t *testing.T) { cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=200&serverSelectionTimeoutMS=100") mtOpts = mtest.NewOptions().ClientOptions(cliOpts).CreateCollection(false) mt.RunOpts("serverSelectionTimeoutMS honored for server selection if it's lower than timeoutMS", mtOpts, func(mt *mtest.T) { - mt.Parallel() + // TODO(GODRIVER-3266): Why do parallel tests fail on windows builds? + // mt.Parallel() callback := func() bool { err := mt.Client.Ping(context.Background(), nil) @@ -145,7 +151,8 @@ func TestCSOTProse(t *testing.T) { cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=0&serverSelectionTimeoutMS=100") mtOpts = mtest.NewOptions().ClientOptions(cliOpts).CreateCollection(false) mt.RunOpts("serverSelectionTimeoutMS honored for server selection if timeoutMS=0", mtOpts, func(mt *mtest.T) { - mt.Parallel() + // TODO(GODRIVER-3266): Why do parallel tests fail on windows builds? + // mt.Parallel() callback := func() bool { err := mt.Client.Ping(context.Background(), nil) @@ -162,3 +169,327 @@ func TestCSOTProse(t *testing.T) { }) }) } + +func TestCSOTProse_GridFS(t *testing.T) { + mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) + + mt.RunOpts("6. gridfs - upload", mtest.NewOptions().MinServerVersion("4.4"), func(mt *mtest.T) { + mt.Run("uploads via openUploadStream can be timed out", func(mt *mtest.T) { + // Drop and re-create the db.fs.files and db.fs.chunks collections. + err := mt.Client.Database("db").Collection("fs.files").Drop(context.Background()) + assert.NoError(mt, err, "failed to drop files") + + err = mt.Client.Database("db").Collection("fs.chunks").Drop(context.Background()) + assert.NoError(mt, err, "failed to drop chunks") + + hosts, err := mongoutil.HostsFromURI(mtest.ClusterURI()) + require.NoError(mt, err) + + failpointHost := hosts[0] + + mt.ResetClient(options.Client(). + SetHosts([]string{failpointHost})) + + // Set a blocking "insert" fail point. + mt.SetFailPoint(mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"insert"}, + BlockConnection: true, + BlockTimeMS: 1250, + }, + }) + + // The automatic failpoint clearing may not clear failpoints set on + // specific hosts, so manually clear the failpoint we set on the specific + // mongos when the test is done. + defer func() { + mt.ResetClient(options.Client(). + SetHosts([]string{failpointHost})) + mt.ClearFailPoints() + }() + + // Create a new MongoClient with timeoutMS=1000. + cliOptions := options.Client().SetTimeout(1000 * time.Millisecond).ApplyURI(mtest.ClusterURI()). + SetHosts([]string{failpointHost}) + + integtest.AddTestServerAPIVersion(cliOptions) + + client, err := mongo.Connect(cliOptions) + assert.NoError(mt, err, "failed to connect to server") + + // Create a GridFS bucket that wraps the db database. + bucket := client.Database("db").GridFSBucket() + + uploadStream, err := bucket.OpenUploadStream(context.Background(), "filename") + require.NoError(mt, err, "failed to open upload stream") + + _, err = uploadStream.Write([]byte{0x12}) + require.NoError(mt, err, "failed to write to upload stream") + + err = uploadStream.Close() + assert.Error(t, err, context.DeadlineExceeded) + }) + + mt.Run("Aborting an upload stream can be timed out", func(mt *mtest.T) { + // Drop and re-create the db.fs.files and db.fs.chunks collections. + err := mt.Client.Database("db").Collection("fs.files").Drop(context.Background()) + assert.NoError(mt, err, "failed to drop files") + + err = mt.Client.Database("db").Collection("fs.chunks").Drop(context.Background()) + assert.NoError(mt, err, "failed to drop chunks") + + hosts, err := mongoutil.HostsFromURI(mtest.ClusterURI()) + require.NoError(mt, err) + + failpointHost := hosts[0] + + mt.ResetClient(options.Client(). + SetHosts([]string{failpointHost})) + + // Set a blocking "delete" fail point. + mt.SetFailPoint(mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"delete"}, + BlockConnection: true, + BlockTimeMS: 1250, + }, + }) + + // The automatic failpoint clearing may not clear failpoints set on + // specific hosts, so manually clear the failpoint we set on the specific + // mongos when the test is done. + defer func() { + mt.ResetClient(options.Client(). + SetHosts([]string{failpointHost})) + mt.ClearFailPoints() + }() + + // Create a new MongoClient with timeoutMS=1000. + cliOptions := options.Client().SetTimeout(1000 * time.Millisecond).ApplyURI(mtest.ClusterURI()). + SetHosts([]string{failpointHost}) + integtest.AddTestServerAPIVersion(cliOptions) + + client, err := mongo.Connect(cliOptions) + assert.NoError(mt, err, "failed to connect to server") + + // Create a GridFS bucket that wraps the db database. + bucket := client.Database("db").GridFSBucket(options.GridFSBucket().SetChunkSizeBytes(2)) + + // Call bucket.open_upload_stream() with the filename filename to create + // an upload stream (referred to as uploadStream). + uploadStream, err := bucket.OpenUploadStream(context.Background(), "filename") + require.NoError(mt, err) + + // Using uploadStream, upload the bytes [0x01, 0x02, 0x03, 0x04]. + _, err = uploadStream.Write([]byte{0x01, 0x02, 0x03, 0x04}) + require.NoError(mt, err) + + err = uploadStream.Abort() + assert.Error(mt, err, context.DeadlineExceeded) + }) + }) + + const test61 = "6.1 gridfs - upload and download with non-expiring client-level timeout" + mt.RunOpts(test61, mtest.NewOptions().MinServerVersion("4.4"), func(mt *mtest.T) { + // Drop and re-create the db.fs.files and db.fs.chunks collections. + err := mt.Client.Database("db").Collection("fs.files").Drop(context.Background()) + assert.NoError(mt, err, "failed to drop files") + + err = mt.Client.Database("db").Collection("fs.chunks").Drop(context.Background()) + assert.NoError(mt, err, "failed to drop chunks") + + // Create a new MongoClient with timeoutMS=500. + cliOptions := options.Client().SetTimeout(500 * time.Millisecond).ApplyURI(mtest.ClusterURI()) + integtest.AddTestServerAPIVersion(cliOptions) + + client, err := mongo.Connect(cliOptions) + assert.NoError(mt, err, "failed to connect to server") + + // Create a GridFS bucket that wraps the db database. + bucket := client.Database("db").GridFSBucket() + + mt.Run("UploadFromStream", func(mt *mtest.T) { + // Upload file and ensure it uploaded correctly. + fileID, err := bucket.UploadFromStream(context.Background(), "filename", bytes.NewReader([]byte{0x12})) + assert.NoError(mt, err, "failed to upload stream") + + buf := bytes.Buffer{} + + _, err = bucket.DownloadToStream(context.Background(), fileID, &buf) + assert.NoError(mt, err, "failed to download stream") + assert.Equal(mt, buf.Len(), 1) + assert.Equal(mt, buf.Bytes(), []byte{0x12}) + }) + + mt.Run("OpenUploadStream", func(mt *mtest.T) { + // Upload file and ensure it uploaded correctly. + uploadStream, err := bucket.OpenUploadStream(context.Background(), "filename2") + require.NoError(mt, err, "failed to open upload stream") + + _, err = uploadStream.Write([]byte{0x13}) + require.NoError(mt, err, "failed to write data to upload stream") + + err = uploadStream.Close() + require.NoError(mt, err, "failed to close upload stream") + + buf := bytes.Buffer{} + + _, err = bucket.DownloadToStream(context.Background(), uploadStream.FileID, &buf) + assert.NoError(mt, err, "failed to download stream") + assert.Equal(mt, buf.Len(), 1) + assert.Equal(mt, buf.Bytes(), []byte{0x13}) + }) + }) + + const test62 = "6.2 gridfs - upload with operation-level timeout" + mt.RunOpts(test62, mtest.NewOptions().MinServerVersion("4.4"), func(mt *mtest.T) { + // Drop and re-create the db.fs.files and db.fs.chunks collections. + err := mt.Client.Database("db").Collection("fs.files").Drop(context.Background()) + assert.NoError(mt, err, "failed to drop files") + + err = mt.Client.Database("db").Collection("fs.chunks").Drop(context.Background()) + assert.NoError(mt, err, "failed to drop chunks") + + hosts, err := mongoutil.HostsFromURI(mtest.ClusterURI()) + require.NoError(mt, err) + + failpointHost := hosts[0] + + mt.ResetClient(options.Client(). + SetHosts([]string{failpointHost})) + + // Set a blocking "insert" fail point. + mt.SetFailPoint(mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"insert"}, + BlockConnection: true, + BlockTimeMS: 200, + }, + }) + + // The automatic failpoint clearing may not clear failpoints set on + // specific hosts, so manually clear the failpoint we set on the specific + // mongos when the test is done. + defer func() { + mt.ResetClient(options.Client(). + SetHosts([]string{failpointHost})) + mt.ClearFailPoints() + }() + + cliOptions := options.Client().SetTimeout(100 * time.Millisecond).ApplyURI(mtest.ClusterURI()) + integtest.AddTestServerAPIVersion(cliOptions) + + client, err := mongo.Connect(cliOptions) + assert.NoError(mt, err, "failed to connect to server") + + // Create a GridFS bucket that wraps the db database. + bucket := client.Database("db").GridFSBucket() + + mt.Run("UploadFromStream", func(mt *mtest.T) { + + // If the operation-level context is not respected, then the client-level + // timeout will exceed deadline. + ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond) + defer cancel() + + // Upload file and ensure it uploaded correctly. + fileID, err := bucket.UploadFromStream(ctx, "filename", bytes.NewReader([]byte{0x12})) + require.NoError(mt, err, "failed to upload stream") + + buf := bytes.Buffer{} + + _, err = bucket.DownloadToStream(context.Background(), fileID, &buf) + assert.NoError(mt, err, "failed to download stream") + assert.Equal(mt, buf.Len(), 1) + assert.Equal(mt, buf.Bytes(), []byte{0x12}) + }) + + mt.Run("OpenUploadStream", func(mt *mtest.T) { + // If the operation-level context is not respected, then the client-level + // timeout will exceed deadline. + ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond) + defer cancel() + + // Upload file and ensure it uploaded correctly. + uploadStream, err := bucket.OpenUploadStream(ctx, "filename2") + require.NoError(mt, err, "failed to open upload stream") + + _, err = uploadStream.Write([]byte{0x13}) + require.NoError(mt, err, "failed to write data to upload stream") + + err = uploadStream.Close() + require.NoError(mt, err, "failed to close upload stream") + + buf := bytes.Buffer{} + + _, err = bucket.DownloadToStream(context.Background(), uploadStream.FileID, &buf) + assert.NoError(mt, err, "failed to download stream") + assert.Equal(mt, buf.Len(), 1) + assert.Equal(mt, buf.Bytes(), []byte{0x13}) + }) + }) + + const test63 = "6.3 gridfs - cancel context mid-stream" + mt.RunOpts(test63, mtest.NewOptions().MinServerVersion("4.4"), func(mt *mtest.T) { + // Drop and re-create the db.fs.files and db.fs.chunks collections. + err := mt.Client.Database("db").Collection("fs.files").Drop(context.Background()) + assert.NoError(mt, err, "failed to drop files") + + err = mt.Client.Database("db").Collection("fs.chunks").Drop(context.Background()) + assert.NoError(mt, err, "failed to drop chunks") + + cliOptions := options.Client().ApplyURI(mtest.ClusterURI()) + integtest.AddTestServerAPIVersion(cliOptions) + + client, err := mongo.Connect(cliOptions) + assert.NoError(mt, err, "failed to connect to server") + + // Create a GridFS bucket that wraps the db database. + bucket := client.Database("db").GridFSBucket() + + mt.Run("Upload#Close", func(mt *mtest.T) { + uploadStream, err := bucket.OpenUploadStream(context.Background(), "filename") + require.NoError(mt, err) + + _ = uploadStream.Close() + + _, err = uploadStream.Write([]byte{0x13}) + assert.Error(mt, err, context.Canceled) + }) + + mt.Run("Upload#Abort", func(mt *mtest.T) { + uploadStream, err := bucket.OpenUploadStream(context.Background(), "filename2") + require.NoError(mt, err) + + _ = uploadStream.Abort() + + _, err = uploadStream.Write([]byte{0x13}) + assert.Error(mt, err, context.Canceled) + }) + + mt.Run("Download#Close", func(mt *mtest.T) { + fileID, err := bucket.UploadFromStream(context.Background(), "filename3", bytes.NewReader([]byte{0x12})) + require.NoError(mt, err, "failed to upload stream") + + downloadStream, err := bucket.OpenDownloadStream(context.Background(), fileID) + assert.NoError(mt, err) + + _ = downloadStream.Close() + + _, err = downloadStream.Read([]byte{}) + assert.Error(mt, err, context.Canceled) + }) + }) +} diff --git a/mongo/gridfs_bucket.go b/mongo/gridfs_bucket.go index 65fd7bc29b..8bb418ad18 100644 --- a/mongo/gridfs_bucket.go +++ b/mongo/gridfs_bucket.go @@ -85,7 +85,6 @@ func (b *GridFSBucket) OpenUploadStreamWithID( opts ...options.Lister[options.GridFSUploadOptions], ) (*GridFSUploadStream, error) { ctx, cancel := csot.WithTimeout(ctx, b.db.client.timeout) - defer cancel() if err := b.checkFirstWrite(ctx); err != nil { return nil, err @@ -96,7 +95,7 @@ func (b *GridFSBucket) OpenUploadStreamWithID( return nil, err } - return newUploadStream(ctx, upload, fileID, filename, b.chunksColl, b.filesColl), nil + return newUploadStream(ctx, cancel, upload, fileID, filename, b.chunksColl, b.filesColl), nil } // UploadFromStream creates a fileID and uploads a file given a source stream. @@ -135,6 +134,9 @@ func (b *GridFSBucket) UploadFromStreamWithID( source io.Reader, opts ...options.Lister[options.GridFSUploadOptions], ) error { + ctx, cancel := csot.WithTimeout(ctx, b.db.client.timeout) + defer cancel() + us, err := b.OpenUploadStreamWithID(ctx, fileID, filename, opts...) if err != nil { return err @@ -350,7 +352,6 @@ func (b *GridFSBucket) openDownloadStream( opts ...options.Lister[options.FindOneOptions], ) (*GridFSDownloadStream, error) { ctx, cancel := csot.WithTimeout(ctx, b.db.client.timeout) - defer cancel() result := b.filesColl.FindOne(ctx, filter, opts...) @@ -369,7 +370,7 @@ func (b *GridFSBucket) openDownloadStream( foundFile := newFileFromResponse(resp) if foundFile.Length == 0 { - return newGridFSDownloadStream(ctx, nil, foundFile.ChunkSize, foundFile), nil + return newGridFSDownloadStream(ctx, cancel, nil, foundFile.ChunkSize, foundFile), nil } // For a file with non-zero length, chunkSize must exist so we know what size to expect when downloading chunks. @@ -384,7 +385,7 @@ func (b *GridFSBucket) openDownloadStream( // The chunk size can be overridden for individual files, so the expected chunk size should be the "chunkSize" // field from the files collection document, not the bucket's chunk size. - return newGridFSDownloadStream(ctx, chunksCursor, foundFile.ChunkSize, foundFile), nil + return newGridFSDownloadStream(ctx, cancel, chunksCursor, foundFile.ChunkSize, foundFile), nil } func (b *GridFSBucket) downloadToStream(ds *GridFSDownloadStream, stream io.Writer) (int64, error) { diff --git a/mongo/gridfs_download_stream.go b/mongo/gridfs_download_stream.go index f33515fe57..1cc9bf65fd 100644 --- a/mongo/gridfs_download_stream.go +++ b/mongo/gridfs_download_stream.go @@ -39,6 +39,7 @@ type GridFSDownloadStream struct { expectedChunk int32 // index of next expected chunk fileLen int64 ctx context.Context + cancel context.CancelFunc // The pointer returned by GetFile. This should not be used in the actual GridFSDownloadStream code outside of the // newGridFSDownloadStream constructor because the values can be mutated by the user after calling GetFile. Instead, @@ -95,7 +96,13 @@ func newFileFromResponse(resp findFileResponse) *GridFSFile { } } -func newGridFSDownloadStream(ctx context.Context, cursor *Cursor, chunkSize int32, file *GridFSFile) *GridFSDownloadStream { +func newGridFSDownloadStream( + ctx context.Context, + cancel context.CancelFunc, + cursor *Cursor, + chunkSize int32, + file *GridFSFile, +) *GridFSDownloadStream { numChunks := int32(math.Ceil(float64(file.Length) / float64(chunkSize))) return &GridFSDownloadStream{ @@ -107,11 +114,18 @@ func newGridFSDownloadStream(ctx context.Context, cursor *Cursor, chunkSize int3 fileLen: file.Length, file: file, ctx: ctx, + cancel: cancel, } } // Close closes this download stream. func (ds *GridFSDownloadStream) Close() error { + defer func() { + if ds.cancel != nil { + ds.cancel() + } + }() + if ds.closed { return ErrStreamClosed } diff --git a/mongo/gridfs_upload_stream.go b/mongo/gridfs_upload_stream.go index 4d0cc5d304..c1f9277412 100644 --- a/mongo/gridfs_upload_stream.go +++ b/mongo/gridfs_upload_stream.go @@ -40,11 +40,13 @@ type GridFSUploadStream struct { bufferIndex int fileLen int64 ctx context.Context + cancel context.CancelFunc } // NewUploadStream creates a new upload stream. func newUploadStream( ctx context.Context, + cancel context.CancelFunc, up *upload, fileID interface{}, filename string, @@ -59,11 +61,18 @@ func newUploadStream( filesColl: files, buffer: make([]byte, uploadBufferSize), ctx: ctx, + cancel: cancel, } } // Close writes file metadata to the files collection and cleans up any resources associated with the UploadStream. func (us *GridFSUploadStream) Close() error { + defer func() { + if us.cancel != nil { + us.cancel() + } + }() + if us.closed { return ErrStreamClosed } @@ -111,6 +120,12 @@ func (us *GridFSUploadStream) Write(p []byte) (int, error) { // Abort closes the stream and deletes all file chunks that have already been written. func (us *GridFSUploadStream) Abort() error { + defer func() { + if us.cancel != nil { + us.cancel() + } + }() + if us.closed { return ErrStreamClosed }