Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions internal/integration/gridfs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ import (
"context"
"io"
"math/rand"
"os"
"runtime"
"sync"
"testing"
"time"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
"go.mongodb.org/mongo-driver/v2/internal/integtest"
"go.mongodb.org/mongo-driver/v2/internal/israce"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo"
Expand Down Expand Up @@ -529,6 +532,41 @@ func TestGridFS(x *testing.T) {
})
}

func TestOpenUploadStreamConcurrently(t *testing.T) {
t.Parallel()

uri, err := integtest.MongoDBURI()
require.NoError(t, err, "error getting URI: %v", err)
opts := options.Client().ApplyURI(uri)
if os.Getenv("REQUIRE_API_VERSION") == "true" {
opts.SetServerAPIOptions(options.ServerAPI(options.ServerAPIVersion1))
}
client, err := mongo.Connect(opts)
require.NoError(t, err, "Connect error: %v", err)
defer func() {
_ = client.Disconnect(context.Background())
}()

db := client.Database(mtest.TestDB)
bucket := db.GridFSBucket()
defer func() {
_ = bucket.Drop(context.Background())
}()

const size = 10_000

wg := sync.WaitGroup{}
wg.Add(size)
for i := 0; i < size; i++ {
go func() {
defer wg.Done()
_, err := bucket.OpenUploadStream(context.Background(), "foo")
assert.NoError(t, err, "OpenUploadStream error: %v", err)
}()
}
wg.Wait()
}

func assertGridFSCollectionState(mt *mtest.T, coll *mongo.Collection, expectedName string, expectedNumDocuments int64) {
mt.Helper()

Expand Down
16 changes: 7 additions & 9 deletions mongo/gridfs_bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"errors"
"fmt"
"io"
"sync/atomic"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/csot"
Expand All @@ -37,6 +38,8 @@ var ErrMissingGridFSChunkSize = errors.New("files collection document does not c

// GridFSBucket represents a GridFS bucket.
type GridFSBucket struct {
firstWriteDone uint32

db *Database
chunksColl *Collection // collection to store file chunks
filesColl *Collection // collection to store file metadata
Expand All @@ -47,9 +50,8 @@ type GridFSBucket struct {
rc *readconcern.ReadConcern
rp *readpref.ReadPref

firstWriteDone bool
readBuf []byte
writeBuf []byte
readBuf []byte
writeBuf []byte
}

// upload contains options to upload a file to a bucket.
Expand Down Expand Up @@ -531,14 +533,10 @@ func (b *GridFSBucket) createIndexes(ctx context.Context) error {
}

func (b *GridFSBucket) checkFirstWrite(ctx context.Context) error {
if !b.firstWriteDone {
if atomic.CompareAndSwapUint32(&b.firstWriteDone, 0, 1) {
// before the first write operation, must determine if files collection is empty
// if so, create indexes if they do not already exist

if err := b.createIndexes(ctx); err != nil {
return err
}
b.firstWriteDone = true
return b.createIndexes(ctx)
}

return nil
Expand Down
Loading