From 9c023ab1eda20b20e96bf18dea1600eb9770c10c Mon Sep 17 00:00:00 2001 From: ssd04 Date: Wed, 8 Mar 2023 14:42:44 +0200 Subject: [PATCH 01/35] create mongodb type in sharded storage factory --- core/constants.go | 3 ++ .../storage/factory/shardedStorageFactory.go | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/core/constants.go b/core/constants.go index 26b7d8ad..adca0d28 100644 --- a/core/constants.go +++ b/core/constants.go @@ -8,3 +8,6 @@ type DBType string // LevelDB is the local levelDB const LevelDB DBType = "levelDB" + +// MongoDB is the mongo db identifier +const MongoDB DBType = "mongoDB" diff --git a/handlers/storage/factory/shardedStorageFactory.go b/handlers/storage/factory/shardedStorageFactory.go index 048893dd..c5c733c1 100644 --- a/handlers/storage/factory/shardedStorageFactory.go +++ b/handlers/storage/factory/shardedStorageFactory.go @@ -6,7 +6,9 @@ import ( "github.com/multiversx/multi-factor-auth-go-service/config" "github.com/multiversx/multi-factor-auth-go-service/core" "github.com/multiversx/multi-factor-auth-go-service/handlers" + "github.com/multiversx/multi-factor-auth-go-service/handlers/storage" "github.com/multiversx/multi-factor-auth-go-service/handlers/storage/bucket" + "github.com/multiversx/multi-factor-auth-go-service/mongodb" "github.com/multiversx/mx-chain-storage-go/storageUnit" ) @@ -26,12 +28,44 @@ func (ssf *shardedStorageFactory) Create() (core.ShardedStorageWithIndex, error) switch ssf.cfg.ShardedStorage.DBType { case core.LevelDB: return ssf.createLocalDB() + case core.MongoDB: + return ssf.createLocalDB() default: // TODO: implement other types of storage return nil, handlers.ErrInvalidConfig } } +func (ssf *shardedStorageFactory) createMongoDB() (core.ShardedStorageWithIndex, error) { + client, err := mongodb.NewMongoDBClient(ssf.cfg.MongoDB) + if err != nil { + return nil, err + } + + storer, err := storage.NewMongoDBStorerHandler(client, mongodb.UsersCollection) + if err != nil { + return nil, err + } + + bucketIDProvider, err := bucket.NewBucketIDProvider(1) + if err != nil { + return nil, err + } + + bucketIndexHandlers := make(map[uint32]core.BucketIndexHandler, 1) + bucketIndexHandlers[0], err = bucket.NewBucketIndexHandler(storer) + if err != nil { + return nil, err + } + + argsShardedStorageWithIndex := bucket.ArgShardedStorageWithIndex{ + BucketIDProvider: bucketIDProvider, + BucketHandlers: bucketIndexHandlers, + } + + return bucket.NewShardedStorageWithIndex(argsShardedStorageWithIndex) +} + func (ssf *shardedStorageFactory) createLocalDB() (core.ShardedStorageWithIndex, error) { numbOfBuckets := ssf.cfg.Buckets.NumberOfBuckets bucketIDProvider, err := bucket.NewBucketIDProvider(numbOfBuckets) From ed7b3097deb8c34301858af932438e29ab9db4a6 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Thu, 9 Mar 2023 14:51:28 +0200 Subject: [PATCH 02/35] setup for mongodb cluster with nginx proxy --- Makefile | 2 + docker/mongodb-cluster-full.yml | 76 +++++++++++++++++++++++++++++++++ docker/nginx/nginx.conf | 56 ++++++++++++++++++++++++ 3 files changed, 134 insertions(+) create mode 100644 docker/mongodb-cluster-full.yml create mode 100644 docker/nginx/nginx.conf diff --git a/Makefile b/Makefile index 2239ed7e..25b87745 100644 --- a/Makefile +++ b/Makefile @@ -85,6 +85,8 @@ docker-rm: docker-stop compose_file = docker/mongodb-cluster.yml ifeq ($(db_setup),redis) compose_file = docker/redis-cluster.yml +else ifeq ($(db_setup),mongodb-full) + compose_file = docker/mongodb-cluster-full.yml endif compose-new: diff --git a/docker/mongodb-cluster-full.yml b/docker/mongodb-cluster-full.yml new file mode 100644 index 00000000..7b93ae03 --- /dev/null +++ b/docker/mongodb-cluster-full.yml @@ -0,0 +1,76 @@ +version: '3' + +services: + mongodb0: + container_name: mongodb0 + image: mongo + ports: + - 27017:27017 + networks: + - mongo + depends_on: + - mongodb1 + - mongodb2 + links: + - mongodb1 + - mongodb2 + restart: always + entrypoint: [ "/usr/bin/mongod", "--bind_ip_all", "--replSet", "mongoReplSet" ] + + mongoinit: + image: mongo + volumes: + - ./mongodb/init.sh:/scripts/init.sh + networks: + - mongo + depends_on: + - mongodb0 + links: + - mongodb0 + restart: "no" + entrypoint: [ "bash", "-c", "sleep 10 && /scripts/init.sh"] + + mongodb1: + container_name: mongodb1 + image: mongo + ports: + - 27018:27017 + networks: + - mongo + restart: always + entrypoint: [ "/usr/bin/mongod", "--bind_ip_all", "--replSet", "mongoReplSet" ] + + mongodb2: + container_name: mongodb2 + image: mongo + ports: + - 27019:27017 + networks: + - mongo + restart: always + entrypoint: [ "/usr/bin/mongod", "--bind_ip_all", "--replSet", "mongoReplSet" ] + + tcs: + build: + context: .. + dockerfile: ../Dockerfile + image: multi-factor-auth:latest + ports: + - 8080:8080 + networks: + - mongo + + nginx: + image: nginx:latest + volumes: + - ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro + depends_on: + - tcs + networks: + - mongo + ports: + - 5000:5000 + +networks: + mongo: + driver: bridge diff --git a/docker/nginx/nginx.conf b/docker/nginx/nginx.conf new file mode 100644 index 00000000..7b94d638 --- /dev/null +++ b/docker/nginx/nginx.conf @@ -0,0 +1,56 @@ +user www-data; +worker_processes auto; +worker_rlimit_nofile 32768; + +error_log /var/log/nginx/error.log; +pid /var/run/nginx.pid; +include /usr/share/nginx/modules/*.conf; +debug_points abort; + +events { + worker_connections 8192; +} + +http { + real_ip_header X-Forwarded-For; + set_real_ip_from 10.0.0.0/8; + proxy_cache_path /var/cache/nginx/cache keys_zone=elasticsearch:10m inactive=60m; + + log_format main '$remote_addr - $remote_user [$time_local] "$request" ' + '$status $body_bytes_sent "$http_referer" ' + '"$http_user_agent" "$http_x_forwarded_for"'; + + access_log /var/log/nginx/access.log main; + error_log /var/log/nginx/error.log; + + sendfile on; + tcp_nopush on; + tcp_nodelay on; + keepalive_timeout 65; + types_hash_max_size 2048; + include /etc/nginx/mime.types; + default_type application/octet-stream; + + upstream tcs { + server tcs:8080; + keepalive 16; + } + + server { + listen 5000; + server_name tcs_proxy; + + location / { + proxy_http_version 1.1; + + proxy_connect_timeout 5s; + proxy_read_timeout 10s; + + client_max_body_size 100M; + + proxy_pass http://tcs; +# timeout for an idle keepalive connection to an upstream server will stay open + keepalive_timeout 50s; + } + } +} From a22bd9a5cdc2ab440f8282ebe4c987088320c69f Mon Sep 17 00:00:00 2001 From: ssd04 Date: Thu, 9 Mar 2023 15:51:16 +0200 Subject: [PATCH 03/35] user create mongo db client from factory --- handlers/storage/factory/shardedStorageFactory.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/handlers/storage/factory/shardedStorageFactory.go b/handlers/storage/factory/shardedStorageFactory.go index c5c733c1..89d70eff 100644 --- a/handlers/storage/factory/shardedStorageFactory.go +++ b/handlers/storage/factory/shardedStorageFactory.go @@ -37,12 +37,12 @@ func (ssf *shardedStorageFactory) Create() (core.ShardedStorageWithIndex, error) } func (ssf *shardedStorageFactory) createMongoDB() (core.ShardedStorageWithIndex, error) { - client, err := mongodb.NewMongoDBClient(ssf.cfg.MongoDB) + client, err := mongodb.CreateMongoDBClient(ssf.cfg.MongoDB) if err != nil { return nil, err } - storer, err := storage.NewMongoDBStorerHandler(client, mongodb.UsersCollection) + storer, err := storage.NewMongoDBStorerHandler(client, mongodb.UsersCollectionID) if err != nil { return nil, err } From b7cd62e64f0465c8b6e59372e66d5e81cffee683 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Thu, 9 Mar 2023 23:45:07 +0200 Subject: [PATCH 04/35] timeout options for mongo client --- handlers/storage/factory/shardedStorageFactory.go | 2 +- mongodb/mongoDBClientFactory.go | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/handlers/storage/factory/shardedStorageFactory.go b/handlers/storage/factory/shardedStorageFactory.go index 89d70eff..be6857a0 100644 --- a/handlers/storage/factory/shardedStorageFactory.go +++ b/handlers/storage/factory/shardedStorageFactory.go @@ -29,7 +29,7 @@ func (ssf *shardedStorageFactory) Create() (core.ShardedStorageWithIndex, error) case core.LevelDB: return ssf.createLocalDB() case core.MongoDB: - return ssf.createLocalDB() + return ssf.createMongoDB() default: // TODO: implement other types of storage return nil, handlers.ErrInvalidConfig diff --git a/mongodb/mongoDBClientFactory.go b/mongodb/mongoDBClientFactory.go index ff680ffe..196a959b 100644 --- a/mongodb/mongoDBClientFactory.go +++ b/mongodb/mongoDBClientFactory.go @@ -1,14 +1,26 @@ package mongodb import ( + "time" + "github.com/multiversx/multi-factor-auth-go-service/config" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) +const ( + connectTimeoutSec = 60 + operationTimeoutSec = 60 +) + // CreateMongoDBClient will create a new mongo db client instance func CreateMongoDBClient(cfg config.MongoDBConfig) (MongoDBClient, error) { - client, err := mongo.NewClient(options.Client().ApplyURI(cfg.URI)) + opts := options.Client() + opts.SetConnectTimeout(connectTimeoutSec * time.Second) + opts.SetTimeout(operationTimeoutSec * time.Second) + opts.ApplyURI(cfg.URI) + + client, err := mongo.NewClient(opts) if err != nil { return nil, err } From b6d60e8cabc51f7dfb1f89a61b9aedc48a9f04d1 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Fri, 10 Mar 2023 10:02:22 +0200 Subject: [PATCH 05/35] added separate index handler for mongo --- handlers/storage/bucket/bucketIndexHandler.go | 16 ++-- .../storage/bucket/mongodbIndexHandler.go | 90 +++++++++++++++++++ .../storage/factory/shardedStorageFactory.go | 2 +- mongodb/dbClient.go | 66 ++++++++++++++ mongodb/interface.go | 2 + mongodb/mongodbClientWrapper.go | 5 ++ 6 files changed, 172 insertions(+), 9 deletions(-) create mode 100644 handlers/storage/bucket/mongodbIndexHandler.go diff --git a/handlers/storage/bucket/bucketIndexHandler.go b/handlers/storage/bucket/bucketIndexHandler.go index ac7fc822..e7e48789 100644 --- a/handlers/storage/bucket/bucketIndexHandler.go +++ b/handlers/storage/bucket/bucketIndexHandler.go @@ -33,7 +33,7 @@ func NewBucketIndexHandler(bucket core.Storer) (*bucketIndexHandler, error) { return handler, nil } - err = handler.saveNewIndex(0) + err = saveNewIndex(handler.bucket, 0) if err != nil { return nil, err } @@ -46,14 +46,14 @@ func (handler *bucketIndexHandler) AllocateBucketIndex() (uint32, error) { handler.mut.Lock() defer handler.mut.Unlock() - index, err := handler.getIndex() + index, err := getIndex(handler.bucket) if err != nil { return 0, err } index++ - return index, handler.saveNewIndex(index) + return index, saveNewIndex(handler.bucket, index) } // Put adds data to the bucket @@ -76,7 +76,7 @@ func (handler *bucketIndexHandler) GetLastIndex() (uint32, error) { handler.mut.RLock() defer handler.mut.RUnlock() - return handler.getIndex() + return getIndex(handler.bucket) } // Close closes the internal bucket @@ -88,8 +88,8 @@ func (handler *bucketIndexHandler) Close() error { } // must be called under mutex protection -func (handler *bucketIndexHandler) getIndex() (uint32, error) { - lastIndexBytes, err := handler.bucket.Get([]byte(lastIndexKey)) +func getIndex(storer core.Storer) (uint32, error) { + lastIndexBytes, err := storer.Get([]byte(lastIndexKey)) if err != nil { return 0, err } @@ -98,10 +98,10 @@ func (handler *bucketIndexHandler) getIndex() (uint32, error) { } // must be called under mutex protection -func (handler *bucketIndexHandler) saveNewIndex(newIndex uint32) error { +func saveNewIndex(storer core.Storer, newIndex uint32) error { latestIndexBytes := make([]byte, uint32Bytes) binary.BigEndian.PutUint32(latestIndexBytes, newIndex) - return handler.bucket.Put([]byte(lastIndexKey), latestIndexBytes) + return storer.Put([]byte(lastIndexKey), latestIndexBytes) } // IsInterfaceNil returns true if there is no value under the interface diff --git a/handlers/storage/bucket/mongodbIndexHandler.go b/handlers/storage/bucket/mongodbIndexHandler.go new file mode 100644 index 00000000..6ec0d052 --- /dev/null +++ b/handlers/storage/bucket/mongodbIndexHandler.go @@ -0,0 +1,90 @@ +package bucket + +import ( + "sync" + + "github.com/multiversx/multi-factor-auth-go-service/core" + "github.com/multiversx/multi-factor-auth-go-service/mongodb" + "github.com/multiversx/mx-chain-core-go/core/check" +) + +const initialIndexValue = 1 + +type mongodbIndexHandler struct { + storer core.Storer + mongodbClient mongodb.MongoDBClient + mut sync.RWMutex +} + +// NewMongoDBIndexHandler returns a new instance of a bucket index handler +func NewMongoDBIndexHandler(storer core.Storer, mongoClient mongodb.MongoDBClient) (*mongodbIndexHandler, error) { + if check.IfNil(storer) { + return nil, core.ErrNilBucket + } + + handler := &mongodbIndexHandler{ + storer: storer, + mongodbClient: mongoClient, + } + + err := storer.Has([]byte(lastIndexKey)) + if err == nil { + return handler, nil + } + + err = saveNewIndex(handler.storer, initialIndexValue) + if err != nil { + return nil, err + } + + return handler, nil +} + +// AllocateBucketIndex allocates a new index and returns it +func (handler *mongodbIndexHandler) AllocateBucketIndex() (uint32, error) { + handler.mut.Lock() + defer handler.mut.Unlock() + + newIndex, err := handler.mongodbClient.IncrementWithTransaction(mongodb.UsersCollectionID, []byte(lastIndexKey)) + if err != nil { + return 0, err + } + + return newIndex, nil +} + +// Put adds data to the bucket +func (handler *mongodbIndexHandler) Put(key, data []byte) error { + return handler.storer.Put(key, data) +} + +// Get returns the value for the key from the bucket +func (handler *mongodbIndexHandler) Get(key []byte) ([]byte, error) { + return handler.storer.Get(key) +} + +// Has returns true if the key exists in the bucket +func (handler *mongodbIndexHandler) Has(key []byte) error { + return handler.storer.Has(key) +} + +// GetLastIndex returns the last index that was allocated +func (handler *mongodbIndexHandler) GetLastIndex() (uint32, error) { + handler.mut.RLock() + defer handler.mut.RUnlock() + + return getIndex(handler.storer) +} + +// Close closes the internal bucket +func (handler *mongodbIndexHandler) Close() error { + handler.mut.Lock() + defer handler.mut.Unlock() + + return handler.storer.Close() +} + +// IsInterfaceNil returns true if there is no value under the interface +func (handler *mongodbIndexHandler) IsInterfaceNil() bool { + return handler == nil +} diff --git a/handlers/storage/factory/shardedStorageFactory.go b/handlers/storage/factory/shardedStorageFactory.go index be6857a0..3289630f 100644 --- a/handlers/storage/factory/shardedStorageFactory.go +++ b/handlers/storage/factory/shardedStorageFactory.go @@ -53,7 +53,7 @@ func (ssf *shardedStorageFactory) createMongoDB() (core.ShardedStorageWithIndex, } bucketIndexHandlers := make(map[uint32]core.BucketIndexHandler, 1) - bucketIndexHandlers[0], err = bucket.NewBucketIndexHandler(storer) + bucketIndexHandlers[0], err = bucket.NewMongoDBIndexHandler(storer, client) if err != nil { return nil, err } diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index ae1435f4..184254a8 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -2,13 +2,18 @@ package mongodb import ( "context" + "encoding/binary" "errors" "github.com/multiversx/mx-chain-core-go/core/check" + logger "github.com/multiversx/mx-chain-logger-go" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) +var log = logger.GetOrCreate("mongodb") + // CollectionID defines mongodb collection type type CollectionID string @@ -74,6 +79,8 @@ func (mdc *mongodbClient) Put(collID CollectionID, key []byte, data []byte) erro opts := options.Update().SetUpsert(true) + log.Debug("Put", "key", string(key), "value", string(data)) + _, err := coll.UpdateOne(mdc.ctx, filter, update, opts) if err != nil { return err @@ -106,12 +113,15 @@ func (mdc *mongodbClient) Get(collID CollectionID, key []byte) ([]byte, error) { return nil, err } + log.Debug("Get", "key", string(key)) + return entry.Value, nil } // Has will return true if the provided key exists in the collection func (mdc *mongodbClient) Has(collID CollectionID, key []byte) error { _, err := mdc.findOne(collID, key) + log.Debug("Has", "key", string(key)) return err } @@ -132,6 +142,62 @@ func (mdc *mongodbClient) Remove(collID CollectionID, key []byte) error { return nil } +// IncrementWithTransaction will increment the value for the provided key, within a transaction +func (mdc *mongodbClient) IncrementWithTransaction(collID CollectionID, key []byte) (uint32, error) { + coll, ok := mdc.collections[collID] + if !ok { + return 0, ErrCollectionNotFound + } + + callback := func(sessCtx mongo.SessionContext) (interface{}, error) { + filter := bson.D{{Key: "_id", Value: string(key)}} + + entry := &mongoEntry{} + err := coll.FindOne(sessCtx, filter).Decode(entry) + if err != nil { + return nil, err + } + + uint32Value := binary.BigEndian.Uint32(entry.Value) + index := uint32Value + 1 + + latestIndexBytes := make([]byte, 4) + binary.BigEndian.PutUint32(latestIndexBytes, index) + + filter = bson.D{{Key: "_id", Value: string(key)}} + update := bson.D{{Key: "$set", + Value: bson.D{ + {Key: "_id", Value: string(key)}, + {Key: "value", Value: latestIndexBytes}, + }, + }} + + opts := options.Update().SetUpsert(true) + + _, err = coll.UpdateOne(sessCtx, filter, update, opts) + if err != nil { + return nil, err + } + + return index, nil + } + + // Step 2: Start a session and run the callback using WithTransaction. + session, err := mdc.client.StartSession() + if err != nil { + return 0, err + } + defer session.EndSession(mdc.ctx) + + newIndex, err := session.WithTransaction(mdc.ctx, callback) + if err != nil { + return 0, err + } + index := newIndex.(uint32) + + return index, nil +} + // Close will close the mongodb client func (mdc *mongodbClient) Close() error { return mdc.client.Disconnect(mdc.ctx) diff --git a/mongodb/interface.go b/mongodb/interface.go index d554c58a..c93c987c 100644 --- a/mongodb/interface.go +++ b/mongodb/interface.go @@ -12,6 +12,7 @@ type MongoDBClientWrapper interface { Connect(ctx context.Context) error Disconnect(ctx context.Context) error DBCollection(dbName string, collName string) MongoDBCollection + StartSession() (mongo.Session, error) IsInterfaceNil() bool } @@ -28,6 +29,7 @@ type MongoDBClient interface { Get(coll CollectionID, key []byte) ([]byte, error) Has(coll CollectionID, key []byte) error Remove(coll CollectionID, key []byte) error + IncrementWithTransaction(collID CollectionID, key []byte) (uint32, error) Close() error IsInterfaceNil() bool } diff --git a/mongodb/mongodbClientWrapper.go b/mongodb/mongodbClientWrapper.go index ab578186..02ad30a3 100644 --- a/mongodb/mongodbClientWrapper.go +++ b/mongodb/mongodbClientWrapper.go @@ -31,6 +31,11 @@ func (m *mongoDBClientWrapper) DBCollection(dbName string, coll string) MongoDBC return m.client.Database(dbName).Collection(coll) } +// DBCollection will return the specified collection object +func (m *mongoDBClientWrapper) StartSession() (mongo.Session, error) { + return m.client.StartSession() +} + // IsInterfaceNil returns true if there is no value under the interface func (m *mongoDBClientWrapper) IsInterfaceNil() bool { return m == nil From ba6f7a713ffa46fc6090a9d101df3607ed3057a1 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Fri, 10 Mar 2023 14:11:18 +0200 Subject: [PATCH 06/35] unit tests for mongo db index handler --- .../storage/bucket/mongodbIndexHandler.go | 11 +- .../bucket/mongodbIndexHandler_test.go | 132 ++++++++++++++++++ mongodb/dbClient.go | 24 +++- mongodb/dbClient_test.go | 49 ++++++- mongodb/interface.go | 9 +- mongodb/mongodbClientWrapper.go | 2 +- testscommon/mongoDBClientStub.go | 32 +++++ testscommon/mongoDBClientWrapperStub.go | 31 ++-- 8 files changed, 260 insertions(+), 30 deletions(-) create mode 100644 handlers/storage/bucket/mongodbIndexHandler_test.go diff --git a/handlers/storage/bucket/mongodbIndexHandler.go b/handlers/storage/bucket/mongodbIndexHandler.go index 6ec0d052..d66d4ff5 100644 --- a/handlers/storage/bucket/mongodbIndexHandler.go +++ b/handlers/storage/bucket/mongodbIndexHandler.go @@ -19,7 +19,10 @@ type mongodbIndexHandler struct { // NewMongoDBIndexHandler returns a new instance of a bucket index handler func NewMongoDBIndexHandler(storer core.Storer, mongoClient mongodb.MongoDBClient) (*mongodbIndexHandler, error) { if check.IfNil(storer) { - return nil, core.ErrNilBucket + return nil, core.ErrNilStorer + } + if check.IfNil(mongoClient) { + return nil, core.ErrNilMongoDBClient } handler := &mongodbIndexHandler{ @@ -53,17 +56,17 @@ func (handler *mongodbIndexHandler) AllocateBucketIndex() (uint32, error) { return newIndex, nil } -// Put adds data to the bucket +// Put adds data to storer func (handler *mongodbIndexHandler) Put(key, data []byte) error { return handler.storer.Put(key, data) } -// Get returns the value for the key from the bucket +// Get returns the value for the key from storer func (handler *mongodbIndexHandler) Get(key []byte) ([]byte, error) { return handler.storer.Get(key) } -// Has returns true if the key exists in the bucket +// Has returns true if the key exists in storer func (handler *mongodbIndexHandler) Has(key []byte) error { return handler.storer.Has(key) } diff --git a/handlers/storage/bucket/mongodbIndexHandler_test.go b/handlers/storage/bucket/mongodbIndexHandler_test.go new file mode 100644 index 00000000..b715eff1 --- /dev/null +++ b/handlers/storage/bucket/mongodbIndexHandler_test.go @@ -0,0 +1,132 @@ +package bucket + +import ( + "encoding/binary" + "sync" + "testing" + + "github.com/multiversx/multi-factor-auth-go-service/core" + "github.com/multiversx/multi-factor-auth-go-service/mongodb" + "github.com/multiversx/multi-factor-auth-go-service/testscommon" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" +) + +func TestNewMongoDBIndexHandler(t *testing.T) { + t.Parallel() + + t.Run("nil storer should error", func(t *testing.T) { + t.Parallel() + + handler, err := NewMongoDBIndexHandler(nil, &testscommon.MongoDBClientStub{}) + assert.Equal(t, core.ErrNilStorer, err) + assert.True(t, check.IfNil(handler)) + }) + + t.Run("nil mongo clinet should error", func(t *testing.T) { + t.Parallel() + + handler, err := NewMongoDBIndexHandler(&testscommon.StorerStub{}, nil) + assert.Equal(t, core.ErrNilMongoDBClient, err) + assert.True(t, check.IfNil(handler)) + }) + + t.Run("should work, bucket has lastIndexKey", func(t *testing.T) { + t.Parallel() + + handler, err := NewMongoDBIndexHandler(&testscommon.StorerStub{ + HasCalled: func(key []byte) error { + assert.Equal(t, []byte(lastIndexKey), key) + return nil + }, + }, &testscommon.MongoDBClientStub{}) + assert.Nil(t, err) + assert.False(t, check.IfNil(handler)) + }) + + t.Run("should work, empty bucket", func(t *testing.T) { + t.Parallel() + + handler, err := NewMongoDBIndexHandler(&testscommon.StorerStub{ + HasCalled: func(key []byte) error { + assert.Equal(t, []byte(lastIndexKey), key) + return expectedErr + }, + PutCalled: func(key, data []byte) error { + assert.Equal(t, []byte(lastIndexKey), key) + return nil + }, + }, &testscommon.MongoDBClientStub{}) + assert.Nil(t, err) + assert.False(t, check.IfNil(handler)) + }) + + t.Run("empty bucket and put lastIndexKey fails", func(t *testing.T) { + t.Parallel() + + handler, err := NewMongoDBIndexHandler(&testscommon.StorerStub{ + HasCalled: func(key []byte) error { + assert.Equal(t, []byte(lastIndexKey), key) + return expectedErr + }, + PutCalled: func(key, data []byte) error { + assert.Equal(t, []byte(lastIndexKey), key) + return expectedErr + }, + }, &testscommon.MongoDBClientStub{}) + assert.Equal(t, expectedErr, err) + assert.True(t, check.IfNil(handler)) + }) +} + +func TestMongoDBIndexHandler_Operations(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r != nil { + assert.Fail(t, "should have not panicked") + } + }() + + handler, _ := NewMongoDBIndexHandler(&testscommon.StorerStub{ + GetCalled: func(key []byte) ([]byte, error) { + index := make([]byte, uint32Bytes) + binary.BigEndian.PutUint32(index, 0) + return index, nil + }, + }, &testscommon.MongoDBClientStub{ + IncrementWithTransactionCalled: func(coll mongodb.CollectionID, key []byte) (uint32, error) { + return 1, nil + }, + }) + + numCalls := 10000 + var wg sync.WaitGroup + wg.Add(numCalls) + for i := 0; i < numCalls; i++ { + go func(idx int) { + switch idx % 6 { + case 0: + _, err := handler.AllocateBucketIndex() + assert.Nil(t, err) + case 1: + assert.Nil(t, handler.Put([]byte("key"), []byte("data"))) + case 2: + _, err := handler.Get([]byte("key")) + assert.Nil(t, err) + case 3: + assert.Nil(t, handler.Has([]byte("key"))) + case 4: + assert.Nil(t, handler.Close()) + case 5: + _, err := handler.GetLastIndex() + assert.Nil(t, err) + default: + assert.Fail(t, "should not hit default") + } + wg.Done() + }(i) + } + wg.Wait() +} diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 184254a8..66959dbc 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" + "github.com/multiversx/multi-factor-auth-go-service/core" "github.com/multiversx/mx-chain-core-go/core/check" logger "github.com/multiversx/mx-chain-logger-go" "go.mongodb.org/mongo-driver/bson" @@ -158,11 +159,7 @@ func (mdc *mongodbClient) IncrementWithTransaction(collID CollectionID, key []by return nil, err } - uint32Value := binary.BigEndian.Uint32(entry.Value) - index := uint32Value + 1 - - latestIndexBytes := make([]byte, 4) - binary.BigEndian.PutUint32(latestIndexBytes, index) + latestIndexBytes, newIndex := incrementIntegerFromBytes(entry.Value) filter = bson.D{{Key: "_id", Value: string(key)}} update := bson.D{{Key: "$set", @@ -179,7 +176,7 @@ func (mdc *mongodbClient) IncrementWithTransaction(collID CollectionID, key []by return nil, err } - return index, nil + return newIndex, nil } // Step 2: Start a session and run the callback using WithTransaction. @@ -193,11 +190,24 @@ func (mdc *mongodbClient) IncrementWithTransaction(collID CollectionID, key []by if err != nil { return 0, err } - index := newIndex.(uint32) + index, ok := newIndex.(uint32) + if !ok { + return 0, core.ErrInvalidValue + } return index, nil } +func incrementIntegerFromBytes(value []byte) ([]byte, uint32) { + uint32Value := binary.BigEndian.Uint32(value) + newIndex := uint32Value + 1 + + newIndexBytes := make([]byte, 4) + binary.BigEndian.PutUint32(newIndexBytes, newIndex) + + return newIndexBytes, newIndex +} + // Close will close the mongodb client func (mdc *mongodbClient) Close() error { return mdc.client.Disconnect(mdc.ctx) diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index 867c6ede..305335e9 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -13,7 +13,10 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -type testStruct struct{} +type testStruct struct { + Key string `bson:"_id"` + Value []byte `bson:"value"` +} func TestNewMongoDBClient(t *testing.T) { t.Parallel() @@ -223,6 +226,50 @@ func TestMongoDBClient_Remove(t *testing.T) { }) } +func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { + t.Parallel() + + updateWasCalled := false + findWasCalled := false + sessionWasCalled := false + + mongoDBClientWrapper := &testscommon.MongoDBClientWrapperStub{ + DBCollectionCalled: func(dbName, collName string) mongodb.MongoDBCollection { + require.Equal(t, string(mongodb.UsersCollectionID), collName) + + return &testscommon.MongoDBCollectionStub{ + FindOneCalled: func(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult { + findWasCalled = true + return mongo.NewSingleResultFromDocument(&testStruct{Key: "key", Value: []byte{0, 0, 0, 1}}, nil, bson.DefaultRegistry) + }, + UpdateOneCalled: func(ctx context.Context, filter, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { + updateWasCalled = true + return nil, nil + }, + } + }, + StartSessionCalled: func() (mongodb.MongoDBSession, error) { + sessionWasCalled = true + return &testscommon.MongoDBSessionStub{ + WithTransactionCalled: func(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { + fn(mongo.NewSessionContext(context.TODO(), mongo.SessionFromContext(context.TODO()))) + return uint32(5), nil + }, + }, nil + }, + } + + client, err := mongodb.NewClient(mongoDBClientWrapper, "dbName") + require.Nil(t, err) + + _, err = client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) + require.Nil(t, err) + + require.True(t, findWasCalled) + require.True(t, updateWasCalled) + require.True(t, sessionWasCalled) +} + func TestMongoDBClient_Close(t *testing.T) { t.Parallel() diff --git a/mongodb/interface.go b/mongodb/interface.go index c93c987c..b96f56ec 100644 --- a/mongodb/interface.go +++ b/mongodb/interface.go @@ -12,7 +12,7 @@ type MongoDBClientWrapper interface { Connect(ctx context.Context) error Disconnect(ctx context.Context) error DBCollection(dbName string, collName string) MongoDBCollection - StartSession() (mongo.Session, error) + StartSession() (MongoDBSession, error) IsInterfaceNil() bool } @@ -33,3 +33,10 @@ type MongoDBClient interface { Close() error IsInterfaceNil() bool } + +// MongoDBSession defines what a mongodb session should do +type MongoDBSession interface { + WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), + opts ...*options.TransactionOptions) (interface{}, error) + EndSession(context.Context) +} diff --git a/mongodb/mongodbClientWrapper.go b/mongodb/mongodbClientWrapper.go index 02ad30a3..df98b9f9 100644 --- a/mongodb/mongodbClientWrapper.go +++ b/mongodb/mongodbClientWrapper.go @@ -32,7 +32,7 @@ func (m *mongoDBClientWrapper) DBCollection(dbName string, coll string) MongoDBC } // DBCollection will return the specified collection object -func (m *mongoDBClientWrapper) StartSession() (mongo.Session, error) { +func (m *mongoDBClientWrapper) StartSession() (MongoDBSession, error) { return m.client.StartSession() } diff --git a/testscommon/mongoDBClientStub.go b/testscommon/mongoDBClientStub.go index b13deca5..3e5a7602 100644 --- a/testscommon/mongoDBClientStub.go +++ b/testscommon/mongoDBClientStub.go @@ -13,6 +13,7 @@ type MongoDBClientWrapperStub struct { DBCollectionCalled func(dbName string, collName string) mongodb.MongoDBCollection ConnectCalled func(ctx context.Context) error DisconnectCalled func(ctx context.Context) error + StartSessionCalled func() (mongodb.MongoDBSession, error) } // DBCollection - @@ -42,6 +43,15 @@ func (m *MongoDBClientWrapperStub) Disconnect(ctx context.Context) error { return nil } +// StartSession - +func (m *MongoDBClientWrapperStub) StartSession() (mongodb.MongoDBSession, error) { + if m.StartSessionCalled != nil { + return m.StartSessionCalled() + } + + return nil, nil +} + // IsInterfaceNil - func (m *MongoDBClientWrapperStub) IsInterfaceNil() bool { return m == nil @@ -81,3 +91,25 @@ func (m *MongoDBCollectionStub) DeleteOne(ctx context.Context, filter interface{ return nil, nil } + +// MongoDBSessionStub - +type MongoDBSessionStub struct { + WithTransactionCalled func(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) + EndSessionCalled func(_ context.Context) +} + +// WithTransaction - +func (m *MongoDBSessionStub) WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { + if m.WithTransactionCalled != nil { + return m.WithTransactionCalled(ctx, fn) + } + + return nil, nil +} + +// EndSession - +func (m *MongoDBSessionStub) EndSession(ctx context.Context) { + if m.EndSessionCalled != nil { + m.EndSessionCalled(ctx) + } +} diff --git a/testscommon/mongoDBClientWrapperStub.go b/testscommon/mongoDBClientWrapperStub.go index 2f5e04f9..e78b9c94 100644 --- a/testscommon/mongoDBClientWrapperStub.go +++ b/testscommon/mongoDBClientWrapperStub.go @@ -2,26 +2,16 @@ package testscommon import ( "github.com/multiversx/multi-factor-auth-go-service/mongodb" - "go.mongodb.org/mongo-driver/mongo" ) // MongoDBClientStub implemented mongodb client wraper interface type MongoDBClientStub struct { - GetCollectionCalled func(coll mongodb.CollectionID) *mongo.Collection - PutCalled func(coll mongodb.CollectionID, key []byte, data []byte) error - GetCalled func(coll mongodb.CollectionID, key []byte) ([]byte, error) - HasCalled func(coll mongodb.CollectionID, key []byte) error - RemoveCalled func(coll mongodb.CollectionID, key []byte) error - CloseCalled func() error -} - -// GetCollection - -func (m *MongoDBClientStub) GetCollection(coll mongodb.CollectionID) *mongo.Collection { - if m.GetCollectionCalled != nil { - return m.GetCollectionCalled(coll) - } - - return nil + PutCalled func(coll mongodb.CollectionID, key []byte, data []byte) error + GetCalled func(coll mongodb.CollectionID, key []byte) ([]byte, error) + HasCalled func(coll mongodb.CollectionID, key []byte) error + RemoveCalled func(coll mongodb.CollectionID, key []byte) error + IncrementWithTransactionCalled func(coll mongodb.CollectionID, key []byte) (uint32, error) + CloseCalled func() error } // Put - @@ -60,6 +50,15 @@ func (m *MongoDBClientStub) Remove(coll mongodb.CollectionID, key []byte) error return nil } +// IncrementWithTransaction - +func (m *MongoDBClientStub) IncrementWithTransaction(coll mongodb.CollectionID, key []byte) (uint32, error) { + if m.IncrementWithTransactionCalled != nil { + return m.IncrementWithTransactionCalled(coll, key) + } + + return 0, nil +} + // Close - func (m *MongoDBClientStub) Close() error { if m.CloseCalled != nil { From 07ab18fa63a064fb7d67fe1210919d550448a1d2 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Fri, 10 Mar 2023 15:02:26 +0200 Subject: [PATCH 07/35] small setup changes --- Makefile | 1 + core/errors.go | 3 +++ docker/mongodb-cluster-full.yml | 2 +- docker/nginx/nginx.conf | 3 ++- 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 28ca8c8b..27419ee4 100644 --- a/Makefile +++ b/Makefile @@ -63,6 +63,7 @@ docker-run: -p 8080:8080 \ --name ${container_name} \ ${image}:${image_tag} + docker rm ${container_name} docker-new: docker-build docker-run diff --git a/core/errors.go b/core/errors.go index 6f22426f..f92fcc83 100644 --- a/core/errors.go +++ b/core/errors.go @@ -28,3 +28,6 @@ var ErrNilBucket = errors.New("nil bucket") // ErrNilMongoDBClient signals that a nil mongodb client has been provided var ErrNilMongoDBClient = errors.New("nil mongodb client") + +// ErrNilStorer is raised when a nil storer has been provided +var ErrNilStorer = errors.New("nil storer") diff --git a/docker/mongodb-cluster-full.yml b/docker/mongodb-cluster-full.yml index 7b93ae03..5b4091b0 100644 --- a/docker/mongodb-cluster-full.yml +++ b/docker/mongodb-cluster-full.yml @@ -53,7 +53,7 @@ services: tcs: build: context: .. - dockerfile: ../Dockerfile + dockerfile: Dockerfile image: multi-factor-auth:latest ports: - 8080:8080 diff --git a/docker/nginx/nginx.conf b/docker/nginx/nginx.conf index 7b94d638..011d289a 100644 --- a/docker/nginx/nginx.conf +++ b/docker/nginx/nginx.conf @@ -49,7 +49,8 @@ http { client_max_body_size 100M; proxy_pass http://tcs; -# timeout for an idle keepalive connection to an upstream server will stay open + + # timeout for an idle keepalive connection to an upstream server will stay open keepalive_timeout 50s; } } From 74feef99bf3d0efc547222bcbf27e1a4af581f11 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Fri, 10 Mar 2023 22:11:07 +0200 Subject: [PATCH 08/35] add second tcs instance in docker compose mongo full setup --- docker/mongodb-cluster-full.yml | 15 +++++++++++++-- docker/nginx/nginx.conf | 3 ++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/docker/mongodb-cluster-full.yml b/docker/mongodb-cluster-full.yml index 5b4091b0..c3da8a11 100644 --- a/docker/mongodb-cluster-full.yml +++ b/docker/mongodb-cluster-full.yml @@ -50,7 +50,7 @@ services: restart: always entrypoint: [ "/usr/bin/mongod", "--bind_ip_all", "--replSet", "mongoReplSet" ] - tcs: + tcs0: build: context: .. dockerfile: Dockerfile @@ -60,12 +60,23 @@ services: networks: - mongo + tcs1: + build: + context: .. + dockerfile: Dockerfile + image: multi-factor-auth:latest + ports: + - 8081:8080 + networks: + - mongo + nginx: image: nginx:latest volumes: - ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro depends_on: - - tcs + - tcs0 + - tcs1 networks: - mongo ports: diff --git a/docker/nginx/nginx.conf b/docker/nginx/nginx.conf index 011d289a..c0ffecf0 100644 --- a/docker/nginx/nginx.conf +++ b/docker/nginx/nginx.conf @@ -32,7 +32,8 @@ http { default_type application/octet-stream; upstream tcs { - server tcs:8080; + server tcs0:8080; + server tcs1:8080; keepalive 16; } From 4b886710ec349f9843dcde86c3c32cab991f1cfb Mon Sep 17 00:00:00 2001 From: ssd04 Date: Mon, 13 Mar 2023 15:20:04 +0200 Subject: [PATCH 09/35] rename sharded storage with index interface --- core/interface.go | 4 ++-- handlers/interface.go | 2 +- handlers/storage/dbOTPHandler.go | 12 ++++++------ handlers/storage/factory/shardedStorageFactory.go | 6 +++--- resolver/serviceResolver.go | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/core/interface.go b/core/interface.go index 82f6f972..231210c8 100644 --- a/core/interface.go +++ b/core/interface.go @@ -81,8 +81,8 @@ type BucketIndexHandler interface { IsInterfaceNil() bool } -// ShardedStorageWithIndex defines the methods for a component that holds multiple BucketIndexHandler -type ShardedStorageWithIndex interface { +// StorageWithIndex defines the methods for a component that holds multiple BucketIndexHandler +type StorageWithIndex interface { AllocateIndex(address []byte) (uint32, error) Put(key, data []byte) error Get(key []byte) ([]byte, error) diff --git a/handlers/interface.go b/handlers/interface.go index 3040f0cc..faf84c16 100644 --- a/handlers/interface.go +++ b/handlers/interface.go @@ -30,6 +30,6 @@ type OTP interface { // ShardedStorageFactory defines the methods available for a sharded storage factory type ShardedStorageFactory interface { - Create() (core.ShardedStorageWithIndex, error) + Create() (core.StorageWithIndex, error) IsInterfaceNil() bool } diff --git a/handlers/storage/dbOTPHandler.go b/handlers/storage/dbOTPHandler.go index 58cd4d83..84800584 100644 --- a/handlers/storage/dbOTPHandler.go +++ b/handlers/storage/dbOTPHandler.go @@ -17,15 +17,15 @@ const ( // ArgDBOTPHandler is the DTO used to create a new instance of dbOTPHandler type ArgDBOTPHandler struct { - DB core.ShardedStorageWithIndex - TOTPHandler handlers.TOTPHandler + DB core.StorageWithIndex + TOTPHandler handlers.TOTPHandler Marshaller core.Marshaller DelayBetweenOTPUpdatesInSec int64 } type dbOTPHandler struct { - db core.ShardedStorageWithIndex - totpHandler handlers.TOTPHandler + db core.StorageWithIndex + totpHandler handlers.TOTPHandler marshaller core.Marshaller getTimeHandler func() time.Time delayBetweenOTPUpdatesInSec int64 @@ -40,8 +40,8 @@ func NewDBOTPHandler(args ArgDBOTPHandler) (*dbOTPHandler, error) { } handler := &dbOTPHandler{ - db: args.DB, - totpHandler: args.TOTPHandler, + db: args.DB, + totpHandler: args.TOTPHandler, getTimeHandler: time.Now, marshaller: args.Marshaller, delayBetweenOTPUpdatesInSec: args.DelayBetweenOTPUpdatesInSec, diff --git a/handlers/storage/factory/shardedStorageFactory.go b/handlers/storage/factory/shardedStorageFactory.go index 3289630f..d3670eae 100644 --- a/handlers/storage/factory/shardedStorageFactory.go +++ b/handlers/storage/factory/shardedStorageFactory.go @@ -24,7 +24,7 @@ func NewShardedStorageFactory(config config.Config) *shardedStorageFactory { } // Create returns a new instance of ShardedStorageWithIndex -func (ssf *shardedStorageFactory) Create() (core.ShardedStorageWithIndex, error) { +func (ssf *shardedStorageFactory) Create() (core.StorageWithIndex, error) { switch ssf.cfg.ShardedStorage.DBType { case core.LevelDB: return ssf.createLocalDB() @@ -36,7 +36,7 @@ func (ssf *shardedStorageFactory) Create() (core.ShardedStorageWithIndex, error) } } -func (ssf *shardedStorageFactory) createMongoDB() (core.ShardedStorageWithIndex, error) { +func (ssf *shardedStorageFactory) createMongoDB() (core.StorageWithIndex, error) { client, err := mongodb.CreateMongoDBClient(ssf.cfg.MongoDB) if err != nil { return nil, err @@ -66,7 +66,7 @@ func (ssf *shardedStorageFactory) createMongoDB() (core.ShardedStorageWithIndex, return bucket.NewShardedStorageWithIndex(argsShardedStorageWithIndex) } -func (ssf *shardedStorageFactory) createLocalDB() (core.ShardedStorageWithIndex, error) { +func (ssf *shardedStorageFactory) createLocalDB() (core.StorageWithIndex, error) { numbOfBuckets := ssf.cfg.Buckets.NumberOfBuckets bucketIDProvider, err := bucket.NewBucketIDProvider(numbOfBuckets) if err != nil { diff --git a/resolver/serviceResolver.go b/resolver/serviceResolver.go index 76ec7358..7bc3d7e9 100644 --- a/resolver/serviceResolver.go +++ b/resolver/serviceResolver.go @@ -46,7 +46,7 @@ type ArgServiceResolver struct { SignatureVerifier builders.Signer GuardedTxBuilder core.GuardedTxBuilder RequestTime time.Duration - RegisteredUsersDB core.ShardedStorageWithIndex + RegisteredUsersDB core.StorageWithIndex KeyGen crypto.KeyGenerator CryptoComponentsHolderFactory CryptoComponentsHolderFactory SkipTxUserSigVerify bool @@ -64,7 +64,7 @@ type serviceResolver struct { requestTime time.Duration signatureVerifier builders.Signer guardedTxBuilder core.GuardedTxBuilder - registeredUsersDB core.ShardedStorageWithIndex + registeredUsersDB core.StorageWithIndex managedPrivateKey crypto.PrivateKey keyGen crypto.KeyGenerator cryptoComponentsHolderFactory CryptoComponentsHolderFactory From a2cdb2c9048e9f7d93c5d17263196fd43ffffe31 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Mon, 13 Mar 2023 15:20:32 +0200 Subject: [PATCH 10/35] fix reference to mongo in factory --- handlers/storage/factory/shardedStorageFactory.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/handlers/storage/factory/shardedStorageFactory.go b/handlers/storage/factory/shardedStorageFactory.go index d3670eae..938d9a9a 100644 --- a/handlers/storage/factory/shardedStorageFactory.go +++ b/handlers/storage/factory/shardedStorageFactory.go @@ -6,8 +6,8 @@ import ( "github.com/multiversx/multi-factor-auth-go-service/config" "github.com/multiversx/multi-factor-auth-go-service/core" "github.com/multiversx/multi-factor-auth-go-service/handlers" - "github.com/multiversx/multi-factor-auth-go-service/handlers/storage" "github.com/multiversx/multi-factor-auth-go-service/handlers/storage/bucket" + "github.com/multiversx/multi-factor-auth-go-service/handlers/storage/mongo" "github.com/multiversx/multi-factor-auth-go-service/mongodb" "github.com/multiversx/mx-chain-storage-go/storageUnit" ) @@ -42,7 +42,7 @@ func (ssf *shardedStorageFactory) createMongoDB() (core.StorageWithIndex, error) return nil, err } - storer, err := storage.NewMongoDBStorerHandler(client, mongodb.UsersCollectionID) + storer, err := mongo.NewMongoDBStorerHandler(client, mongodb.UsersCollectionID) if err != nil { return nil, err } From 0834a8af4b59157de109d997b9d8e1410ad58a89 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Mon, 13 Mar 2023 21:11:02 +0200 Subject: [PATCH 11/35] added update with checker implementation --- core/interface.go | 6 ++ handlers/storage/bucket/bucketIndexHandler.go | 18 ++++ .../storage/bucket/mongodbIndexHandler.go | 4 + .../storage/bucket/shardedStorageWithIndex.go | 9 ++ handlers/storage/dbOTPHandler.go | 65 +++++++++++--- .../storage/factory/shardedStorageFactory.go | 6 +- mongodb/dbClient.go | 63 +++++++++++++- mongodb/dbClient_test.go | 86 +++++++++---------- mongodb/interface.go | 10 ++- mongodb/mongodbClientWrapper.go | 2 +- 10 files changed, 206 insertions(+), 63 deletions(-) diff --git a/core/interface.go b/core/interface.go index 231210c8..9a7365c6 100644 --- a/core/interface.go +++ b/core/interface.go @@ -78,6 +78,7 @@ type BucketIndexHandler interface { Close() error AllocateBucketIndex() (uint32, error) GetLastIndex() (uint32, error) + UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error IsInterfaceNil() bool } @@ -91,3 +92,8 @@ type StorageWithIndex interface { Count() (uint32, error) IsInterfaceNil() bool } + +type StorageWithIndexChecker interface { + StorageWithIndex + UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error +} diff --git a/handlers/storage/bucket/bucketIndexHandler.go b/handlers/storage/bucket/bucketIndexHandler.go index e7e48789..2bccac2a 100644 --- a/handlers/storage/bucket/bucketIndexHandler.go +++ b/handlers/storage/bucket/bucketIndexHandler.go @@ -71,6 +71,24 @@ func (handler *bucketIndexHandler) Has(key []byte) error { return handler.bucket.Has(key) } +func (handler *bucketIndexHandler) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { + data, err := handler.bucket.Get(key) + if err != nil { + return nil + } + + newData, err := fn(data) + if err != nil { + return err + } + newDataBytes, ok := newData.([]byte) + if !ok { + return core.ErrInvalidValue + } + + return handler.bucket.Put(key, newDataBytes) +} + // GetLastIndex returns the last index that was allocated func (handler *bucketIndexHandler) GetLastIndex() (uint32, error) { handler.mut.RLock() diff --git a/handlers/storage/bucket/mongodbIndexHandler.go b/handlers/storage/bucket/mongodbIndexHandler.go index d66d4ff5..783aaca3 100644 --- a/handlers/storage/bucket/mongodbIndexHandler.go +++ b/handlers/storage/bucket/mongodbIndexHandler.go @@ -71,6 +71,10 @@ func (handler *mongodbIndexHandler) Has(key []byte) error { return handler.storer.Has(key) } +func (handler *mongodbIndexHandler) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { + return handler.mongodbClient.ReadWriteWithCheck(mongodb.UsersCollectionID, key, fn) +} + // GetLastIndex returns the last index that was allocated func (handler *mongodbIndexHandler) GetLastIndex() (uint32, error) { handler.mut.RLock() diff --git a/handlers/storage/bucket/shardedStorageWithIndex.go b/handlers/storage/bucket/shardedStorageWithIndex.go index 92723aea..d8925a80 100644 --- a/handlers/storage/bucket/shardedStorageWithIndex.go +++ b/handlers/storage/bucket/shardedStorageWithIndex.go @@ -95,6 +95,15 @@ func (sswi *shardedStorageWithIndex) Has(key []byte) error { return bucket.Has(key) } +func (sswi *shardedStorageWithIndex) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { + bucket, _, err := sswi.getBucketForKey(key) + if err != nil { + return err + } + + return bucket.UpdateWithCheck(key, fn) +} + // Count returns the number of elements in all buckets func (sswi *shardedStorageWithIndex) Count() (uint32, error) { count := uint32(0) diff --git a/handlers/storage/dbOTPHandler.go b/handlers/storage/dbOTPHandler.go index 84800584..9b92ac90 100644 --- a/handlers/storage/dbOTPHandler.go +++ b/handlers/storage/dbOTPHandler.go @@ -17,14 +17,14 @@ const ( // ArgDBOTPHandler is the DTO used to create a new instance of dbOTPHandler type ArgDBOTPHandler struct { - DB core.StorageWithIndex + DB core.StorageWithIndexChecker TOTPHandler handlers.TOTPHandler Marshaller core.Marshaller DelayBetweenOTPUpdatesInSec int64 } type dbOTPHandler struct { - db core.StorageWithIndex + db core.StorageWithIndexChecker totpHandler handlers.TOTPHandler marshaller core.Marshaller getTimeHandler func() time.Time @@ -85,19 +85,31 @@ func (handler *dbOTPHandler) Save(account, guardian []byte, otp handlers.OTP) er return handler.saveNewOTP(key, otp) } - oldOTPInfo, err := handler.getOldOTPInfo(key) - if err != nil { - return err + checker := func(data interface{}) (interface{}, error) { + otpInfoBytes, ok := data.([]byte) + if !ok { + return nil, core.ErrInvalidValue + } + + err := handler.checkOtpUpdateAllowed(otpInfoBytes) + if err != nil { + return nil, err + } + + buff, err := handler.getMarshalledOtpData(otp) + if err != nil { + return nil, err + } + + return buff, nil } - currentTimestamp := handler.getTimeHandler().Unix() - isOTPUpdateAllowed := oldOTPInfo.LastTOTPChangeTimestamp+handler.delayBetweenOTPUpdatesInSec < currentTimestamp - if !isOTPUpdateAllowed { - return fmt.Errorf("%w, last update was %d seconds ago", - handlers.ErrRegistrationFailed, currentTimestamp-oldOTPInfo.LastTOTPChangeTimestamp) + err = handler.db.UpdateWithCheck(key, checker) + if err != nil { + return err } - return handler.saveNewOTP(key, otp) + return nil } // Get returns the one time password @@ -126,6 +138,37 @@ func (handler *dbOTPHandler) getOldOTPInfo(key []byte) (*core.OTPInfo, error) { return otpInfo, nil } +func (handler *dbOTPHandler) getMarshalledOtpData(otp handlers.OTP) ([]byte, error) { + newOtpInfo := &core.OTPInfo{ + LastTOTPChangeTimestamp: handler.getTimeHandler().Unix(), + } + + var err error + newOtpInfo.OTP, err = otp.ToBytes() + if err != nil { + return nil, err + } + + return handler.marshaller.Marshal(newOtpInfo) +} + +func (handler *dbOTPHandler) checkOtpUpdateAllowed(otpInfoBytes []byte) error { + otpInfo := &core.OTPInfo{} + err := handler.marshaller.Unmarshal(otpInfo, otpInfoBytes) + if err != nil { + return err + } + + currentTimestamp := handler.getTimeHandler().Unix() + isOTPUpdateAllowed := otpInfo.LastTOTPChangeTimestamp+handler.delayBetweenOTPUpdatesInSec < currentTimestamp + if !isOTPUpdateAllowed { + return fmt.Errorf("%w, last update was %d seconds ago", + handlers.ErrRegistrationFailed, currentTimestamp-otpInfo.LastTOTPChangeTimestamp) + } + + return nil +} + func (handler *dbOTPHandler) saveNewOTP(key []byte, otp handlers.OTP) error { otpInfo := &core.OTPInfo{ LastTOTPChangeTimestamp: handler.getTimeHandler().Unix(), diff --git a/handlers/storage/factory/shardedStorageFactory.go b/handlers/storage/factory/shardedStorageFactory.go index 938d9a9a..f2a4a89d 100644 --- a/handlers/storage/factory/shardedStorageFactory.go +++ b/handlers/storage/factory/shardedStorageFactory.go @@ -24,7 +24,7 @@ func NewShardedStorageFactory(config config.Config) *shardedStorageFactory { } // Create returns a new instance of ShardedStorageWithIndex -func (ssf *shardedStorageFactory) Create() (core.StorageWithIndex, error) { +func (ssf *shardedStorageFactory) Create() (core.StorageWithIndexChecker, error) { switch ssf.cfg.ShardedStorage.DBType { case core.LevelDB: return ssf.createLocalDB() @@ -36,7 +36,7 @@ func (ssf *shardedStorageFactory) Create() (core.StorageWithIndex, error) { } } -func (ssf *shardedStorageFactory) createMongoDB() (core.StorageWithIndex, error) { +func (ssf *shardedStorageFactory) createMongoDB() (core.StorageWithIndexChecker, error) { client, err := mongodb.CreateMongoDBClient(ssf.cfg.MongoDB) if err != nil { return nil, err @@ -66,7 +66,7 @@ func (ssf *shardedStorageFactory) createMongoDB() (core.StorageWithIndex, error) return bucket.NewShardedStorageWithIndex(argsShardedStorageWithIndex) } -func (ssf *shardedStorageFactory) createLocalDB() (core.StorageWithIndex, error) { +func (ssf *shardedStorageFactory) createLocalDB() (core.StorageWithIndexChecker, error) { numbOfBuckets := ssf.cfg.Buckets.NumberOfBuckets bucketIDProvider, err := bucket.NewBucketIDProvider(numbOfBuckets) if err != nil { diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 5feed1bb..5e4291fd 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -10,6 +10,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/writeconcern" ) var log = logger.GetOrCreate("mongodb") @@ -76,7 +77,7 @@ func (mdc *mongodbClient) Put(collID CollectionID, key []byte, data []byte) erro opts := options.Update().SetUpsert(true) - log.Debug("Put", "key", string(key), "value", string(data)) + log.Trace("Put", "key", string(key), "value", string(data)) _, err := coll.UpdateOne(mdc.ctx, filter, update, opts) if err != nil { @@ -110,7 +111,7 @@ func (mdc *mongodbClient) Get(collID CollectionID, key []byte) ([]byte, error) { return nil, err } - log.Debug("Get", "key", string(key)) + log.Trace("Get", "key", string(key)) return entry.Value, nil } @@ -118,7 +119,7 @@ func (mdc *mongodbClient) Get(collID CollectionID, key []byte) ([]byte, error) { // Has will return true if the provided key exists in the collection func (mdc *mongodbClient) Has(collID CollectionID, key []byte) error { _, err := mdc.findOne(collID, key) - log.Debug("Has", "key", string(key)) + log.Trace("Has", "key", string(key)) return err } @@ -139,6 +140,61 @@ func (mdc *mongodbClient) Remove(collID CollectionID, key []byte) error { return nil } +// IncrementWithTransaction will increment the value for the provided key, within a transaction +func (mdc *mongodbClient) ReadWriteWithCheck( + collID CollectionID, + key []byte, + checker func(data interface{}) (interface{}, error), +) error { + session, err := mdc.client.StartSession() + if err != nil { + return err + } + defer session.EndSession(mdc.ctx) + + wc := writeconcern.New(writeconcern.WMajority()) + txnOptions := options.Transaction().SetWriteConcern(wc) + + sessionCallback := func(ctx mongo.SessionContext) error { + err := session.StartTransaction(txnOptions) + if err != nil { + return err + } + + value, err := mdc.Get(collID, key) + if err != nil { + return err + } + + retValue, err := checker(value) + if err != nil { + return err + } + retValueBytes, ok := retValue.([]byte) + if !ok { + return core.ErrInvalidValue + } + + err = mdc.Put(collID, key, retValueBytes) + + if err = session.CommitTransaction(ctx); err != nil { + return err + } + + return nil + } + + err = mongo.WithSession(mdc.ctx, session, sessionCallback) + if err != nil { + if err := session.AbortTransaction(mdc.ctx); err != nil { + return err + } + return err + } + + return nil +} + // IncrementWithTransaction will increment the value for the provided key, within a transaction func (mdc *mongodbClient) IncrementWithTransaction(collID CollectionID, key []byte) (uint32, error) { coll, ok := mdc.collections[collID] @@ -175,7 +231,6 @@ func (mdc *mongodbClient) IncrementWithTransaction(collID CollectionID, key []by return newIndex, nil } - // Step 2: Start a session and run the callback using WithTransaction. session, err := mdc.client.StartSession() if err != nil { return 0, err diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index 305335e9..5acc13da 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -226,49 +226,49 @@ func TestMongoDBClient_Remove(t *testing.T) { }) } -func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { - t.Parallel() - - updateWasCalled := false - findWasCalled := false - sessionWasCalled := false - - mongoDBClientWrapper := &testscommon.MongoDBClientWrapperStub{ - DBCollectionCalled: func(dbName, collName string) mongodb.MongoDBCollection { - require.Equal(t, string(mongodb.UsersCollectionID), collName) - - return &testscommon.MongoDBCollectionStub{ - FindOneCalled: func(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult { - findWasCalled = true - return mongo.NewSingleResultFromDocument(&testStruct{Key: "key", Value: []byte{0, 0, 0, 1}}, nil, bson.DefaultRegistry) - }, - UpdateOneCalled: func(ctx context.Context, filter, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { - updateWasCalled = true - return nil, nil - }, - } - }, - StartSessionCalled: func() (mongodb.MongoDBSession, error) { - sessionWasCalled = true - return &testscommon.MongoDBSessionStub{ - WithTransactionCalled: func(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { - fn(mongo.NewSessionContext(context.TODO(), mongo.SessionFromContext(context.TODO()))) - return uint32(5), nil - }, - }, nil - }, - } - - client, err := mongodb.NewClient(mongoDBClientWrapper, "dbName") - require.Nil(t, err) - - _, err = client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) - require.Nil(t, err) - - require.True(t, findWasCalled) - require.True(t, updateWasCalled) - require.True(t, sessionWasCalled) -} +// func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { +// t.Parallel() + +// updateWasCalled := false +// findWasCalled := false +// sessionWasCalled := false + +// mongoDBClientWrapper := &testscommon.MongoDBClientWrapperStub{ +// DBCollectionCalled: func(dbName, collName string) mongodb.MongoDBCollection { +// require.Equal(t, string(mongodb.UsersCollectionID), collName) + +// return &testscommon.MongoDBCollectionStub{ +// FindOneCalled: func(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult { +// findWasCalled = true +// return mongo.NewSingleResultFromDocument(&testStruct{Key: "key", Value: []byte{0, 0, 0, 1}}, nil, bson.DefaultRegistry) +// }, +// UpdateOneCalled: func(ctx context.Context, filter, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { +// updateWasCalled = true +// return nil, nil +// }, +// } +// }, +// StartSessionCalled: func() (mongo.Session, error) { +// sessionWasCalled = true +// return &testscommon.MongoDBSessionStub{ +// WithTransactionCalled: func(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { +// fn(mongo.NewSessionContext(context.TODO(), mongo.SessionFromContext(context.TODO()))) +// return uint32(5), nil +// }, +// }, nil +// }, +// } + +// client, err := mongodb.NewClient(mongoDBClientWrapper, "dbName") +// require.Nil(t, err) + +// _, err = client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) +// require.Nil(t, err) + +// require.True(t, findWasCalled) +// require.True(t, updateWasCalled) +// require.True(t, sessionWasCalled) +// } func TestMongoDBClient_Close(t *testing.T) { t.Parallel() diff --git a/mongodb/interface.go b/mongodb/interface.go index b96f56ec..bda97387 100644 --- a/mongodb/interface.go +++ b/mongodb/interface.go @@ -12,7 +12,7 @@ type MongoDBClientWrapper interface { Connect(ctx context.Context) error Disconnect(ctx context.Context) error DBCollection(dbName string, collName string) MongoDBCollection - StartSession() (MongoDBSession, error) + StartSession() (mongo.Session, error) IsInterfaceNil() bool } @@ -30,12 +30,20 @@ type MongoDBClient interface { Has(coll CollectionID, key []byte) error Remove(coll CollectionID, key []byte) error IncrementWithTransaction(collID CollectionID, key []byte) (uint32, error) + ReadWriteWithCheck( + collID CollectionID, + key []byte, + checker func(data interface{}) (interface{}, error), + ) error Close() error IsInterfaceNil() bool } // MongoDBSession defines what a mongodb session should do type MongoDBSession interface { + StartTransaction(...*options.TransactionOptions) error + AbortTransaction(context.Context) error + CommitTransaction(context.Context) error WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) EndSession(context.Context) diff --git a/mongodb/mongodbClientWrapper.go b/mongodb/mongodbClientWrapper.go index df98b9f9..02ad30a3 100644 --- a/mongodb/mongodbClientWrapper.go +++ b/mongodb/mongodbClientWrapper.go @@ -32,7 +32,7 @@ func (m *mongoDBClientWrapper) DBCollection(dbName string, coll string) MongoDBC } // DBCollection will return the specified collection object -func (m *mongoDBClientWrapper) StartSession() (MongoDBSession, error) { +func (m *mongoDBClientWrapper) StartSession() (mongo.Session, error) { return m.client.StartSession() } From f32a5139b822f0a7122e1bcd7b91867434f15c2a Mon Sep 17 00:00:00 2001 From: ssd04 Date: Tue, 14 Mar 2023 10:59:44 +0200 Subject: [PATCH 12/35] fix unit tests in db otp handler --- handlers/storage/dbOTPHandler_test.go | 27 ++++++++++++++++++++-- testscommon/bucketIndexHandlerStub.go | 10 ++++++++ testscommon/mongoDBClientWrapperStub.go | 18 +++++++++++++++ testscommon/shardedStorageWithIndexMock.go | 14 +++++++++++ testscommon/shardedStorageWithIndexStub.go | 10 ++++++++ 5 files changed, 77 insertions(+), 2 deletions(-) diff --git a/handlers/storage/dbOTPHandler_test.go b/handlers/storage/dbOTPHandler_test.go index 48c56cdc..bbcab4bc 100644 --- a/handlers/storage/dbOTPHandler_test.go +++ b/handlers/storage/dbOTPHandler_test.go @@ -21,7 +21,7 @@ var expectedErr = errors.New("expected error") func createMockArgs() storage.ArgDBOTPHandler { return storage.ArgDBOTPHandler{ - DB: testscommon.NewShardedStorageWithIndexMock(), + DB: testscommon.NewShardedStorageWithIndexMock(), TOTPHandler: &testscommon.TOTPHandlerStub{}, Marshaller: &testscommon.MarshallerStub{}, DelayBetweenOTPUpdatesInSec: 5, @@ -201,6 +201,10 @@ func TestDBOTPHandler_Save(t *testing.T) { args := createMockArgs() args.DB = &testscommon.ShardedStorageWithIndexStub{ + UpdateWithCheckCalled: func(key []byte, fn func(data interface{}) (interface{}, error)) error { + _, err := args.DB.Get(key) + return err + }, GetCalled: func(key []byte) ([]byte, error) { return nil, expectedErr }, @@ -217,7 +221,12 @@ func TestDBOTPHandler_Save(t *testing.T) { t.Parallel() args := createMockArgs() - args.DB = &testscommon.ShardedStorageWithIndexStub{} + args.DB = &testscommon.ShardedStorageWithIndexStub{ + UpdateWithCheckCalled: func(key []byte, fn func(data interface{}) (interface{}, error)) error { + _, err := fn([]byte("badEncodedData")) + return err + }, + } args.Marshaller = &testscommon.MarshallerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { return expectedErr @@ -347,6 +356,20 @@ func TestDBOTPHandler_Save(t *testing.T) { GetCalled: func(key []byte) ([]byte, error) { return mockDB.Get(key) }, + UpdateWithCheckCalled: func(key []byte, fn func(data interface{}) (interface{}, error)) error { + data, err := mockDB.Get(key) + if err != nil { + return err + } + + newData, err := fn(data) + if err != nil { + return err + } + + atomic.AddUint32(&putCounter, 1) + return mockDB.Put(key, newData.([]byte)) + }, } args.Marshaller = &mock.MarshalizerMock{} handler, err := storage.NewDBOTPHandler(args) diff --git a/testscommon/bucketIndexHandlerStub.go b/testscommon/bucketIndexHandlerStub.go index a69413b7..74351094 100644 --- a/testscommon/bucketIndexHandlerStub.go +++ b/testscommon/bucketIndexHandlerStub.go @@ -8,6 +8,7 @@ type BucketIndexHandlerStub struct { CloseCalled func() error AllocateBucketIndexCalled func() (uint32, error) GetLastIndexCalled func() (uint32, error) + UpdateWithCheckCalled func(key []byte, fn func(data interface{}) (interface{}, error)) error } // Put - @@ -58,6 +59,15 @@ func (stub *BucketIndexHandlerStub) GetLastIndex() (uint32, error) { return 0, nil } +// UpdateWithCheck - +func (stub *BucketIndexHandlerStub) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { + if stub.UpdateWithCheckCalled != nil { + return stub.UpdateWithCheckCalled(key, fn) + } + + return nil +} + // IsInterfaceNil - func (stub *BucketIndexHandlerStub) IsInterfaceNil() bool { return stub == nil diff --git a/testscommon/mongoDBClientWrapperStub.go b/testscommon/mongoDBClientWrapperStub.go index e78b9c94..38a002de 100644 --- a/testscommon/mongoDBClientWrapperStub.go +++ b/testscommon/mongoDBClientWrapperStub.go @@ -12,6 +12,11 @@ type MongoDBClientStub struct { RemoveCalled func(coll mongodb.CollectionID, key []byte) error IncrementWithTransactionCalled func(coll mongodb.CollectionID, key []byte) (uint32, error) CloseCalled func() error + ReadWriteWithCheckCalled func( + collID mongodb.CollectionID, + key []byte, + checker func(data interface{}) (interface{}, error), + ) error } // Put - @@ -59,6 +64,19 @@ func (m *MongoDBClientStub) IncrementWithTransaction(coll mongodb.CollectionID, return 0, nil } +// ReadWriteWithCheck - +func (m *MongoDBClientStub) ReadWriteWithCheck( + collID mongodb.CollectionID, + key []byte, + checker func(data interface{}) (interface{}, error), +) error { + if m.ReadWriteWithCheckCalled != nil { + return m.ReadWriteWithCheckCalled(collID, key, checker) + } + + return nil +} + // Close - func (m *MongoDBClientStub) Close() error { if m.CloseCalled != nil { diff --git a/testscommon/shardedStorageWithIndexMock.go b/testscommon/shardedStorageWithIndexMock.go index 1fd7e354..52603e97 100644 --- a/testscommon/shardedStorageWithIndexMock.go +++ b/testscommon/shardedStorageWithIndexMock.go @@ -64,6 +64,20 @@ func (mock *shardedStorageWithIndexMock) Count() (uint32, error) { return uint32(len(mock.cache)), nil } +func (mock *shardedStorageWithIndexMock) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { + data, err := mock.Get(key) + if err != nil { + return err + } + + newData, err := fn(data) + if err != nil { + return err + } + + return mock.Put(key, newData.([]byte)) +} + // IsInterfaceNil - func (mock *shardedStorageWithIndexMock) IsInterfaceNil() bool { return mock == nil diff --git a/testscommon/shardedStorageWithIndexStub.go b/testscommon/shardedStorageWithIndexStub.go index f13cae93..6d07deb4 100644 --- a/testscommon/shardedStorageWithIndexStub.go +++ b/testscommon/shardedStorageWithIndexStub.go @@ -9,6 +9,7 @@ type ShardedStorageWithIndexStub struct { CloseCalled func() error AllocateBucketIndexCalled func(address []byte) (uint32, error) CountCalled func() (uint32, error) + UpdateWithCheckCalled func(key []byte, fn func(data interface{}) (interface{}, error)) error } // AllocateIndex - @@ -67,6 +68,15 @@ func (stub *ShardedStorageWithIndexStub) Count() (uint32, error) { return 0, nil } +// UpdateWithCheck - +func (stub *ShardedStorageWithIndexStub) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { + if stub.UpdateWithCheckCalled != nil { + return stub.UpdateWithCheckCalled(key, fn) + } + + return nil +} + // IsInterfaceNil - func (stub *ShardedStorageWithIndexStub) IsInterfaceNil() bool { return stub == nil From e9809d68f8772d472193ba9b3a87f06c399780e7 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Tue, 14 Mar 2023 15:44:35 +0200 Subject: [PATCH 13/35] use mtest package for mongo db client unit testing --- mongodb/dbClient.go | 11 +- mongodb/dbClient_test.go | 360 +++++++++++++++---------------- mongodb/mongoDBClientFactory.go | 4 +- mongodb/mongodbClientWrapper.go | 42 ---- testscommon/mongoDBClientStub.go | 115 ---------- 5 files changed, 186 insertions(+), 346 deletions(-) delete mode 100644 mongodb/mongodbClientWrapper.go delete mode 100644 testscommon/mongoDBClientStub.go diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 5e4291fd..8c6d4125 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "github.com/multiversx/multi-factor-auth-go-service/core" - "github.com/multiversx/mx-chain-core-go/core/check" logger "github.com/multiversx/mx-chain-logger-go" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -29,14 +28,14 @@ type mongoEntry struct { } type mongodbClient struct { - client MongoDBClientWrapper + client *mongo.Client collections map[CollectionID]MongoDBCollection ctx context.Context } // NewClient will create a new mongodb client instance -func NewClient(client MongoDBClientWrapper, dbName string) (*mongodbClient, error) { - if check.IfNil(client) { +func NewClient(client *mongo.Client, dbName string) (*mongodbClient, error) { + if client == nil { return nil, ErrNilMongoDBClientWrapper } if dbName == "" { @@ -51,7 +50,7 @@ func NewClient(client MongoDBClientWrapper, dbName string) (*mongodbClient, erro } collections := make(map[CollectionID]MongoDBCollection) - collections[UsersCollectionID] = client.DBCollection(dbName, string(UsersCollectionID)) + collections[UsersCollectionID] = client.Database(dbName).Collection(string(UsersCollectionID)) return &mongodbClient{ client: client, @@ -140,7 +139,7 @@ func (mdc *mongodbClient) Remove(collID CollectionID, key []byte) error { return nil } -// IncrementWithTransaction will increment the value for the provided key, within a transaction +// ReadWriteWithCheck will perform read and write operation with a provided checker func (mdc *mongodbClient) ReadWriteWithCheck( collID CollectionID, key []byte, diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index 5acc13da..32936aee 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -1,16 +1,15 @@ package mongodb_test import ( - "context" + "bytes" + "encoding/binary" "errors" "testing" "github.com/multiversx/multi-factor-auth-go-service/mongodb" - "github.com/multiversx/multi-factor-auth-go-service/testscommon" "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/integration/mtest" ) type testStruct struct { @@ -21,138 +20,107 @@ type testStruct struct { func TestNewMongoDBClient(t *testing.T) { t.Parallel() - t.Run("nil client wrapper, should fail", func(t *testing.T) { - t.Parallel() + mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) + defer mt.Close() + mt.Run("nil client wrapper, should fail", func(mt *mtest.T) { client, err := mongodb.NewClient(nil, "dbName") - require.Nil(t, client) - require.Equal(t, mongodb.ErrNilMongoDBClientWrapper, err) + require.Nil(mt, client) + require.Equal(mt, mongodb.ErrNilMongoDBClientWrapper, err) }) - t.Run("empty db name, should fail", func(t *testing.T) { - t.Parallel() - - client, err := mongodb.NewClient(&testscommon.MongoDBClientWrapperStub{}, "") - require.Nil(t, client) - require.Equal(t, mongodb.ErrEmptyMongoDBName, err) + mt.Run("empty db name, should fail", func(mt *mtest.T) { + client, err := mongodb.NewClient(mt.Client, "") + require.Nil(mt, client) + require.Equal(mt, mongodb.ErrEmptyMongoDBName, err) }) - t.Run("failed to connect", func(t *testing.T) { - t.Parallel() - - expectedErr := errors.New("expected err") - mongoDBClientWrapper := &testscommon.MongoDBClientWrapperStub{ - ConnectCalled: func(ctx context.Context) error { - return expectedErr - }, - } - - client, err := mongodb.NewClient(mongoDBClientWrapper, "dbName") - require.Nil(t, client) - require.Equal(t, expectedErr, err) - }) - - t.Run("should work", func(t *testing.T) { - t.Parallel() - - client, err := mongodb.NewClient(&testscommon.MongoDBClientWrapperStub{}, "dbName") - require.Nil(t, err) - require.False(t, client.IsInterfaceNil()) + mt.Run("should work", func(mt *mtest.T) { + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(mt, err) + require.False(mt, client.IsInterfaceNil()) }) } func TestMongoDBClient_Put(t *testing.T) { t.Parallel() - t.Run("collection not found", func(t *testing.T) { - t.Parallel() + mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) + defer mt.Close() - client, err := mongodb.NewClient(&testscommon.MongoDBClientWrapperStub{}, "dbName") + mt.Run("collection not found", func(mt *mtest.T) { + client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(t, err) err = client.Put("another coll", []byte("key1"), []byte("data")) require.Equal(t, mongodb.ErrCollectionNotFound, err) }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - wasCalled := false - mongoDBClientWrapper := &testscommon.MongoDBClientWrapperStub{ - DBCollectionCalled: func(dbName, collName string) mongodb.MongoDBCollection { - require.Equal(t, string(mongodb.UsersCollectionID), collName) - - return &testscommon.MongoDBCollectionStub{ - UpdateOneCalled: func(ctx context.Context, filter, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { - wasCalled = true - return nil, nil - }, - } - }, - } + mt.Run("should fail", func(mt *mtest.T) { + expectedErr := errors.New("expected error") + mt.AddMockResponses(mtest.CreateCommandErrorResponse(mtest.CommandError{ + Code: 1, + Message: expectedErr.Error(), + })) - client, err := mongodb.NewClient(mongoDBClientWrapper, "dbName") + client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(t, err) err = client.Put(mongodb.UsersCollectionID, []byte("key1"), []byte("data")) + require.Equal(t, expectedErr.Error(), err.Error()) + }) + + mt.Run("should work", func(mt *mtest.T) { + mt.AddMockResponses( + mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ + {"_id", "key2"}, + {"value", []byte("value")}, + })) + + client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(t, err) - require.True(t, wasCalled) + err = client.Put(mongodb.UsersCollectionID, []byte("key1"), []byte("data")) + require.Nil(t, err) }) } func TestMongoDBClient_Get(t *testing.T) { t.Parallel() - t.Run("collection not found", func(t *testing.T) { - t.Parallel() + mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) + defer mt.Close() - client, err := mongodb.NewClient(&testscommon.MongoDBClientWrapperStub{}, "dbName") + mt.Run("collection not found", func(mt *mtest.T) { + client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(t, err) _, err = client.Get("another coll", []byte("key1")) require.Equal(t, mongodb.ErrCollectionNotFound, err) }) - t.Run("find one entry failed", func(t *testing.T) { - t.Parallel() - + mt.Run("find one entry failed", func(mt *mtest.T) { expectedErr := errors.New("expected err") - mongoDBClientWrapper := &testscommon.MongoDBClientWrapperStub{ - DBCollectionCalled: func(dbName, collName string) mongodb.MongoDBCollection { - require.Equal(t, string(mongodb.UsersCollectionID), collName) - - return &testscommon.MongoDBCollectionStub{ - FindOneCalled: func(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult { - return mongo.NewSingleResultFromDocument(&testStruct{}, expectedErr, bson.DefaultRegistry) - }, - } - }, - } + mt.AddMockResponses(mtest.CreateCommandErrorResponse(mtest.CommandError{ + Code: 1, + Message: expectedErr.Error(), + })) - client, err := mongodb.NewClient(mongoDBClientWrapper, "dbName") + client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(t, err) _, err = client.Get(mongodb.UsersCollectionID, []byte("key1")) - require.Equal(t, expectedErr, err) + require.Equal(t, expectedErr.Error(), err.Error()) }) - t.Run("should work", func(t *testing.T) { - t.Parallel() + mt.Run("should work", func(mt *mtest.T) { + mt.AddMockResponses( + mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ + {"_id", "key2"}, + {"value", []byte("value")}, + })) - mongoDBClientWrapper := &testscommon.MongoDBClientWrapperStub{ - DBCollectionCalled: func(dbName, collName string) mongodb.MongoDBCollection { - require.Equal(t, string(mongodb.UsersCollectionID), collName) - - return &testscommon.MongoDBCollectionStub{ - FindOneCalled: func(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult { - return mongo.NewSingleResultFromDocument(&testStruct{}, nil, bson.DefaultRegistry) - }, - } - }, - } - - client, err := mongodb.NewClient(mongoDBClientWrapper, "dbName") + client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(t, err) _, err = client.Get(mongodb.UsersCollectionID, []byte("key1")) @@ -163,22 +131,31 @@ func TestMongoDBClient_Get(t *testing.T) { func TestMongoDBClient_Has(t *testing.T) { t.Parallel() - t.Run("should work", func(t *testing.T) { - t.Parallel() + mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) + defer mt.Close() - mongoDBClientWrapper := &testscommon.MongoDBClientWrapperStub{ - DBCollectionCalled: func(dbName, collName string) mongodb.MongoDBCollection { - require.Equal(t, string(mongodb.UsersCollectionID), collName) + mt.Run("should fail", func(mt *mtest.T) { + expectedErr := errors.New("expected err") + mt.AddMockResponses(mtest.CreateCommandErrorResponse(mtest.CommandError{ + Code: 1, + Message: expectedErr.Error(), + })) - return &testscommon.MongoDBCollectionStub{ - FindOneCalled: func(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult { - return mongo.NewSingleResultFromDocument(&testStruct{}, nil, bson.DefaultRegistry) - }, - } - }, - } + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(t, err) + + err = client.Has(mongodb.UsersCollectionID, []byte("key1")) + require.Equal(t, expectedErr.Error(), err.Error()) + }) - client, err := mongodb.NewClient(mongoDBClientWrapper, "dbName") + mt.Run("should work", func(mt *mtest.T) { + mt.AddMockResponses( + mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ + {"_id", "key2"}, + {"value", []byte("value")}, + })) + + client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(t, err) err = client.Has(mongodb.UsersCollectionID, []byte("key1")) @@ -189,101 +166,124 @@ func TestMongoDBClient_Has(t *testing.T) { func TestMongoDBClient_Remove(t *testing.T) { t.Parallel() - t.Run("collection not found", func(t *testing.T) { - t.Parallel() + mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) + defer mt.Close() - client, err := mongodb.NewClient(&testscommon.MongoDBClientWrapperStub{}, "dbName") + mt.Run("collection not found", func(mt *mtest.T) { + client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(t, err) err = client.Remove("another coll", []byte("key1")) require.Equal(t, mongodb.ErrCollectionNotFound, err) }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - wasCalled := false - mongoDBClientWrapper := &testscommon.MongoDBClientWrapperStub{ - DBCollectionCalled: func(dbName, collName string) mongodb.MongoDBCollection { - require.Equal(t, string(mongodb.UsersCollectionID), collName) - - return &testscommon.MongoDBCollectionStub{ - DeleteOneCalled: func(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { - wasCalled = true - return nil, nil - }, - } - }, - } + mt.Run("should work", func(mt *mtest.T) { + mt.AddMockResponses( + mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ + {"_id", "key1"}, + }), + ) - client, err := mongodb.NewClient(mongoDBClientWrapper, "dbName") + client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(t, err) err = client.Remove(mongodb.UsersCollectionID, []byte("key1")) require.Nil(t, err) + }) +} + +func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { + t.Parallel() + + mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) + defer mt.Close() + + mt.Run("failed to create session", func(mt *mtest.T) { + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(t, err) + + expectedErr := errors.New("expected error") + + mt.AddMockResponses( + mtest.CreateCommandErrorResponse(mtest.CommandError{ + Code: 1, + Message: expectedErr.Error(), + }), + ) + + _, err = client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) + require.Equal(t, expectedErr.Error(), err.Error()) + }) + + mt.Run("should work", func(mt *mtest.T) { + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(t, err) + + newIndexBytes := make([]byte, 4) + binary.BigEndian.PutUint32(newIndexBytes, 3) + + mt.AddMockResponses( + mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ + {"_id", "key2"}, + {"value", newIndexBytes}, + }), + mtest.CreateSuccessResponse(), + mtest.CreateSuccessResponse(), + mtest.CreateSuccessResponse(), + ) - require.True(t, wasCalled) + val, err := client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) + require.Nil(t, err) + require.Equal(t, uint32(4), val) }) } -// func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { -// t.Parallel() - -// updateWasCalled := false -// findWasCalled := false -// sessionWasCalled := false - -// mongoDBClientWrapper := &testscommon.MongoDBClientWrapperStub{ -// DBCollectionCalled: func(dbName, collName string) mongodb.MongoDBCollection { -// require.Equal(t, string(mongodb.UsersCollectionID), collName) - -// return &testscommon.MongoDBCollectionStub{ -// FindOneCalled: func(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult { -// findWasCalled = true -// return mongo.NewSingleResultFromDocument(&testStruct{Key: "key", Value: []byte{0, 0, 0, 1}}, nil, bson.DefaultRegistry) -// }, -// UpdateOneCalled: func(ctx context.Context, filter, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { -// updateWasCalled = true -// return nil, nil -// }, -// } -// }, -// StartSessionCalled: func() (mongo.Session, error) { -// sessionWasCalled = true -// return &testscommon.MongoDBSessionStub{ -// WithTransactionCalled: func(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { -// fn(mongo.NewSessionContext(context.TODO(), mongo.SessionFromContext(context.TODO()))) -// return uint32(5), nil -// }, -// }, nil -// }, -// } - -// client, err := mongodb.NewClient(mongoDBClientWrapper, "dbName") -// require.Nil(t, err) - -// _, err = client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) -// require.Nil(t, err) - -// require.True(t, findWasCalled) -// require.True(t, updateWasCalled) -// require.True(t, sessionWasCalled) -// } - -func TestMongoDBClient_Close(t *testing.T) { +func TestMongoDBClient_ReadWriteWithCheck(t *testing.T) { t.Parallel() - wasCalled := false - mongoDBClientWrapper := &testscommon.MongoDBClientWrapperStub{ - ConnectCalled: func(ctx context.Context) error { - wasCalled = true - return nil - }, - } + mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) + defer mt.Close() + + mt.Run("failed to create session", func(mt *mtest.T) { + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(t, err) + + expectedErr := errors.New("expected error") + + mt.AddMockResponses( + mtest.CreateCommandErrorResponse(mtest.CommandError{ + Code: 1, + Message: expectedErr.Error(), + }), + ) + + checker := func(data interface{}) (interface{}, error) { + return nil, nil + } + + err = client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key1"), checker) + require.Equal(t, expectedErr.Error(), err.Error()) + }) - client, err := mongodb.NewClient(mongoDBClientWrapper, "dbName") - require.Nil(t, err) + mt.Run("should work", func(mt *mtest.T) { + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(t, err) - require.Nil(t, client.Close()) - require.True(t, wasCalled) + mt.AddMockResponses( + mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ + {"_id", "key1"}, + {"value", []byte("data1")}, + }), + ) + + checker := func(data interface{}) (interface{}, error) { + if bytes.Equal(data.([]byte), []byte("data1")) { + return []byte("data2"), nil + } + return nil, errors.New("error") + } + + err = client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key1"), checker) + require.Nil(t, err) + }) } diff --git a/mongodb/mongoDBClientFactory.go b/mongodb/mongoDBClientFactory.go index 196a959b..3b12a8a1 100644 --- a/mongodb/mongoDBClientFactory.go +++ b/mongodb/mongoDBClientFactory.go @@ -25,7 +25,5 @@ func CreateMongoDBClient(cfg config.MongoDBConfig) (MongoDBClient, error) { return nil, err } - clientWrapper := newMongoDBClientWrapper(client) - - return NewClient(clientWrapper, cfg.DBName) + return NewClient(client, cfg.DBName) } diff --git a/mongodb/mongodbClientWrapper.go b/mongodb/mongodbClientWrapper.go deleted file mode 100644 index 02ad30a3..00000000 --- a/mongodb/mongodbClientWrapper.go +++ /dev/null @@ -1,42 +0,0 @@ -package mongodb - -import ( - "context" - - "go.mongodb.org/mongo-driver/mongo" -) - -type mongoDBClientWrapper struct { - client *mongo.Client -} - -func newMongoDBClientWrapper(client *mongo.Client) *mongoDBClientWrapper { - return &mongoDBClientWrapper{ - client: client, - } -} - -// Connect will try to connect the db client -func (m *mongoDBClientWrapper) Connect(ctx context.Context) error { - return m.client.Connect(ctx) -} - -// Disconnect will disconnect the db client -func (m *mongoDBClientWrapper) Disconnect(ctx context.Context) error { - return m.client.Disconnect(ctx) -} - -// DBCollection will return the specified collection object -func (m *mongoDBClientWrapper) DBCollection(dbName string, coll string) MongoDBCollection { - return m.client.Database(dbName).Collection(coll) -} - -// DBCollection will return the specified collection object -func (m *mongoDBClientWrapper) StartSession() (mongo.Session, error) { - return m.client.StartSession() -} - -// IsInterfaceNil returns true if there is no value under the interface -func (m *mongoDBClientWrapper) IsInterfaceNil() bool { - return m == nil -} diff --git a/testscommon/mongoDBClientStub.go b/testscommon/mongoDBClientStub.go deleted file mode 100644 index 3e5a7602..00000000 --- a/testscommon/mongoDBClientStub.go +++ /dev/null @@ -1,115 +0,0 @@ -package testscommon - -import ( - "context" - - "github.com/multiversx/multi-factor-auth-go-service/mongodb" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" -) - -// MongoDBClientWrapperStub - -type MongoDBClientWrapperStub struct { - DBCollectionCalled func(dbName string, collName string) mongodb.MongoDBCollection - ConnectCalled func(ctx context.Context) error - DisconnectCalled func(ctx context.Context) error - StartSessionCalled func() (mongodb.MongoDBSession, error) -} - -// DBCollection - -func (m *MongoDBClientWrapperStub) DBCollection(dbName string, collName string) mongodb.MongoDBCollection { - if m.DBCollectionCalled != nil { - return m.DBCollectionCalled(dbName, collName) - } - - return &MongoDBCollectionStub{} -} - -// Connect - -func (m *MongoDBClientWrapperStub) Connect(ctx context.Context) error { - if m.ConnectCalled != nil { - return m.ConnectCalled(ctx) - } - - return nil -} - -// Disconnect - -func (m *MongoDBClientWrapperStub) Disconnect(ctx context.Context) error { - if m.DisconnectCalled != nil { - return m.DisconnectCalled(ctx) - } - - return nil -} - -// StartSession - -func (m *MongoDBClientWrapperStub) StartSession() (mongodb.MongoDBSession, error) { - if m.StartSessionCalled != nil { - return m.StartSessionCalled() - } - - return nil, nil -} - -// IsInterfaceNil - -func (m *MongoDBClientWrapperStub) IsInterfaceNil() bool { - return m == nil -} - -// MongoDBCollectionStub - -type MongoDBCollectionStub struct { - UpdateOneCalled func(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) - FindOneCalled func(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult - - DeleteOneCalled func(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) -} - -// UpdateOne - -func (m *MongoDBCollectionStub) UpdateOne(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { - if m.UpdateOneCalled != nil { - return m.UpdateOneCalled(ctx, filter, update) - } - - return nil, nil -} - -// FindOne - -func (m *MongoDBCollectionStub) FindOne(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult { - if m.FindOneCalled != nil { - return m.FindOneCalled(ctx, filter) - } - - return nil -} - -// DeleteOne - -func (m *MongoDBCollectionStub) DeleteOne(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { - if m.DeleteOneCalled != nil { - return m.DeleteOneCalled(ctx, filter) - } - - return nil, nil -} - -// MongoDBSessionStub - -type MongoDBSessionStub struct { - WithTransactionCalled func(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) - EndSessionCalled func(_ context.Context) -} - -// WithTransaction - -func (m *MongoDBSessionStub) WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { - if m.WithTransactionCalled != nil { - return m.WithTransactionCalled(ctx, fn) - } - - return nil, nil -} - -// EndSession - -func (m *MongoDBSessionStub) EndSession(ctx context.Context) { - if m.EndSessionCalled != nil { - m.EndSessionCalled(ctx) - } -} From 49103ce5dcfba20cd7339092fb5feb6455761656 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Tue, 14 Mar 2023 15:59:12 +0200 Subject: [PATCH 14/35] mongodb unit tests - added parallel run --- mongodb/dbClient_test.go | 108 ++++++++++++++++++++++++--------------- 1 file changed, 68 insertions(+), 40 deletions(-) diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index 32936aee..5e798f46 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -49,14 +49,18 @@ func TestMongoDBClient_Put(t *testing.T) { defer mt.Close() mt.Run("collection not found", func(mt *mtest.T) { + mt.Parallel() + client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) err = client.Put("another coll", []byte("key1"), []byte("data")) - require.Equal(t, mongodb.ErrCollectionNotFound, err) + require.Equal(mt, mongodb.ErrCollectionNotFound, err) }) mt.Run("should fail", func(mt *mtest.T) { + mt.Parallel() + expectedErr := errors.New("expected error") mt.AddMockResponses(mtest.CreateCommandErrorResponse(mtest.CommandError{ Code: 1, @@ -64,24 +68,26 @@ func TestMongoDBClient_Put(t *testing.T) { })) client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) err = client.Put(mongodb.UsersCollectionID, []byte("key1"), []byte("data")) - require.Equal(t, expectedErr.Error(), err.Error()) + require.Equal(mt, expectedErr.Error(), err.Error()) }) mt.Run("should work", func(mt *mtest.T) { + mt.Parallel() + mt.AddMockResponses( mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ - {"_id", "key2"}, - {"value", []byte("value")}, + {Key: "_id", Value: "key2"}, + {Key: "value", Value: []byte("value")}, })) client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) err = client.Put(mongodb.UsersCollectionID, []byte("key1"), []byte("data")) - require.Nil(t, err) + require.Nil(mt, err) }) } @@ -92,14 +98,18 @@ func TestMongoDBClient_Get(t *testing.T) { defer mt.Close() mt.Run("collection not found", func(mt *mtest.T) { + mt.Parallel() + client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) _, err = client.Get("another coll", []byte("key1")) - require.Equal(t, mongodb.ErrCollectionNotFound, err) + require.Equal(mt, mongodb.ErrCollectionNotFound, err) }) mt.Run("find one entry failed", func(mt *mtest.T) { + mt.Parallel() + expectedErr := errors.New("expected err") mt.AddMockResponses(mtest.CreateCommandErrorResponse(mtest.CommandError{ Code: 1, @@ -107,24 +117,26 @@ func TestMongoDBClient_Get(t *testing.T) { })) client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) _, err = client.Get(mongodb.UsersCollectionID, []byte("key1")) - require.Equal(t, expectedErr.Error(), err.Error()) + require.Equal(mt, expectedErr.Error(), err.Error()) }) mt.Run("should work", func(mt *mtest.T) { + mt.Parallel() + mt.AddMockResponses( mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ - {"_id", "key2"}, - {"value", []byte("value")}, + {Key: "_id", Value: "key2"}, + {Key: "value", Value: []byte("value")}, })) client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) _, err = client.Get(mongodb.UsersCollectionID, []byte("key1")) - require.Nil(t, err) + require.Nil(mt, err) }) } @@ -135,6 +147,8 @@ func TestMongoDBClient_Has(t *testing.T) { defer mt.Close() mt.Run("should fail", func(mt *mtest.T) { + mt.Parallel() + expectedErr := errors.New("expected err") mt.AddMockResponses(mtest.CreateCommandErrorResponse(mtest.CommandError{ Code: 1, @@ -142,24 +156,26 @@ func TestMongoDBClient_Has(t *testing.T) { })) client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) err = client.Has(mongodb.UsersCollectionID, []byte("key1")) - require.Equal(t, expectedErr.Error(), err.Error()) + require.Equal(mt, expectedErr.Error(), err.Error()) }) mt.Run("should work", func(mt *mtest.T) { + mt.Parallel() + mt.AddMockResponses( mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ - {"_id", "key2"}, - {"value", []byte("value")}, + {Key: "_id", Value: "key2"}, + {Key: "value", Value: []byte("value")}, })) client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) err = client.Has(mongodb.UsersCollectionID, []byte("key1")) - require.Nil(t, err) + require.Nil(mt, err) }) } @@ -170,25 +186,29 @@ func TestMongoDBClient_Remove(t *testing.T) { defer mt.Close() mt.Run("collection not found", func(mt *mtest.T) { + mt.Parallel() + client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) err = client.Remove("another coll", []byte("key1")) - require.Equal(t, mongodb.ErrCollectionNotFound, err) + require.Equal(mt, mongodb.ErrCollectionNotFound, err) }) mt.Run("should work", func(mt *mtest.T) { + mt.Parallel() + mt.AddMockResponses( mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ - {"_id", "key1"}, + {Key: "_id", Value: "key1"}, }), ) client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) err = client.Remove(mongodb.UsersCollectionID, []byte("key1")) - require.Nil(t, err) + require.Nil(mt, err) }) } @@ -199,8 +219,10 @@ func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { defer mt.Close() mt.Run("failed to create session", func(mt *mtest.T) { + mt.Parallel() + client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) expectedErr := errors.New("expected error") @@ -212,20 +234,22 @@ func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { ) _, err = client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) - require.Equal(t, expectedErr.Error(), err.Error()) + require.Equal(mt, expectedErr.Error(), err.Error()) }) mt.Run("should work", func(mt *mtest.T) { + mt.Parallel() + client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) newIndexBytes := make([]byte, 4) binary.BigEndian.PutUint32(newIndexBytes, 3) mt.AddMockResponses( mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ - {"_id", "key2"}, - {"value", newIndexBytes}, + {Key: "_id", Value: "key2"}, + {Key: "value", Value: newIndexBytes}, }), mtest.CreateSuccessResponse(), mtest.CreateSuccessResponse(), @@ -233,8 +257,8 @@ func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { ) val, err := client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) - require.Nil(t, err) - require.Equal(t, uint32(4), val) + require.Nil(mt, err) + require.Equal(mt, uint32(4), val) }) } @@ -245,8 +269,10 @@ func TestMongoDBClient_ReadWriteWithCheck(t *testing.T) { defer mt.Close() mt.Run("failed to create session", func(mt *mtest.T) { + mt.Parallel() + client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) expectedErr := errors.New("expected error") @@ -262,17 +288,19 @@ func TestMongoDBClient_ReadWriteWithCheck(t *testing.T) { } err = client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key1"), checker) - require.Equal(t, expectedErr.Error(), err.Error()) + require.Equal(mt, expectedErr.Error(), err.Error()) }) mt.Run("should work", func(mt *mtest.T) { + mt.Parallel() + client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(t, err) + require.Nil(mt, err) mt.AddMockResponses( mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ - {"_id", "key1"}, - {"value", []byte("data1")}, + {Key: "_id", Value: "key1"}, + {Key: "value", Value: []byte("data1")}, }), ) @@ -284,6 +312,6 @@ func TestMongoDBClient_ReadWriteWithCheck(t *testing.T) { } err = client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key1"), checker) - require.Nil(t, err) + require.Nil(mt, err) }) } From 3ceb780595c1bf6d4230d4cd4347c0d833590a41 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Tue, 14 Mar 2023 16:11:58 +0200 Subject: [PATCH 15/35] add missing err check --- mongodb/dbClient.go | 3 +++ mongodb/dbClient_test.go | 5 ----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 8c6d4125..e2a95a8f 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -175,6 +175,9 @@ func (mdc *mongodbClient) ReadWriteWithCheck( } err = mdc.Put(collID, key, retValueBytes) + if err != nil { + return err + } if err = session.CommitTransaction(ctx); err != nil { return err diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index 5e798f46..2026a847 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -12,11 +12,6 @@ import ( "go.mongodb.org/mongo-driver/mongo/integration/mtest" ) -type testStruct struct { - Key string `bson:"_id"` - Value []byte `bson:"value"` -} - func TestNewMongoDBClient(t *testing.T) { t.Parallel() From 5b336f3a3ce9f93c5d445871ff401cd8f440a6b6 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Tue, 14 Mar 2023 16:12:34 +0200 Subject: [PATCH 16/35] bump golang version in golangci-lint github workflow --- .github/workflows/golangci-lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 041020d4..aed5d099 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -16,7 +16,7 @@ jobs: steps: - uses: actions/setup-go@v3 with: - go-version: 1.17.6 + go-version: 1.18.10 - uses: actions/checkout@v3 - name: golangci-lint uses: golangci/golangci-lint-action@v3 From 8a28600d4251bd7c4ae523044e304fc31c6c7c81 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Tue, 14 Mar 2023 16:19:39 +0200 Subject: [PATCH 17/35] Revert "bump golang version in golangci-lint github workflow" This reverts commit 5b336f3a3ce9f93c5d445871ff401cd8f440a6b6. --- .github/workflows/golangci-lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index aed5d099..041020d4 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -16,7 +16,7 @@ jobs: steps: - uses: actions/setup-go@v3 with: - go-version: 1.18.10 + go-version: 1.17.6 - uses: actions/checkout@v3 - name: golangci-lint uses: golangci/golangci-lint-action@v3 From d028bbd7065e79d3dad64719d638062762f4de4f Mon Sep 17 00:00:00 2001 From: ssd04 Date: Tue, 14 Mar 2023 16:19:55 +0200 Subject: [PATCH 18/35] go mod tidy --- go.mod | 1 + 1 file changed, 1 insertion(+) diff --git a/go.mod b/go.mod index 3cb90cfb..64546496 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( github.com/goccy/go-json v0.10.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect + github.com/google/go-cmp v0.5.5 // indirect github.com/google/uuid v1.3.0 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/json-iterator/go v1.1.12 // indirect From 7506de3d0accbb71538e8e94de8ab5ec47f1c747 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Tue, 14 Mar 2023 16:33:40 +0200 Subject: [PATCH 19/35] revert changed to dbOTPHandler --- handlers/storage/dbOTPHandler.go | 65 +++++---------------------- handlers/storage/dbOTPHandler_test.go | 27 +---------- 2 files changed, 13 insertions(+), 79 deletions(-) diff --git a/handlers/storage/dbOTPHandler.go b/handlers/storage/dbOTPHandler.go index 9b92ac90..84800584 100644 --- a/handlers/storage/dbOTPHandler.go +++ b/handlers/storage/dbOTPHandler.go @@ -17,14 +17,14 @@ const ( // ArgDBOTPHandler is the DTO used to create a new instance of dbOTPHandler type ArgDBOTPHandler struct { - DB core.StorageWithIndexChecker + DB core.StorageWithIndex TOTPHandler handlers.TOTPHandler Marshaller core.Marshaller DelayBetweenOTPUpdatesInSec int64 } type dbOTPHandler struct { - db core.StorageWithIndexChecker + db core.StorageWithIndex totpHandler handlers.TOTPHandler marshaller core.Marshaller getTimeHandler func() time.Time @@ -85,31 +85,19 @@ func (handler *dbOTPHandler) Save(account, guardian []byte, otp handlers.OTP) er return handler.saveNewOTP(key, otp) } - checker := func(data interface{}) (interface{}, error) { - otpInfoBytes, ok := data.([]byte) - if !ok { - return nil, core.ErrInvalidValue - } - - err := handler.checkOtpUpdateAllowed(otpInfoBytes) - if err != nil { - return nil, err - } - - buff, err := handler.getMarshalledOtpData(otp) - if err != nil { - return nil, err - } - - return buff, nil - } - - err = handler.db.UpdateWithCheck(key, checker) + oldOTPInfo, err := handler.getOldOTPInfo(key) if err != nil { return err } - return nil + currentTimestamp := handler.getTimeHandler().Unix() + isOTPUpdateAllowed := oldOTPInfo.LastTOTPChangeTimestamp+handler.delayBetweenOTPUpdatesInSec < currentTimestamp + if !isOTPUpdateAllowed { + return fmt.Errorf("%w, last update was %d seconds ago", + handlers.ErrRegistrationFailed, currentTimestamp-oldOTPInfo.LastTOTPChangeTimestamp) + } + + return handler.saveNewOTP(key, otp) } // Get returns the one time password @@ -138,37 +126,6 @@ func (handler *dbOTPHandler) getOldOTPInfo(key []byte) (*core.OTPInfo, error) { return otpInfo, nil } -func (handler *dbOTPHandler) getMarshalledOtpData(otp handlers.OTP) ([]byte, error) { - newOtpInfo := &core.OTPInfo{ - LastTOTPChangeTimestamp: handler.getTimeHandler().Unix(), - } - - var err error - newOtpInfo.OTP, err = otp.ToBytes() - if err != nil { - return nil, err - } - - return handler.marshaller.Marshal(newOtpInfo) -} - -func (handler *dbOTPHandler) checkOtpUpdateAllowed(otpInfoBytes []byte) error { - otpInfo := &core.OTPInfo{} - err := handler.marshaller.Unmarshal(otpInfo, otpInfoBytes) - if err != nil { - return err - } - - currentTimestamp := handler.getTimeHandler().Unix() - isOTPUpdateAllowed := otpInfo.LastTOTPChangeTimestamp+handler.delayBetweenOTPUpdatesInSec < currentTimestamp - if !isOTPUpdateAllowed { - return fmt.Errorf("%w, last update was %d seconds ago", - handlers.ErrRegistrationFailed, currentTimestamp-otpInfo.LastTOTPChangeTimestamp) - } - - return nil -} - func (handler *dbOTPHandler) saveNewOTP(key []byte, otp handlers.OTP) error { otpInfo := &core.OTPInfo{ LastTOTPChangeTimestamp: handler.getTimeHandler().Unix(), diff --git a/handlers/storage/dbOTPHandler_test.go b/handlers/storage/dbOTPHandler_test.go index bbcab4bc..48c56cdc 100644 --- a/handlers/storage/dbOTPHandler_test.go +++ b/handlers/storage/dbOTPHandler_test.go @@ -21,7 +21,7 @@ var expectedErr = errors.New("expected error") func createMockArgs() storage.ArgDBOTPHandler { return storage.ArgDBOTPHandler{ - DB: testscommon.NewShardedStorageWithIndexMock(), + DB: testscommon.NewShardedStorageWithIndexMock(), TOTPHandler: &testscommon.TOTPHandlerStub{}, Marshaller: &testscommon.MarshallerStub{}, DelayBetweenOTPUpdatesInSec: 5, @@ -201,10 +201,6 @@ func TestDBOTPHandler_Save(t *testing.T) { args := createMockArgs() args.DB = &testscommon.ShardedStorageWithIndexStub{ - UpdateWithCheckCalled: func(key []byte, fn func(data interface{}) (interface{}, error)) error { - _, err := args.DB.Get(key) - return err - }, GetCalled: func(key []byte) ([]byte, error) { return nil, expectedErr }, @@ -221,12 +217,7 @@ func TestDBOTPHandler_Save(t *testing.T) { t.Parallel() args := createMockArgs() - args.DB = &testscommon.ShardedStorageWithIndexStub{ - UpdateWithCheckCalled: func(key []byte, fn func(data interface{}) (interface{}, error)) error { - _, err := fn([]byte("badEncodedData")) - return err - }, - } + args.DB = &testscommon.ShardedStorageWithIndexStub{} args.Marshaller = &testscommon.MarshallerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { return expectedErr @@ -356,20 +347,6 @@ func TestDBOTPHandler_Save(t *testing.T) { GetCalled: func(key []byte) ([]byte, error) { return mockDB.Get(key) }, - UpdateWithCheckCalled: func(key []byte, fn func(data interface{}) (interface{}, error)) error { - data, err := mockDB.Get(key) - if err != nil { - return err - } - - newData, err := fn(data) - if err != nil { - return err - } - - atomic.AddUint32(&putCounter, 1) - return mockDB.Put(key, newData.([]byte)) - }, } args.Marshaller = &mock.MarshalizerMock{} handler, err := storage.NewDBOTPHandler(args) From 18d09781fab0348d72e6b1fd625d9f48e68f107f Mon Sep 17 00:00:00 2001 From: ssd04 Date: Tue, 14 Mar 2023 16:33:51 +0200 Subject: [PATCH 20/35] fix db client unit test --- mongodb/dbClient_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index 2026a847..c3fa005f 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -297,6 +297,8 @@ func TestMongoDBClient_ReadWriteWithCheck(t *testing.T) { {Key: "_id", Value: "key1"}, {Key: "value", Value: []byte("data1")}, }), + mtest.CreateSuccessResponse(), + mtest.CreateSuccessResponse(), ) checker := func(data interface{}) (interface{}, error) { From a184d5457b4c18aa026b826a31623f140a8471b0 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Wed, 15 Mar 2023 14:19:07 +0200 Subject: [PATCH 21/35] fixes after review: handlers package - renamings, adding mutex protection --- core/interface.go | 1 + handlers/interface.go | 4 ++-- handlers/storage/bucket/bucketIndexHandler.go | 21 ++++++++++++------- .../storage/bucket/mongodbIndexHandler.go | 3 ++- .../storage/factory/shardedStorageFactory.go | 19 ++++++++--------- .../factory/shardedStorageFactory_test.go | 8 +++---- 6 files changed, 31 insertions(+), 25 deletions(-) diff --git a/core/interface.go b/core/interface.go index 9a7365c6..40944138 100644 --- a/core/interface.go +++ b/core/interface.go @@ -93,6 +93,7 @@ type StorageWithIndex interface { IsInterfaceNil() bool } +// StorageWithIndexChecker defines the methods for storage with check operations type StorageWithIndexChecker interface { StorageWithIndex UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error diff --git a/handlers/interface.go b/handlers/interface.go index faf84c16..79a84529 100644 --- a/handlers/interface.go +++ b/handlers/interface.go @@ -28,8 +28,8 @@ type OTP interface { ToBytes() ([]byte, error) } -// ShardedStorageFactory defines the methods available for a sharded storage factory -type ShardedStorageFactory interface { +// StorageWithIndexFactory defines the methods available for a sharded storage factory +type StorageWithIndexFactory interface { Create() (core.StorageWithIndex, error) IsInterfaceNil() bool } diff --git a/handlers/storage/bucket/bucketIndexHandler.go b/handlers/storage/bucket/bucketIndexHandler.go index 2bccac2a..5af2b2f6 100644 --- a/handlers/storage/bucket/bucketIndexHandler.go +++ b/handlers/storage/bucket/bucketIndexHandler.go @@ -14,8 +14,9 @@ const ( ) type bucketIndexHandler struct { - bucket core.Storer - mut sync.RWMutex + bucket core.Storer + bucketMut sync.RWMutex + bucketOpMut sync.RWMutex } // NewBucketIndexHandler returns a new instance of a bucket index handler @@ -43,8 +44,8 @@ func NewBucketIndexHandler(bucket core.Storer) (*bucketIndexHandler, error) { // AllocateBucketIndex allocates a new index and returns it func (handler *bucketIndexHandler) AllocateBucketIndex() (uint32, error) { - handler.mut.Lock() - defer handler.mut.Unlock() + handler.bucketMut.Lock() + defer handler.bucketMut.Unlock() index, err := getIndex(handler.bucket) if err != nil { @@ -71,7 +72,11 @@ func (handler *bucketIndexHandler) Has(key []byte) error { return handler.bucket.Has(key) } +// UpdateWithCheck will update key value pair based on callback function func (handler *bucketIndexHandler) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { + handler.bucketOpMut.Lock() + defer handler.bucketOpMut.Unlock() + data, err := handler.bucket.Get(key) if err != nil { return nil @@ -91,16 +96,16 @@ func (handler *bucketIndexHandler) UpdateWithCheck(key []byte, fn func(data inte // GetLastIndex returns the last index that was allocated func (handler *bucketIndexHandler) GetLastIndex() (uint32, error) { - handler.mut.RLock() - defer handler.mut.RUnlock() + handler.bucketMut.RLock() + defer handler.bucketMut.RUnlock() return getIndex(handler.bucket) } // Close closes the internal bucket func (handler *bucketIndexHandler) Close() error { - handler.mut.Lock() - defer handler.mut.Unlock() + handler.bucketMut.Lock() + defer handler.bucketMut.Unlock() return handler.bucket.Close() } diff --git a/handlers/storage/bucket/mongodbIndexHandler.go b/handlers/storage/bucket/mongodbIndexHandler.go index 783aaca3..424229ce 100644 --- a/handlers/storage/bucket/mongodbIndexHandler.go +++ b/handlers/storage/bucket/mongodbIndexHandler.go @@ -16,7 +16,7 @@ type mongodbIndexHandler struct { mut sync.RWMutex } -// NewMongoDBIndexHandler returns a new instance of a bucket index handler +// NewMongoDBIndexHandler returns a new instance of a mongo db index handler func NewMongoDBIndexHandler(storer core.Storer, mongoClient mongodb.MongoDBClient) (*mongodbIndexHandler, error) { if check.IfNil(storer) { return nil, core.ErrNilStorer @@ -71,6 +71,7 @@ func (handler *mongodbIndexHandler) Has(key []byte) error { return handler.storer.Has(key) } +// UpdateWithCheck will update key value pair based on callback function func (handler *mongodbIndexHandler) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { return handler.mongodbClient.ReadWriteWithCheck(mongodb.UsersCollectionID, key, fn) } diff --git a/handlers/storage/factory/shardedStorageFactory.go b/handlers/storage/factory/shardedStorageFactory.go index f2a4a89d..78345d6f 100644 --- a/handlers/storage/factory/shardedStorageFactory.go +++ b/handlers/storage/factory/shardedStorageFactory.go @@ -12,31 +12,30 @@ import ( "github.com/multiversx/mx-chain-storage-go/storageUnit" ) -type shardedStorageFactory struct { +type storageWithIndexFactory struct { cfg config.Config } -// NewShardedStorageFactory returns a new instance of shardedStorageFactory -func NewShardedStorageFactory(config config.Config) *shardedStorageFactory { - return &shardedStorageFactory{ +// NewStorageWithIndexFactory returns a new instance of shardedStorageFactory +func NewStorageWithIndexFactory(config config.Config) *storageWithIndexFactory { + return &storageWithIndexFactory{ cfg: config, } } -// Create returns a new instance of ShardedStorageWithIndex -func (ssf *shardedStorageFactory) Create() (core.StorageWithIndexChecker, error) { +// Create returns a new instance of storage with index +func (ssf *storageWithIndexFactory) Create() (core.StorageWithIndexChecker, error) { switch ssf.cfg.ShardedStorage.DBType { case core.LevelDB: return ssf.createLocalDB() case core.MongoDB: return ssf.createMongoDB() default: - // TODO: implement other types of storage return nil, handlers.ErrInvalidConfig } } -func (ssf *shardedStorageFactory) createMongoDB() (core.StorageWithIndexChecker, error) { +func (ssf *storageWithIndexFactory) createMongoDB() (core.StorageWithIndexChecker, error) { client, err := mongodb.CreateMongoDBClient(ssf.cfg.MongoDB) if err != nil { return nil, err @@ -66,7 +65,7 @@ func (ssf *shardedStorageFactory) createMongoDB() (core.StorageWithIndexChecker, return bucket.NewShardedStorageWithIndex(argsShardedStorageWithIndex) } -func (ssf *shardedStorageFactory) createLocalDB() (core.StorageWithIndexChecker, error) { +func (ssf *storageWithIndexFactory) createLocalDB() (core.StorageWithIndexChecker, error) { numbOfBuckets := ssf.cfg.Buckets.NumberOfBuckets bucketIDProvider, err := bucket.NewBucketIDProvider(numbOfBuckets) if err != nil { @@ -102,6 +101,6 @@ func (ssf *shardedStorageFactory) createLocalDB() (core.StorageWithIndexChecker, } // IsInterfaceNil returns true if there is no value under the interface -func (ssf *shardedStorageFactory) IsInterfaceNil() bool { +func (ssf *storageWithIndexFactory) IsInterfaceNil() bool { return ssf == nil } diff --git a/handlers/storage/factory/shardedStorageFactory_test.go b/handlers/storage/factory/shardedStorageFactory_test.go index aa9c5ad2..3e14fff1 100644 --- a/handlers/storage/factory/shardedStorageFactory_test.go +++ b/handlers/storage/factory/shardedStorageFactory_test.go @@ -24,7 +24,7 @@ func TestNewShardedStorageFactory_Create(t *testing.T) { DBType: "dummy", }, } - ssf := NewShardedStorageFactory(cfg) + ssf := NewStorageWithIndexFactory(cfg) assert.False(t, check.IfNil(ssf)) shardedStorageInstance, err := ssf.Create() assert.Equal(t, handlers.ErrInvalidConfig, err) @@ -41,7 +41,7 @@ func TestNewShardedStorageFactory_Create(t *testing.T) { NumberOfBuckets: 0, }, } - ssf := NewShardedStorageFactory(cfg) + ssf := NewStorageWithIndexFactory(cfg) assert.False(t, check.IfNil(ssf)) shardedStorageInstance, err := ssf.Create() assert.NotNil(t, err) @@ -66,7 +66,7 @@ func TestNewShardedStorageFactory_Create(t *testing.T) { NumberOfBuckets: 1, }, } - ssf := NewShardedStorageFactory(cfg) + ssf := NewStorageWithIndexFactory(cfg) assert.False(t, check.IfNil(ssf)) shardedStorageInstance, err := ssf.Create() assert.NotNil(t, err) @@ -98,7 +98,7 @@ func TestNewShardedStorageFactory_Create(t *testing.T) { NumberOfBuckets: 4, }, } - ssf := NewShardedStorageFactory(cfg) + ssf := NewStorageWithIndexFactory(cfg) assert.False(t, check.IfNil(ssf)) shardedStorageInstance, err := ssf.Create() assert.Nil(t, err) From 68933e1559edd4578066b53b7db1ce1342ddf674 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Thu, 16 Mar 2023 13:56:18 +0200 Subject: [PATCH 22/35] refactor mongo operations with session and transaction --- mongodb/dbClient.go | 96 ++++++++++++++++----------------- mongodb/dbClient_test.go | 64 +++++++++++++++++----- mongodb/mongoDBClientFactory.go | 12 +++++ 3 files changed, 110 insertions(+), 62 deletions(-) diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index e2a95a8f..ef4eab60 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -22,14 +22,21 @@ const ( UsersCollectionID CollectionID = "users" ) +const initialCounterValue = 1 + type mongoEntry struct { Key string `bson:"_id"` Value []byte `bson:"value"` } +type counterMongoEntry struct { + Key string `bson:"_id"` + Value uint32 `bson:"value"` +} + type mongodbClient struct { client *mongo.Client - collections map[CollectionID]MongoDBCollection + collections map[CollectionID]*mongo.Collection ctx context.Context } @@ -49,7 +56,7 @@ func NewClient(client *mongo.Client, dbName string) (*mongodbClient, error) { return nil, err } - collections := make(map[CollectionID]MongoDBCollection) + collections := make(map[CollectionID]*mongo.Collection) collections[UsersCollectionID] = client.Database(dbName).Collection(string(UsersCollectionID)) return &mongodbClient{ @@ -160,12 +167,20 @@ func (mdc *mongodbClient) ReadWriteWithCheck( return err } - value, err := mdc.Get(collID, key) + coll, ok := mdc.collections[collID] + if !ok { + return ErrCollectionNotFound + } + + filter := bson.D{{Key: "_id", Value: string(key)}} + + entry := &mongoEntry{} + err = coll.FindOne(mdc.ctx, filter).Decode(entry) if err != nil { return err } - retValue, err := checker(value) + retValue, err := checker(entry.Value) if err != nil { return err } @@ -174,16 +189,22 @@ func (mdc *mongodbClient) ReadWriteWithCheck( return core.ErrInvalidValue } - err = mdc.Put(collID, key, retValueBytes) - if err != nil { - return err - } + filter = bson.D{{Key: "_id", Value: string(key)}} + update := bson.D{{Key: "$set", + Value: bson.D{ + {Key: "_id", Value: string(key)}, + {Key: "value", Value: retValueBytes}, + }, + }} + + opts := options.Update().SetUpsert(true) - if err = session.CommitTransaction(ctx); err != nil { + _, err = coll.UpdateOne(mdc.ctx, filter, update, opts) + if err != nil { return err } - return nil + return session.CommitTransaction(ctx) } err = mongo.WithSession(mdc.ctx, session, sessionCallback) @@ -204,51 +225,28 @@ func (mdc *mongodbClient) IncrementWithTransaction(collID CollectionID, key []by return 0, ErrCollectionNotFound } - callback := func(sessCtx mongo.SessionContext) (interface{}, error) { - filter := bson.D{{Key: "_id", Value: string(key)}} - - entry := &mongoEntry{} - err := coll.FindOne(sessCtx, filter).Decode(entry) - if err != nil { - return nil, err - } + opts := options.FindOneAndUpdate().SetUpsert(true) - latestIndexBytes, newIndex := incrementIntegerFromBytes(entry.Value) - - filter = bson.D{{Key: "_id", Value: string(key)}} - update := bson.D{{Key: "$set", - Value: bson.D{ - {Key: "_id", Value: string(key)}, - {Key: "value", Value: latestIndexBytes}, - }, - }} - - opts := options.Update().SetUpsert(true) - - _, err = coll.UpdateOne(sessCtx, filter, update, opts) - if err != nil { - return nil, err - } - - return newIndex, nil - } + filter := bson.D{{Key: "_id", Value: string(key)}} + update := bson.D{{ + Key: "$inc", + Value: bson.D{ + {Key: "value", Value: uint32(1)}, + }, + }} - session, err := mdc.client.StartSession() + entry := &counterMongoEntry{} + res := coll.FindOneAndUpdate(mdc.ctx, filter, update, opts) + err := res.Decode(entry) if err != nil { - return 0, err - } - defer session.EndSession(mdc.ctx) + if err == mongo.ErrNoDocuments { + return initialCounterValue, nil + } - newIndex, err := session.WithTransaction(mdc.ctx, callback) - if err != nil { - return 0, err - } - index, ok := newIndex.(uint32) - if !ok { - return 0, core.ErrInvalidValue + return initialCounterValue, err } - return index, nil + return entry.Value, nil } func incrementIntegerFromBytes(value []byte) ([]byte, uint32) { diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index c3fa005f..8b7e1a3e 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -2,11 +2,13 @@ package mongodb_test import ( "bytes" - "encoding/binary" "errors" + "sync" "testing" + "github.com/multiversx/multi-factor-auth-go-service/config" "github.com/multiversx/multi-factor-auth-go-service/mongodb" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/integration/mtest" @@ -235,25 +237,19 @@ func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { mt.Run("should work", func(mt *mtest.T) { mt.Parallel() - client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(mt, err) - - newIndexBytes := make([]byte, 4) - binary.BigEndian.PutUint32(newIndexBytes, 3) - mt.AddMockResponses( mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ - {Key: "_id", Value: "key2"}, - {Key: "value", Value: newIndexBytes}, + {Key: "_id", Value: "key1"}, + {Key: "value", Value: 1}, }), - mtest.CreateSuccessResponse(), - mtest.CreateSuccessResponse(), - mtest.CreateSuccessResponse(), ) + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(mt, err) + val, err := client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) require.Nil(mt, err) - require.Equal(mt, uint32(4), val) + require.Equal(mt, uint32(1), val) }) } @@ -312,3 +308,45 @@ func TestMongoDBClient_ReadWriteWithCheck(t *testing.T) { require.Nil(mt, err) }) } + +func TestMongoDBClient_ConcurrentCalls(t *testing.T) { + t.Parallel() + + client, err := mongodb.CreateMongoDBClient(config.MongoDBConfig{ + URI: "mongodb://127.0.0.1:27017/?directConnection=true&serverSelectionTimeoutMS=2000", + DBName: "main", + }) + require.Nil(t, err) + + // checker := func(data interface{}) (interface{}, error) { + // return []byte("newData"), nil + // } + + numCalls := 100 + + var wg sync.WaitGroup + wg.Add(numCalls) + for i := 0; i < numCalls; i++ { + go func(idx int) { + switch idx % 5 { + case 0: + require.Nil(t, client.Put(mongodb.UsersCollectionID, []byte("key4"), []byte("data"))) + case 1: + _, err := client.Get(mongodb.UsersCollectionID, []byte("key")) + require.Nil(t, err) + case 2: + require.Nil(t, client.Has(mongodb.UsersCollectionID, []byte("key"))) + case 3: + // err = client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key4"), checker) + // require.Nil(t, err) + case 4: + _, err := client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) + require.Nil(t, err) + default: + assert.Fail(t, "should not hit default") + } + wg.Done() + }(i) + } + wg.Wait() +} diff --git a/mongodb/mongoDBClientFactory.go b/mongodb/mongoDBClientFactory.go index 3b12a8a1..554c05fb 100644 --- a/mongodb/mongoDBClientFactory.go +++ b/mongodb/mongoDBClientFactory.go @@ -6,6 +6,8 @@ import ( "github.com/multiversx/multi-factor-auth-go-service/config" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readpref" + "go.mongodb.org/mongo-driver/mongo/writeconcern" ) const ( @@ -20,6 +22,16 @@ func CreateMongoDBClient(cfg config.MongoDBConfig) (MongoDBClient, error) { opts.SetTimeout(operationTimeoutSec * time.Second) opts.ApplyURI(cfg.URI) + writeConcern := writeconcern.New(writeconcern.WMajority()) + opts.SetWriteConcern(writeConcern) + + readPref, err := readpref.New(readpref.SecondaryPreferredMode) + if err != nil { + return nil, err + } + + opts.SetReadPreference(readPref) + client, err := mongo.NewClient(opts) if err != nil { return nil, err From bf57f7c9c8a223701028676473d80206c2ba1492 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Mon, 20 Mar 2023 20:07:48 +0200 Subject: [PATCH 23/35] added mongo operations for otp db handler --- mongodb/dbClient.go | 144 +++++++++++++++++++++++++++++--- mongodb/dbClient_test.go | 68 +++++++++------ mongodb/interface.go | 10 +++ mongodb/mongoDBClientFactory.go | 2 +- 4 files changed, 183 insertions(+), 41 deletions(-) diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index ef4eab60..3ec389cd 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -3,12 +3,14 @@ package mongodb import ( "context" "encoding/binary" + "time" "github.com/multiversx/multi-factor-auth-go-service/core" logger "github.com/multiversx/mx-chain-logger-go" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" ) @@ -34,6 +36,11 @@ type counterMongoEntry struct { Value uint32 `bson:"value"` } +type otpInfoWrapper struct { + Key string `bson:"_id"` + *core.OTPInfo +} + type mongodbClient struct { client *mongo.Client collections map[CollectionID]*mongo.Collection @@ -93,6 +100,34 @@ func (mdc *mongodbClient) Put(collID CollectionID, key []byte, data []byte) erro return nil } +func (mdc *mongodbClient) PutStruct(collID CollectionID, key []byte, data *core.OTPInfo) error { + coll, ok := mdc.collections[collID] + if !ok { + return ErrCollectionNotFound + } + + otpInfo := &otpInfoWrapper{ + Key: string(key), + OTPInfo: data, + } + + filter := bson.M{"_id": string(key)} + update := bson.M{ + "$set": otpInfo, + } + + opts := options.Update().SetUpsert(true) + + log.Trace("PutStruct", "key", string(key), "value", data.LastTOTPChangeTimestamp) + + _, err := coll.UpdateOne(mdc.ctx, filter, update, opts) + if err != nil { + return err + } + + return nil +} + func (mdc *mongodbClient) findOne(collID CollectionID, key []byte) (*mongoEntry, error) { coll, ok := mdc.collections[collID] if !ok { @@ -122,6 +157,23 @@ func (mdc *mongodbClient) Get(collID CollectionID, key []byte) ([]byte, error) { return entry.Value, nil } +func (mdc *mongodbClient) GetStruct(collID CollectionID, key []byte) (*core.OTPInfo, error) { + coll, ok := mdc.collections[collID] + if !ok { + return nil, ErrCollectionNotFound + } + + filter := bson.D{{Key: "_id", Value: string(key)}} + + entry := &otpInfoWrapper{} + err := coll.FindOne(mdc.ctx, filter).Decode(entry) + if err != nil { + return nil, err + } + + return entry.OTPInfo, nil +} + // Has will return true if the provided key exists in the collection func (mdc *mongodbClient) Has(collID CollectionID, key []byte) error { _, err := mdc.findOne(collID, key) @@ -129,6 +181,23 @@ func (mdc *mongodbClient) Has(collID CollectionID, key []byte) error { return err } +func (mdc *mongodbClient) HasStruct(collID CollectionID, key []byte) error { + coll, ok := mdc.collections[collID] + if !ok { + return ErrCollectionNotFound + } + + filter := bson.D{{Key: "_id", Value: string(key)}} + + entry := &otpInfoWrapper{} + err := coll.FindOne(mdc.ctx, filter).Decode(entry) + if err != nil { + return err + } + + return err +} + // Remove will remove the provided key from the collection func (mdc *mongodbClient) Remove(collID CollectionID, key []byte) error { coll, ok := mdc.collections[collID] @@ -146,6 +215,38 @@ func (mdc *mongodbClient) Remove(collID CollectionID, key []byte) error { return nil } +func (mdc *mongodbClient) UpdateTimestamp(collID CollectionID, key []byte, interval int64) (int64, error) { + coll, ok := mdc.collections[collID] + if !ok { + return 0, ErrCollectionNotFound + } + + opts := options.FindOneAndUpdate().SetUpsert(false) + + currentTimestamp := time.Now().Unix() + compareValue := currentTimestamp - interval + + filter := bson.M{"_id": string(key), "otpinfo.lasttotpchangetimestamp": bson.M{"$lt": compareValue}} + update := bson.M{ + "$set": bson.M{ + "otpinfo.lasttotpchangetimestamp": time.Now().Unix(), + }, + } + + entry := &core.OTPInfo{} + res := coll.FindOneAndUpdate(mdc.ctx, filter, update, opts) + err := res.Decode(entry) + if err != nil { + if err == mongo.ErrNoDocuments { + return currentTimestamp, nil + } + + return currentTimestamp, err + } + + return entry.LastTOTPChangeTimestamp, nil +} + // ReadWriteWithCheck will perform read and write operation with a provided checker func (mdc *mongodbClient) ReadWriteWithCheck( collID CollectionID, @@ -160,6 +261,7 @@ func (mdc *mongodbClient) ReadWriteWithCheck( wc := writeconcern.New(writeconcern.WMajority()) txnOptions := options.Transaction().SetWriteConcern(wc) + txnOptions.SetReadPreference(readpref.Primary()) sessionCallback := func(ctx mongo.SessionContext) error { err := session.StartTransaction(txnOptions) @@ -172,30 +274,35 @@ func (mdc *mongodbClient) ReadWriteWithCheck( return ErrCollectionNotFound } - filter := bson.D{{Key: "_id", Value: string(key)}} + filter := bson.M{"_id": string(key)} - entry := &mongoEntry{} - err = coll.FindOne(mdc.ctx, filter).Decode(entry) + entry := &otpInfoWrapper{} + err = coll.FindOne(ctx, filter).Decode(entry) if err != nil { return err } - retValue, err := checker(entry.Value) + checker = func(data interface{}) (interface{}, error) { + return &core.OTPInfo{}, nil + } + + retValue, err := checker(entry.OTPInfo) if err != nil { return err } - retValueBytes, ok := retValue.([]byte) + retValueBytes, ok := retValue.(*core.OTPInfo) if !ok { return core.ErrInvalidValue } - filter = bson.D{{Key: "_id", Value: string(key)}} - update := bson.D{{Key: "$set", - Value: bson.D{ - {Key: "_id", Value: string(key)}, - {Key: "value", Value: retValueBytes}, - }, - }} + otpInfo := &otpInfoWrapper{ + Key: string(key), + OTPInfo: retValueBytes, + } + + update := bson.M{ + "$set": otpInfo, + } opts := options.Update().SetUpsert(true) @@ -209,9 +316,20 @@ func (mdc *mongodbClient) ReadWriteWithCheck( err = mongo.WithSession(mdc.ctx, session, sessionCallback) if err != nil { - if err := session.AbortTransaction(mdc.ctx); err != nil { + abortErr := session.AbortTransaction(mdc.ctx) + if abortErr != nil { + return abortErr + } + + cmdErr, ok := err.(mongo.CommandError) + if !ok { return err } + if cmdErr.HasErrorLabel("TransientTransactionError") { + log.Error(err.Error()) + return cmdErr + } + return err } diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index 8b7e1a3e..e81e791e 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -1,12 +1,12 @@ package mongodb_test import ( - "bytes" "errors" "sync" "testing" "github.com/multiversx/multi-factor-auth-go-service/config" + "github.com/multiversx/multi-factor-auth-go-service/core" "github.com/multiversx/multi-factor-auth-go-service/mongodb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -59,10 +59,12 @@ func TestMongoDBClient_Put(t *testing.T) { mt.Parallel() expectedErr := errors.New("expected error") - mt.AddMockResponses(mtest.CreateCommandErrorResponse(mtest.CommandError{ - Code: 1, - Message: expectedErr.Error(), - })) + mt.AddMockResponses( + mtest.CreateCommandErrorResponse(mtest.CommandError{ + Code: 1, + Message: expectedErr.Error(), + }), + ) client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(mt, err) @@ -78,7 +80,8 @@ func TestMongoDBClient_Put(t *testing.T) { mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ {Key: "_id", Value: "key2"}, {Key: "value", Value: []byte("value")}, - })) + }), + ) client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(mt, err) @@ -108,10 +111,12 @@ func TestMongoDBClient_Get(t *testing.T) { mt.Parallel() expectedErr := errors.New("expected err") - mt.AddMockResponses(mtest.CreateCommandErrorResponse(mtest.CommandError{ - Code: 1, - Message: expectedErr.Error(), - })) + mt.AddMockResponses( + mtest.CreateCommandErrorResponse(mtest.CommandError{ + Code: 1, + Message: expectedErr.Error(), + }), + ) client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(mt, err) @@ -127,7 +132,8 @@ func TestMongoDBClient_Get(t *testing.T) { mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ {Key: "_id", Value: "key2"}, {Key: "value", Value: []byte("value")}, - })) + }), + ) client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(mt, err) @@ -147,10 +153,12 @@ func TestMongoDBClient_Has(t *testing.T) { mt.Parallel() expectedErr := errors.New("expected err") - mt.AddMockResponses(mtest.CreateCommandErrorResponse(mtest.CommandError{ - Code: 1, - Message: expectedErr.Error(), - })) + mt.AddMockResponses( + mtest.CreateCommandErrorResponse(mtest.CommandError{ + Code: 1, + Message: expectedErr.Error(), + }), + ) client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(mt, err) @@ -166,7 +174,8 @@ func TestMongoDBClient_Has(t *testing.T) { mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ {Key: "_id", Value: "key2"}, {Key: "value", Value: []byte("value")}, - })) + }), + ) client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(mt, err) @@ -291,15 +300,16 @@ func TestMongoDBClient_ReadWriteWithCheck(t *testing.T) { mt.AddMockResponses( mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ {Key: "_id", Value: "key1"}, - {Key: "value", Value: []byte("data1")}, + {Key: "otpinfo", Value: &core.OTPInfo{LastTOTPChangeTimestamp: 101}}, }), mtest.CreateSuccessResponse(), mtest.CreateSuccessResponse(), + mtest.CreateSuccessResponse(), ) checker := func(data interface{}) (interface{}, error) { - if bytes.Equal(data.([]byte), []byte("data1")) { - return []byte("data2"), nil + if data.(*core.OTPInfo).LastTOTPChangeTimestamp == 101 { + return &core.OTPInfo{}, nil } return nil, errors.New("error") } @@ -319,10 +329,10 @@ func TestMongoDBClient_ConcurrentCalls(t *testing.T) { require.Nil(t, err) // checker := func(data interface{}) (interface{}, error) { - // return []byte("newData"), nil + // return &core.OTPInfo{}, nil // } - numCalls := 100 + numCalls := 600 var wg sync.WaitGroup wg.Add(numCalls) @@ -330,18 +340,22 @@ func TestMongoDBClient_ConcurrentCalls(t *testing.T) { go func(idx int) { switch idx % 5 { case 0: - require.Nil(t, client.Put(mongodb.UsersCollectionID, []byte("key4"), []byte("data"))) + //require.Nil(t, client.Put(mongodb.UsersCollectionID, []byte("key4"), []byte("data"))) + err := client.PutStruct(mongodb.UsersCollectionID, []byte("key"), &core.OTPInfo{LastTOTPChangeTimestamp: 101}) + require.Nil(t, err) case 1: - _, err := client.Get(mongodb.UsersCollectionID, []byte("key")) + _, err := client.GetStruct(mongodb.UsersCollectionID, []byte("key")) require.Nil(t, err) case 2: - require.Nil(t, client.Has(mongodb.UsersCollectionID, []byte("key"))) + require.Nil(t, client.HasStruct(mongodb.UsersCollectionID, []byte("key"))) case 3: - // err = client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key4"), checker) - // require.Nil(t, err) + err := client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key"), nil) + require.Nil(t, err) case 4: - _, err := client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) + _, err := client.UpdateTimestamp(mongodb.UsersCollectionID, []byte("key"), 0) require.Nil(t, err) + // _, err := client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key")) + // require.Nil(t, err) default: assert.Fail(t, "should not hit default") } diff --git a/mongodb/interface.go b/mongodb/interface.go index bda97387..a1203e2b 100644 --- a/mongodb/interface.go +++ b/mongodb/interface.go @@ -3,6 +3,7 @@ package mongodb import ( "context" + "github.com/multiversx/multi-factor-auth-go-service/core" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) @@ -39,6 +40,15 @@ type MongoDBClient interface { IsInterfaceNil() bool } +// MongoDBUsersHandler defines the behaviour of a mongo users handler component +type MongoDBUsersHandler interface { + MongoDBClient + UpdateTimestamp(collID CollectionID, key []byte, interval int64) (int64, error) + PutStruct(collID CollectionID, key []byte, data *core.OTPInfo) error + GetStruct(collID CollectionID, key []byte) (*core.OTPInfo, error) + HasStruct(collID CollectionID, key []byte) error +} + // MongoDBSession defines what a mongodb session should do type MongoDBSession interface { StartTransaction(...*options.TransactionOptions) error diff --git a/mongodb/mongoDBClientFactory.go b/mongodb/mongoDBClientFactory.go index 554c05fb..11cf8b80 100644 --- a/mongodb/mongoDBClientFactory.go +++ b/mongodb/mongoDBClientFactory.go @@ -16,7 +16,7 @@ const ( ) // CreateMongoDBClient will create a new mongo db client instance -func CreateMongoDBClient(cfg config.MongoDBConfig) (MongoDBClient, error) { +func CreateMongoDBClient(cfg config.MongoDBConfig) (MongoDBUsersHandler, error) { opts := options.Client() opts.SetConnectTimeout(connectTimeoutSec * time.Second) opts.SetTimeout(operationTimeoutSec * time.Second) From 42589f242e25a5034a3fc722dce7275aec90b9bf Mon Sep 17 00:00:00 2001 From: ssd04 Date: Mon, 20 Mar 2023 20:24:33 +0200 Subject: [PATCH 24/35] add run tx with retry --- mongodb/dbClient.go | 41 ++++++++++++++++++++++++---------------- mongodb/dbClient_test.go | 9 ++++----- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 3ec389cd..2d1e67f0 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -279,19 +279,18 @@ func (mdc *mongodbClient) ReadWriteWithCheck( entry := &otpInfoWrapper{} err = coll.FindOne(ctx, filter).Decode(entry) if err != nil { + session.AbortTransaction(ctx) return err } - checker = func(data interface{}) (interface{}, error) { - return &core.OTPInfo{}, nil - } - retValue, err := checker(entry.OTPInfo) if err != nil { + session.AbortTransaction(ctx) return err } retValueBytes, ok := retValue.(*core.OTPInfo) if !ok { + session.AbortTransaction(ctx) return core.ErrInvalidValue } @@ -308,32 +307,42 @@ func (mdc *mongodbClient) ReadWriteWithCheck( _, err = coll.UpdateOne(mdc.ctx, filter, update, opts) if err != nil { + session.AbortTransaction(ctx) return err } return session.CommitTransaction(ctx) } - err = mongo.WithSession(mdc.ctx, session, sessionCallback) + err = mongo.WithSession(mdc.ctx, session, + func(sctx mongo.SessionContext) error { + return runTxWithRetry(sctx, sessionCallback) + }, + ) if err != nil { - abortErr := session.AbortTransaction(mdc.ctx) - if abortErr != nil { - return abortErr + return err + } + + return nil +} + +func runTxWithRetry(sctx mongo.SessionContext, txnFn func(mongo.SessionContext) error) error { + for { + err := txnFn(sctx) + if err == nil { + return nil } + log.Trace("Transaction aborted. Caught exception during transaction.") + cmdErr, ok := err.(mongo.CommandError) - if !ok { - return err - } - if cmdErr.HasErrorLabel("TransientTransactionError") { - log.Error(err.Error()) - return cmdErr + if ok && cmdErr.HasErrorLabel("TransientTransactionError") { + log.Trace("TransientTransactionError, retrying transaction...") + continue } return err } - - return nil } // IncrementWithTransaction will increment the value for the provided key, within a transaction diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index e81e791e..2ab930ec 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -328,9 +328,9 @@ func TestMongoDBClient_ConcurrentCalls(t *testing.T) { }) require.Nil(t, err) - // checker := func(data interface{}) (interface{}, error) { - // return &core.OTPInfo{}, nil - // } + checker := func(data interface{}) (interface{}, error) { + return &core.OTPInfo{}, nil + } numCalls := 600 @@ -340,7 +340,6 @@ func TestMongoDBClient_ConcurrentCalls(t *testing.T) { go func(idx int) { switch idx % 5 { case 0: - //require.Nil(t, client.Put(mongodb.UsersCollectionID, []byte("key4"), []byte("data"))) err := client.PutStruct(mongodb.UsersCollectionID, []byte("key"), &core.OTPInfo{LastTOTPChangeTimestamp: 101}) require.Nil(t, err) case 1: @@ -349,7 +348,7 @@ func TestMongoDBClient_ConcurrentCalls(t *testing.T) { case 2: require.Nil(t, client.HasStruct(mongodb.UsersCollectionID, []byte("key"))) case 3: - err := client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key"), nil) + err := client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key"), checker) require.Nil(t, err) case 4: _, err := client.UpdateTimestamp(mongodb.UsersCollectionID, []byte("key"), 0) From 80572d62adfbf619af70a553dfd078c9970585ac Mon Sep 17 00:00:00 2001 From: ssd04 Date: Mon, 20 Mar 2023 21:05:12 +0200 Subject: [PATCH 25/35] added integration test for mongo with in memory testing db --- go.mod | 5 +- go.sum | 13 +++++ mongodb/dbClient.go | 11 ----- mongodb/dbClient_test.go | 48 ------------------ mongodb/integrationTests/mongo_test.go | 67 ++++++++++++++++++++++++++ 5 files changed, 84 insertions(+), 60 deletions(-) create mode 100644 mongodb/integrationTests/mongo_test.go diff --git a/go.mod b/go.mod index 64546496..24b99bc1 100644 --- a/go.mod +++ b/go.mod @@ -17,11 +17,12 @@ require ( github.com/sec51/twofactor v1.0.0 github.com/stretchr/testify v1.8.1 github.com/urfave/cli v1.22.10 - go.mongodb.org/mongo-driver v1.11.2 + go.mongodb.org/mongo-driver v1.11.3 ) require ( filippo.io/edwards25519 v1.0.0 // indirect + github.com/acobaugh/osrelease v0.0.0-20181218015638-a93a0a55a249 // indirect github.com/beevik/ntp v0.3.0 // indirect github.com/btcsuite/btcd/btcutil v1.1.3 // indirect github.com/bytedance/sonic v1.8.0 // indirect @@ -62,7 +63,9 @@ require ( github.com/sec51/gf256 v0.0.0-20160126143050-2454accbeb9e // indirect github.com/sec51/qrcode v0.0.0-20160126144534-b7779abbcaf1 // indirect github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect + github.com/spf13/afero v1.6.0 // indirect github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // indirect + github.com/tryvium-travels/memongo v0.9.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/tyler-smith/go-bip39 v1.1.0 // indirect github.com/ugorji/go/codec v1.2.9 // indirect diff --git a/go.sum b/go.sum index 5ef98c1a..f2fade9c 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek= filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/acobaugh/osrelease v0.0.0-20181218015638-a93a0a55a249 h1:fMi9ZZ/it4orHj3xWrM6cLkVFcCbkXQALFUiNtHtCPs= +github.com/acobaugh/osrelease v0.0.0-20181218015638-a93a0a55a249/go.mod h1:iU1PxQMQwoHZZWmMKrMkrNlY+3+p9vxIjpZOVyxWa0g= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/beevik/ntp v0.3.0 h1:xzVrPrE4ziasFXgBVBZJDP0Wg/KpMwk2KHJ4Ba8GrDw= github.com/beevik/ntp v0.3.0/go.mod h1:hIHWr+l3+/clUnF44zdK+CWW7fO8dR5cIylAQ76NRpg= @@ -114,6 +116,7 @@ github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47e github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.1.0 h1:eyi1Ad2aNJMW95zcSbmGg7Cg6cq3ADwLpMAP96d8rF0= github.com/klauspost/cpuid/v2 v2.1.0/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= @@ -175,8 +178,10 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU= github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= @@ -196,10 +201,13 @@ github.com/sec51/twofactor v1.0.0 h1:1BTbzPhyMyB0YvcWxgNxEkI7WDNsBLvR+z699YWGMC8 github.com/sec51/twofactor v1.0.0/go.mod h1:CjtKwpvQSs9SYzLUsRH7gML+TgKeIofT8uxoy7RTLQI= github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/spf13/afero v1.6.0 h1:xoax2sJ2DT8S8xA2paPFjDCScCNeWsg75VG0DLRreiY= +github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -210,6 +218,8 @@ github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tryvium-travels/memongo v0.9.0 h1:k5kuTHSDdITN+aMaNoKr1nzTyNvhwOXtb9SnrMjRg+I= +github.com/tryvium-travels/memongo v0.9.0/go.mod h1:riRUHKRQ5JbeX2ryzFfmr7P2EYXIkNwgloSQJPpBikA= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/tyler-smith/go-bip39 v1.1.0 h1:5eUemwrMargf3BSLRRCalXT93Ns6pQJIjYQN2nyfOP8= @@ -232,10 +242,13 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.mongodb.org/mongo-driver v1.11.2 h1:+1v2rDQUWNcGW7/7E0Jvdz51V38XXxJfhzbV17aNHCw= go.mongodb.org/mongo-driver v1.11.2/go.mod h1:s7p5vEtfbeR1gYi6pnj3c3/urpbLv2T5Sfd6Rp2HBB8= +go.mongodb.org/mongo-driver v1.11.3 h1:Ql6K6qYHEzB6xvu4+AU0BoRoqf9vFPcc4o7MUIdPW8Y= +go.mongodb.org/mongo-driver v1.11.3/go.mod h1:PTSz5yu21bkT/wXpkS7WR5f0ddqw5quethTUn9WM+2g= golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 2d1e67f0..62daff5c 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -2,7 +2,6 @@ package mongodb import ( "context" - "encoding/binary" "time" "github.com/multiversx/multi-factor-auth-go-service/core" @@ -376,16 +375,6 @@ func (mdc *mongodbClient) IncrementWithTransaction(collID CollectionID, key []by return entry.Value, nil } -func incrementIntegerFromBytes(value []byte) ([]byte, uint32) { - uint32Value := binary.BigEndian.Uint32(value) - newIndex := uint32Value + 1 - - newIndexBytes := make([]byte, 4) - binary.BigEndian.PutUint32(newIndexBytes, newIndex) - - return newIndexBytes, newIndex -} - // Close will close the mongodb client func (mdc *mongodbClient) Close() error { return mdc.client.Disconnect(mdc.ctx) diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index 2ab930ec..b1e2b090 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -2,13 +2,10 @@ package mongodb_test import ( "errors" - "sync" "testing" - "github.com/multiversx/multi-factor-auth-go-service/config" "github.com/multiversx/multi-factor-auth-go-service/core" "github.com/multiversx/multi-factor-auth-go-service/mongodb" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/integration/mtest" @@ -318,48 +315,3 @@ func TestMongoDBClient_ReadWriteWithCheck(t *testing.T) { require.Nil(mt, err) }) } - -func TestMongoDBClient_ConcurrentCalls(t *testing.T) { - t.Parallel() - - client, err := mongodb.CreateMongoDBClient(config.MongoDBConfig{ - URI: "mongodb://127.0.0.1:27017/?directConnection=true&serverSelectionTimeoutMS=2000", - DBName: "main", - }) - require.Nil(t, err) - - checker := func(data interface{}) (interface{}, error) { - return &core.OTPInfo{}, nil - } - - numCalls := 600 - - var wg sync.WaitGroup - wg.Add(numCalls) - for i := 0; i < numCalls; i++ { - go func(idx int) { - switch idx % 5 { - case 0: - err := client.PutStruct(mongodb.UsersCollectionID, []byte("key"), &core.OTPInfo{LastTOTPChangeTimestamp: 101}) - require.Nil(t, err) - case 1: - _, err := client.GetStruct(mongodb.UsersCollectionID, []byte("key")) - require.Nil(t, err) - case 2: - require.Nil(t, client.HasStruct(mongodb.UsersCollectionID, []byte("key"))) - case 3: - err := client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key"), checker) - require.Nil(t, err) - case 4: - _, err := client.UpdateTimestamp(mongodb.UsersCollectionID, []byte("key"), 0) - require.Nil(t, err) - // _, err := client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key")) - // require.Nil(t, err) - default: - assert.Fail(t, "should not hit default") - } - wg.Done() - }(i) - } - wg.Wait() -} diff --git a/mongodb/integrationTests/mongo_test.go b/mongodb/integrationTests/mongo_test.go new file mode 100644 index 00000000..e1e6f101 --- /dev/null +++ b/mongodb/integrationTests/mongo_test.go @@ -0,0 +1,67 @@ +package integrationtests + +import ( + "sync" + "testing" + + "github.com/multiversx/multi-factor-auth-go-service/config" + "github.com/multiversx/multi-factor-auth-go-service/core" + "github.com/multiversx/multi-factor-auth-go-service/mongodb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tryvium-travels/memongo" +) + +func TestMongoDBClient_ConcurrentCalls(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.Skip("this is not a short test") + } + + inMemoryMongoDB, err := memongo.StartWithOptions(&memongo.Options{MongoVersion: "4.4.0", ShouldUseReplica: true}) + require.Nil(t, err) + defer inMemoryMongoDB.Stop() + + client, err := mongodb.CreateMongoDBClient(config.MongoDBConfig{ + URI: inMemoryMongoDB.URI(), + DBName: memongo.RandomDatabase(), + }) + require.Nil(t, err) + + checker := func(data interface{}) (interface{}, error) { + return &core.OTPInfo{}, nil + } + + err = client.PutStruct(mongodb.UsersCollectionID, []byte("key"), &core.OTPInfo{LastTOTPChangeTimestamp: 101}) + require.Nil(t, err) + + numCalls := 6 + + var wg sync.WaitGroup + wg.Add(numCalls) + for i := 0; i < numCalls; i++ { + go func(idx int) { + switch idx % 5 { + case 0: + err := client.PutStruct(mongodb.UsersCollectionID, []byte("key"), &core.OTPInfo{LastTOTPChangeTimestamp: 101}) + require.Nil(t, err) + case 1: + _, err := client.GetStruct(mongodb.UsersCollectionID, []byte("key")) + require.Nil(t, err) + case 2: + require.Nil(t, client.HasStruct(mongodb.UsersCollectionID, []byte("key"))) + case 3: + err := client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key"), checker) + require.Nil(t, err) + case 4: + _, err := client.UpdateTimestamp(mongodb.UsersCollectionID, []byte("key"), 0) + require.Nil(t, err) + default: + assert.Fail(t, "should not hit default") + } + wg.Done() + }(i) + } + wg.Wait() +} From d1a080b92346bad3cff64e06bf7b41776d3b854c Mon Sep 17 00:00:00 2001 From: ssd04 Date: Mon, 20 Mar 2023 21:14:10 +0200 Subject: [PATCH 26/35] exclude integration test from CI run, fix linter issue --- mongodb/dbClient.go | 8 ++++---- mongodb/integrationTests/mongo_test.go | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 62daff5c..711b96bb 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -278,18 +278,18 @@ func (mdc *mongodbClient) ReadWriteWithCheck( entry := &otpInfoWrapper{} err = coll.FindOne(ctx, filter).Decode(entry) if err != nil { - session.AbortTransaction(ctx) + _ = session.AbortTransaction(ctx) return err } retValue, err := checker(entry.OTPInfo) if err != nil { - session.AbortTransaction(ctx) + _ = session.AbortTransaction(ctx) return err } retValueBytes, ok := retValue.(*core.OTPInfo) if !ok { - session.AbortTransaction(ctx) + _ = session.AbortTransaction(ctx) return core.ErrInvalidValue } @@ -306,7 +306,7 @@ func (mdc *mongodbClient) ReadWriteWithCheck( _, err = coll.UpdateOne(mdc.ctx, filter, update, opts) if err != nil { - session.AbortTransaction(ctx) + _ = session.AbortTransaction(ctx) return err } diff --git a/mongodb/integrationTests/mongo_test.go b/mongodb/integrationTests/mongo_test.go index e1e6f101..5f6bd117 100644 --- a/mongodb/integrationTests/mongo_test.go +++ b/mongodb/integrationTests/mongo_test.go @@ -1,6 +1,7 @@ package integrationtests import ( + "os" "sync" "testing" @@ -15,8 +16,8 @@ import ( func TestMongoDBClient_ConcurrentCalls(t *testing.T) { t.Parallel() - if testing.Short() { - t.Skip("this is not a short test") + if os.Getenv("CI") != "" { + t.Skip("Skipping testing in CI environment") } inMemoryMongoDB, err := memongo.StartWithOptions(&memongo.Options{MongoVersion: "4.4.0", ShouldUseReplica: true}) From 5b90ba948594583c5f187a338a307fb2e9ff831a Mon Sep 17 00:00:00 2001 From: ssd04 Date: Mon, 20 Mar 2023 21:18:06 +0200 Subject: [PATCH 27/35] fix renaming in main --- cmd/multi-factor-auth/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/multi-factor-auth/main.go b/cmd/multi-factor-auth/main.go index 116a8d3e..34ebd741 100644 --- a/cmd/multi-factor-auth/main.go +++ b/cmd/multi-factor-auth/main.go @@ -141,7 +141,7 @@ func startService(ctx *cli.Context, version string) error { return err } - shardedStorageFactory := storageFactory.NewShardedStorageFactory(cfg) + shardedStorageFactory := storageFactory.NewStorageWithIndexFactory(cfg) registeredUsersDB, err := shardedStorageFactory.Create() if err != nil { return err @@ -169,7 +169,7 @@ func startService(ctx *cli.Context, version string) error { twoFactorHandler := handlers.NewTwoFactorHandler(cfg.TwoFactor.Digits, cfg.TwoFactor.Issuer) argsStorageHandler := storage.ArgDBOTPHandler{ - DB: registeredUsersDB, + DB: registeredUsersDB, TOTPHandler: twoFactorHandler, Marshaller: gogoMarshaller, DelayBetweenOTPUpdatesInSec: cfg.ShardedStorage.DelayBetweenWritesInSec, From 0be0aab3739b598500d82c15b63d66b6536e832d Mon Sep 17 00:00:00 2001 From: ssd04 Date: Fri, 24 Mar 2023 11:22:31 +0200 Subject: [PATCH 28/35] renamings, refactorings, added consts to config --- cmd/multi-factor-auth/config/config.toml | 2 ++ config/config.go | 6 ++-- core/interface.go | 4 +-- .../storage/bucket/shardedStorageWithIndex.go | 6 ++-- .../bucket/shardedStorageWithIndex_test.go | 32 +++++++++---------- .../storage/factory/shardedStorageFactory.go | 4 +-- mongodb/dbClient_test.go | 9 ++---- mongodb/mongoDBClientFactory.go | 9 ++---- ...entWrapperStub.go => mongoDBClientStub.go} | 0 9 files changed, 33 insertions(+), 39 deletions(-) rename testscommon/{mongoDBClientWrapperStub.go => mongoDBClientStub.go} (100%) diff --git a/cmd/multi-factor-auth/config/config.toml b/cmd/multi-factor-auth/config/config.toml index f43e33f7..e69acac5 100644 --- a/cmd/multi-factor-auth/config/config.toml +++ b/cmd/multi-factor-auth/config/config.toml @@ -63,3 +63,5 @@ [MongoDB] URI = "mongodb://mongodb0:27017,mongodb1:27018,mongodb2:27019/?replicaSet=mongoReplSet" DBName = "main" + ConnectTimeoutInSec = 60 + OperationTimeoutInSec = 60 diff --git a/config/config.go b/config/config.go index 456592d4..e531e1ed 100644 --- a/config/config.go +++ b/config/config.go @@ -139,6 +139,8 @@ type TwoFactorConfig struct { // MongoDBConfig maps the mongodb configuration type MongoDBConfig struct { - URI string - DBName string + URI string + DBName string + ConnectTimeoutInSec uint32 + OperationTimeoutInSec uint32 } diff --git a/core/interface.go b/core/interface.go index 40944138..c9513643 100644 --- a/core/interface.go +++ b/core/interface.go @@ -70,8 +70,8 @@ type BucketIDProvider interface { IsInterfaceNil() bool } -// BucketIndexHandler defines the methods for a component which handles a bucket -type BucketIndexHandler interface { +// IndexHandler defines the methods for a component which handles a bucket +type IndexHandler interface { Put(key, data []byte) error Get(key []byte) ([]byte, error) Has(key []byte) error diff --git a/handlers/storage/bucket/shardedStorageWithIndex.go b/handlers/storage/bucket/shardedStorageWithIndex.go index d8925a80..84d4d4f6 100644 --- a/handlers/storage/bucket/shardedStorageWithIndex.go +++ b/handlers/storage/bucket/shardedStorageWithIndex.go @@ -17,12 +17,12 @@ const ( // ArgShardedStorageWithIndex is the DTO used to create a new instance of sharded storage with index type ArgShardedStorageWithIndex struct { BucketIDProvider core.BucketIDProvider - BucketHandlers map[uint32]core.BucketIndexHandler + BucketHandlers map[uint32]core.IndexHandler } type shardedStorageWithIndex struct { bucketIDProvider core.BucketIDProvider - bucketHandlers map[uint32]core.BucketIndexHandler + bucketHandlers map[uint32]core.IndexHandler } // NewShardedStorageWithIndex returns a new instance of sharded storage with index @@ -143,7 +143,7 @@ func (sswi *shardedStorageWithIndex) getBucketIDAndBaseIndex(address []byte) (ui return bucketID, index, err } -func (sswi *shardedStorageWithIndex) getBucketForKey(key []byte) (core.BucketIndexHandler, uint32, error) { +func (sswi *shardedStorageWithIndex) getBucketForKey(key []byte) (core.IndexHandler, uint32, error) { bucketID := sswi.bucketIDProvider.GetBucketForAddress(key) bucket, found := sswi.bucketHandlers[bucketID] if !found { diff --git a/handlers/storage/bucket/shardedStorageWithIndex_test.go b/handlers/storage/bucket/shardedStorageWithIndex_test.go index daa5fa0c..8f2d9d68 100644 --- a/handlers/storage/bucket/shardedStorageWithIndex_test.go +++ b/handlers/storage/bucket/shardedStorageWithIndex_test.go @@ -44,7 +44,7 @@ func TestNewShardedStorageWithIndex(t *testing.T) { args := ArgShardedStorageWithIndex{ BucketIDProvider: &testscommon.BucketIDProviderStub{}, - BucketHandlers: make(map[uint32]core.BucketIndexHandler), + BucketHandlers: make(map[uint32]core.IndexHandler), } sswi, err := NewShardedStorageWithIndex(args) assert.Equal(t, core.ErrInvalidBucketHandlers, err) @@ -53,7 +53,7 @@ func TestNewShardedStorageWithIndex(t *testing.T) { t.Run("nil BucketHandler should error", func(t *testing.T) { t.Parallel() - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ 0: &testscommon.BucketIndexHandlerStub{}, 1: &testscommon.BucketIndexHandlerStub{}, 2: nil, @@ -70,7 +70,7 @@ func TestNewShardedStorageWithIndex(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ 0: &testscommon.BucketIndexHandlerStub{}, 1: &testscommon.BucketIndexHandlerStub{}, } @@ -98,7 +98,7 @@ func TestShardedStorageWithIndex_getBucketForKey(t *testing.T) { return providedIdx }, } - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ providedIdx - 1: &testscommon.BucketIndexHandlerStub{}, } args := ArgShardedStorageWithIndex{ @@ -124,7 +124,7 @@ func TestShardedStorageWithIndex_getBucketForKey(t *testing.T) { return providedIdx }, } - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ providedIdx: &testscommon.BucketIndexHandlerStub{}, } args := ArgShardedStorageWithIndex{ @@ -151,7 +151,7 @@ func TestShardedStorageWithIndex_getBucketForKey(t *testing.T) { return providedIdx }, } - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ providedIdx: &testscommon.BucketIndexHandlerStub{}, } args := ArgShardedStorageWithIndex{ @@ -176,7 +176,7 @@ func TestIndexHandler_AllocateIndex(t *testing.T) { t.Run("get base index returns error", func(t *testing.T) { t.Parallel() - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ 0: &testscommon.BucketIndexHandlerStub{}, 1: &testscommon.BucketIndexHandlerStub{}, } @@ -201,7 +201,7 @@ func TestIndexHandler_AllocateIndex(t *testing.T) { providedBucketID := uint32(5) providedIndex := uint32(100) - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ 0: &testscommon.BucketIndexHandlerStub{}, 1: &testscommon.BucketIndexHandlerStub{}, 2: &testscommon.BucketIndexHandlerStub{}, @@ -267,7 +267,7 @@ func TestShardedStorageWithIndex_Close(t *testing.T) { t.Parallel() calledCounter := 0 - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ 0: &testscommon.BucketIndexHandlerStub{ CloseCalled: func() error { calledCounter++ @@ -300,7 +300,7 @@ func TestShardedStorageWithIndex_Close(t *testing.T) { t.Parallel() calledCounter := 0 - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ 0: &testscommon.BucketIndexHandlerStub{ CloseCalled: func() error { calledCounter++ @@ -337,7 +337,7 @@ func TestShardedStorageWithIndex_Count(t *testing.T) { t.Run("one bucked returns error should error", func(t *testing.T) { t.Parallel() - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ 0: &testscommon.BucketIndexHandlerStub{ GetLastIndexCalled: func() (uint32, error) { return uint32(100), nil @@ -363,7 +363,7 @@ func TestShardedStorageWithIndex_Count(t *testing.T) { t.Parallel() providedLastIndex0, providedLastIndex1, providedLastIndex2 := uint32(100), uint32(200), uint32(300) - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ 0: &testscommon.BucketIndexHandlerStub{ GetLastIndexCalled: func() (uint32, error) { return providedLastIndex0, nil @@ -410,7 +410,7 @@ func testGetBucketIDAndBaseIndex(shouldWork bool) func(t *testing.T) { if !shouldWork { key = providedIdx + 1 } - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ key: &testscommon.BucketIndexHandlerStub{ AllocateBucketIndexCalled: func() (uint32, error) { wasCalled = true @@ -457,7 +457,7 @@ func testHas(shouldWork bool) func(t *testing.T) { if !shouldWork { key = providedIdx + 1 } - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ key: &testscommon.BucketIndexHandlerStub{ HasCalled: func(key []byte) error { assert.Equal(t, providedAddr, key) @@ -501,7 +501,7 @@ func testGet(shouldWork bool) func(t *testing.T) { if !shouldWork { key = providedIdx + 1 } - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ key: &testscommon.BucketIndexHandlerStub{ GetCalled: func(key []byte) ([]byte, error) { assert.Equal(t, providedAddr, key) @@ -548,7 +548,7 @@ func testPut(shouldWork bool) func(t *testing.T) { if !shouldWork { key = providedIdx + 1 } - bucketHandlers := map[uint32]core.BucketIndexHandler{ + bucketHandlers := map[uint32]core.IndexHandler{ key: &testscommon.BucketIndexHandlerStub{ PutCalled: func(key, data []byte) error { assert.Equal(t, providedAddr, key) diff --git a/handlers/storage/factory/shardedStorageFactory.go b/handlers/storage/factory/shardedStorageFactory.go index 78345d6f..0d2d2d0b 100644 --- a/handlers/storage/factory/shardedStorageFactory.go +++ b/handlers/storage/factory/shardedStorageFactory.go @@ -51,7 +51,7 @@ func (ssf *storageWithIndexFactory) createMongoDB() (core.StorageWithIndexChecke return nil, err } - bucketIndexHandlers := make(map[uint32]core.BucketIndexHandler, 1) + bucketIndexHandlers := make(map[uint32]core.IndexHandler) bucketIndexHandlers[0], err = bucket.NewMongoDBIndexHandler(storer, client) if err != nil { return nil, err @@ -73,7 +73,7 @@ func (ssf *storageWithIndexFactory) createLocalDB() (core.StorageWithIndexChecke } localDBCfg := ssf.cfg.ShardedStorage.Users - bucketIndexHandlers := make(map[uint32]core.BucketIndexHandler, numbOfBuckets) + bucketIndexHandlers := make(map[uint32]core.IndexHandler, numbOfBuckets) var bucketStorer core.Storer for i := uint32(0); i < numbOfBuckets; i++ { cacheCfg := localDBCfg.Cache diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index b1e2b090..1643a2c8 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -11,6 +11,8 @@ import ( "go.mongodb.org/mongo-driver/mongo/integration/mtest" ) +var expectedErr = errors.New("expected error") + func TestNewMongoDBClient(t *testing.T) { t.Parallel() @@ -55,7 +57,6 @@ func TestMongoDBClient_Put(t *testing.T) { mt.Run("should fail", func(mt *mtest.T) { mt.Parallel() - expectedErr := errors.New("expected error") mt.AddMockResponses( mtest.CreateCommandErrorResponse(mtest.CommandError{ Code: 1, @@ -107,7 +108,6 @@ func TestMongoDBClient_Get(t *testing.T) { mt.Run("find one entry failed", func(mt *mtest.T) { mt.Parallel() - expectedErr := errors.New("expected err") mt.AddMockResponses( mtest.CreateCommandErrorResponse(mtest.CommandError{ Code: 1, @@ -149,7 +149,6 @@ func TestMongoDBClient_Has(t *testing.T) { mt.Run("should fail", func(mt *mtest.T) { mt.Parallel() - expectedErr := errors.New("expected err") mt.AddMockResponses( mtest.CreateCommandErrorResponse(mtest.CommandError{ Code: 1, @@ -227,8 +226,6 @@ func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(mt, err) - expectedErr := errors.New("expected error") - mt.AddMockResponses( mtest.CreateCommandErrorResponse(mtest.CommandError{ Code: 1, @@ -271,8 +268,6 @@ func TestMongoDBClient_ReadWriteWithCheck(t *testing.T) { client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(mt, err) - expectedErr := errors.New("expected error") - mt.AddMockResponses( mtest.CreateCommandErrorResponse(mtest.CommandError{ Code: 1, diff --git a/mongodb/mongoDBClientFactory.go b/mongodb/mongoDBClientFactory.go index 11cf8b80..0aa6d929 100644 --- a/mongodb/mongoDBClientFactory.go +++ b/mongodb/mongoDBClientFactory.go @@ -10,16 +10,11 @@ import ( "go.mongodb.org/mongo-driver/mongo/writeconcern" ) -const ( - connectTimeoutSec = 60 - operationTimeoutSec = 60 -) - // CreateMongoDBClient will create a new mongo db client instance func CreateMongoDBClient(cfg config.MongoDBConfig) (MongoDBUsersHandler, error) { opts := options.Client() - opts.SetConnectTimeout(connectTimeoutSec * time.Second) - opts.SetTimeout(operationTimeoutSec * time.Second) + opts.SetConnectTimeout(time.Duration(cfg.ConnectTimeoutInSec) * time.Second) + opts.SetTimeout(time.Duration(cfg.OperationTimeoutInSec) * time.Second) opts.ApplyURI(cfg.URI) writeConcern := writeconcern.New(writeconcern.WMajority()) diff --git a/testscommon/mongoDBClientWrapperStub.go b/testscommon/mongoDBClientStub.go similarity index 100% rename from testscommon/mongoDBClientWrapperStub.go rename to testscommon/mongoDBClientStub.go From e7424f36b59b4bc2986318ede3a8ce96eb5c0bc6 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Sat, 25 Mar 2023 21:21:41 +0200 Subject: [PATCH 29/35] refactor to use mongodb index operations --- .../storage/bucket/mongodbIndexHandler.go | 9 +- mongodb/dbClient.go | 63 ++++++++++++- mongodb/dbClient_test.go | 93 +++++++++++++++++++ mongodb/integrationTests/mongo_test.go | 2 +- mongodb/interface.go | 4 +- 5 files changed, 160 insertions(+), 11 deletions(-) diff --git a/handlers/storage/bucket/mongodbIndexHandler.go b/handlers/storage/bucket/mongodbIndexHandler.go index 424229ce..545363b8 100644 --- a/handlers/storage/bucket/mongodbIndexHandler.go +++ b/handlers/storage/bucket/mongodbIndexHandler.go @@ -30,12 +30,7 @@ func NewMongoDBIndexHandler(storer core.Storer, mongoClient mongodb.MongoDBClien mongodbClient: mongoClient, } - err := storer.Has([]byte(lastIndexKey)) - if err == nil { - return handler, nil - } - - err = saveNewIndex(handler.storer, initialIndexValue) + err := handler.mongodbClient.PutIndex(mongodb.UsersCollectionID, []byte(lastIndexKey), initialIndexValue) if err != nil { return nil, err } @@ -48,7 +43,7 @@ func (handler *mongodbIndexHandler) AllocateBucketIndex() (uint32, error) { handler.mut.Lock() defer handler.mut.Unlock() - newIndex, err := handler.mongodbClient.IncrementWithTransaction(mongodb.UsersCollectionID, []byte(lastIndexKey)) + newIndex, err := handler.mongodbClient.IncrementIndex(mongodb.UsersCollectionID, []byte(lastIndexKey)) if err != nil { return 0, err } diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 711b96bb..0a8fa367 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -99,6 +99,32 @@ func (mdc *mongodbClient) Put(collID CollectionID, key []byte, data []byte) erro return nil } +func (mdc *mongodbClient) PutIfNotExists(collID CollectionID, key []byte, data []byte) error { + coll, ok := mdc.collections[collID] + if !ok { + return ErrCollectionNotFound + } + + filter := bson.D{{Key: "_id", Value: string(key)}} + update := bson.D{{Key: "$setOnInsert", + Value: bson.D{ + {Key: "_id", Value: string(key)}, + {Key: "value", Value: data}, + }, + }} + + opts := options.Update().SetUpsert(true) + + res, err := coll.UpdateOne(mdc.ctx, filter, update, opts) + if err != nil { + return err + } + + log.Trace("PutIfNotExists", "key", string(key), "value", string(data), "modifiedCount", res.ModifiedCount) + + return nil +} + func (mdc *mongodbClient) PutStruct(collID CollectionID, key []byte, data *core.OTPInfo) error { coll, ok := mdc.collections[collID] if !ok { @@ -344,8 +370,34 @@ func runTxWithRetry(sctx mongo.SessionContext, txnFn func(mongo.SessionContext) } } -// IncrementWithTransaction will increment the value for the provided key, within a transaction -func (mdc *mongodbClient) IncrementWithTransaction(collID CollectionID, key []byte) (uint32, error) { +func (mdc *mongodbClient) PutIndex(collID CollectionID, key []byte, index uint32) error { + coll, ok := mdc.collections[collID] + if !ok { + return ErrCollectionNotFound + } + + filter := bson.D{{Key: "_id", Value: string(key)}} + update := bson.D{{Key: "$setOnInsert", + Value: bson.D{ + {Key: "_id", Value: string(key)}, + {Key: "value", Value: index}, + }, + }} + + opts := options.Update().SetUpsert(true) + + res, err := coll.UpdateOne(mdc.ctx, filter, update, opts) + if err != nil { + return err + } + + log.Trace("PutIfNotExists", "key", string(key), "value", index, "modifiedCount", res.ModifiedCount) + + return nil +} + +// IncrementIndex will increment the value for the provided key +func (mdc *mongodbClient) IncrementIndex(collID CollectionID, key []byte) (uint32, error) { coll, ok := mdc.collections[collID] if !ok { return 0, ErrCollectionNotFound @@ -366,12 +418,19 @@ func (mdc *mongodbClient) IncrementWithTransaction(collID CollectionID, key []by err := res.Decode(entry) if err != nil { if err == mongo.ErrNoDocuments { + log.Trace( + "IncrementIndex: no document found, will return initial counter value", + "key", string(key), + "value", entry.Value, + ) return initialCounterValue, nil } return initialCounterValue, err } + log.Trace("IncrementIndex", "key", string(key), "value", entry.Value) + return entry.Value, nil } diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index 1643a2c8..c00e9e6f 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -3,6 +3,7 @@ package mongodb_test import ( "errors" "testing" + "time" "github.com/multiversx/multi-factor-auth-go-service/core" "github.com/multiversx/multi-factor-auth-go-service/mongodb" @@ -310,3 +311,95 @@ func TestMongoDBClient_ReadWriteWithCheck(t *testing.T) { require.Nil(mt, err) }) } + +func TestMongoDBClient_UpdateTimestamp(t *testing.T) { + t.Parallel() + + mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) + defer mt.Close() + + mt.Run("internal not passed, should not update timestamp", func(mt *mtest.T) { + mt.Parallel() + + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(mt, err) + + type otpInfoWrapper struct { + Key string `bson:"_id"` + *core.OTPInfo + } + + currentTimestamp := time.Now().Unix() + otpInfo := &core.OTPInfo{ + OTP: []byte("otpInfo1"), + LastTOTPChangeTimestamp: currentTimestamp, + } + + mt.AddMockResponses( + mtest.CreateSuccessResponse(), + bson.D{ + {Key: "ok", Value: 1}, + {Key: "value", Value: otpInfo}, + }, + ) + + err = client.PutStruct(mongodb.UsersCollectionID, []byte("key1"), otpInfo) + require.Nil(t, err) + + timestamp, err := client.UpdateTimestamp(mongodb.UsersCollectionID, []byte("key1"), 10) + require.Nil(mt, err) + require.Equal(t, currentTimestamp, timestamp) + }) + + mt.Run("internal passed, should update timestamp", func(mt *mtest.T) { + mt.Parallel() + + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(mt, err) + + interval := int64(10) + + currentTimestamp := time.Now().Unix() + otpInfo := &core.OTPInfo{ + OTP: []byte("otpInfo1"), + LastTOTPChangeTimestamp: currentTimestamp, + } + + otpInfo.LastTOTPChangeTimestamp += interval + + mt.AddMockResponses( + bson.D{ + {Key: "ok", Value: 1}, + {Key: "value", Value: otpInfo}, + }, + ) + + newTimestamp, err := client.UpdateTimestamp(mongodb.UsersCollectionID, []byte("key1"), interval) + require.Nil(mt, err) + require.Greater(t, newTimestamp, currentTimestamp) + }) +} + +func TestMongoDBClient_PutIfNotExists(t *testing.T) { + t.Parallel() + + mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) + defer mt.Close() + + mt.Run("should work", func(mt *mtest.T) { + mt.Parallel() + + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(mt, err) + + mt.AddMockResponses( + bson.D{ + {Key: "ok", Value: 1}, + {Key: "value", Value: []byte("data1")}, + }, + ) + + err = client.PutIfNotExists(mongodb.UsersCollectionID, []byte("key1"), []byte("data1")) + require.Nil(mt, err) + }) +} diff --git a/mongodb/integrationTests/mongo_test.go b/mongodb/integrationTests/mongo_test.go index 5f6bd117..7df3174a 100644 --- a/mongodb/integrationTests/mongo_test.go +++ b/mongodb/integrationTests/mongo_test.go @@ -37,7 +37,7 @@ func TestMongoDBClient_ConcurrentCalls(t *testing.T) { err = client.PutStruct(mongodb.UsersCollectionID, []byte("key"), &core.OTPInfo{LastTOTPChangeTimestamp: 101}) require.Nil(t, err) - numCalls := 6 + numCalls := 60 var wg sync.WaitGroup wg.Add(numCalls) diff --git a/mongodb/interface.go b/mongodb/interface.go index a1203e2b..b09c50f5 100644 --- a/mongodb/interface.go +++ b/mongodb/interface.go @@ -27,10 +27,12 @@ type MongoDBCollection interface { // MongoDBClient defines what a mongodb client should do type MongoDBClient interface { Put(coll CollectionID, key []byte, data []byte) error + PutIfNotExists(collID CollectionID, key []byte, data []byte) error Get(coll CollectionID, key []byte) ([]byte, error) Has(coll CollectionID, key []byte) error Remove(coll CollectionID, key []byte) error - IncrementWithTransaction(collID CollectionID, key []byte) (uint32, error) + PutIndex(collID CollectionID, key []byte, index uint32) error + IncrementIndex(collID CollectionID, key []byte) (uint32, error) ReadWriteWithCheck( collID CollectionID, key []byte, From 0dd72cfd4593918f0886569142dc5aace6e86dc7 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Wed, 29 Mar 2023 18:15:51 +0300 Subject: [PATCH 30/35] add more mongodb opearation + separate index collection --- .../storage/bucket/mongodbIndexHandler.go | 4 +- .../bucket/mongodbIndexHandler_test.go | 13 +-- mongodb/dbClient.go | 64 ++++++++------- mongodb/dbClient_test.go | 79 +++++++++++++------ mongodb/interface.go | 3 +- testscommon/mongoDBClientStub.go | 32 +++++--- 6 files changed, 117 insertions(+), 78 deletions(-) diff --git a/handlers/storage/bucket/mongodbIndexHandler.go b/handlers/storage/bucket/mongodbIndexHandler.go index 545363b8..db714066 100644 --- a/handlers/storage/bucket/mongodbIndexHandler.go +++ b/handlers/storage/bucket/mongodbIndexHandler.go @@ -30,7 +30,7 @@ func NewMongoDBIndexHandler(storer core.Storer, mongoClient mongodb.MongoDBClien mongodbClient: mongoClient, } - err := handler.mongodbClient.PutIndex(mongodb.UsersCollectionID, []byte(lastIndexKey), initialIndexValue) + err := handler.mongodbClient.PutIndexIfNotExists(mongodb.IndexCollectionID, []byte(lastIndexKey), initialIndexValue) if err != nil { return nil, err } @@ -43,7 +43,7 @@ func (handler *mongodbIndexHandler) AllocateBucketIndex() (uint32, error) { handler.mut.Lock() defer handler.mut.Unlock() - newIndex, err := handler.mongodbClient.IncrementIndex(mongodb.UsersCollectionID, []byte(lastIndexKey)) + newIndex, err := handler.mongodbClient.IncrementIndex(mongodb.IndexCollectionID, []byte(lastIndexKey)) if err != nil { return 0, err } diff --git a/handlers/storage/bucket/mongodbIndexHandler_test.go b/handlers/storage/bucket/mongodbIndexHandler_test.go index b715eff1..80a7125e 100644 --- a/handlers/storage/bucket/mongodbIndexHandler_test.go +++ b/handlers/storage/bucket/mongodbIndexHandler_test.go @@ -64,16 +64,11 @@ func TestNewMongoDBIndexHandler(t *testing.T) { t.Run("empty bucket and put lastIndexKey fails", func(t *testing.T) { t.Parallel() - handler, err := NewMongoDBIndexHandler(&testscommon.StorerStub{ - HasCalled: func(key []byte) error { - assert.Equal(t, []byte(lastIndexKey), key) - return expectedErr - }, - PutCalled: func(key, data []byte) error { - assert.Equal(t, []byte(lastIndexKey), key) + handler, err := NewMongoDBIndexHandler(&testscommon.StorerStub{}, &testscommon.MongoDBClientStub{ + PutIndexIfNotExistsCalled: func(collID mongodb.CollectionID, key []byte, index uint32) error { return expectedErr }, - }, &testscommon.MongoDBClientStub{}) + }) assert.Equal(t, expectedErr, err) assert.True(t, check.IfNil(handler)) }) @@ -96,7 +91,7 @@ func TestMongoDBIndexHandler_Operations(t *testing.T) { return index, nil }, }, &testscommon.MongoDBClientStub{ - IncrementWithTransactionCalled: func(coll mongodb.CollectionID, key []byte) (uint32, error) { + IncrementIndexCalled: func(collID mongodb.CollectionID, key []byte) (uint32, error) { return 1, nil }, }) diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 0a8fa367..661c0b4d 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -2,6 +2,7 @@ package mongodb import ( "context" + "fmt" "time" "github.com/multiversx/multi-factor-auth-go-service/core" @@ -21,9 +22,13 @@ type CollectionID string const ( // UsersCollectionID specifies mongodb collection for users UsersCollectionID CollectionID = "users" + + // IndexCollectionID specifies mongodb collection for global index + IndexCollectionID CollectionID = "index" ) const initialCounterValue = 1 +const numInitialShardChunks = 4 type mongoEntry struct { Key string `bson:"_id"` @@ -35,6 +40,7 @@ type counterMongoEntry struct { Value uint32 `bson:"value"` } +// TODO: change with merged structures type otpInfoWrapper struct { Key string `bson:"_id"` *core.OTPInfo @@ -42,6 +48,7 @@ type otpInfoWrapper struct { type mongodbClient struct { client *mongo.Client + db *mongo.Database collections map[CollectionID]*mongo.Collection ctx context.Context } @@ -62,11 +69,15 @@ func NewClient(client *mongo.Client, dbName string) (*mongodbClient, error) { return nil, err } + database := client.Database(dbName) + collections := make(map[CollectionID]*mongo.Collection) - collections[UsersCollectionID] = client.Database(dbName).Collection(string(UsersCollectionID)) + collections[UsersCollectionID] = database.Collection(string(UsersCollectionID)) + collections[IndexCollectionID] = database.Collection(string(IndexCollectionID)) return &mongodbClient{ client: client, + db: database, collections: collections, ctx: ctx, }, nil @@ -99,32 +110,6 @@ func (mdc *mongodbClient) Put(collID CollectionID, key []byte, data []byte) erro return nil } -func (mdc *mongodbClient) PutIfNotExists(collID CollectionID, key []byte, data []byte) error { - coll, ok := mdc.collections[collID] - if !ok { - return ErrCollectionNotFound - } - - filter := bson.D{{Key: "_id", Value: string(key)}} - update := bson.D{{Key: "$setOnInsert", - Value: bson.D{ - {Key: "_id", Value: string(key)}, - {Key: "value", Value: data}, - }, - }} - - opts := options.Update().SetUpsert(true) - - res, err := coll.UpdateOne(mdc.ctx, filter, update, opts) - if err != nil { - return err - } - - log.Trace("PutIfNotExists", "key", string(key), "value", string(data), "modifiedCount", res.ModifiedCount) - - return nil -} - func (mdc *mongodbClient) PutStruct(collID CollectionID, key []byte, data *core.OTPInfo) error { coll, ok := mdc.collections[collID] if !ok { @@ -370,7 +355,7 @@ func runTxWithRetry(sctx mongo.SessionContext, txnFn func(mongo.SessionContext) } } -func (mdc *mongodbClient) PutIndex(collID CollectionID, key []byte, index uint32) error { +func (mdc *mongodbClient) PutIndexIfNotExists(collID CollectionID, key []byte, index uint32) error { coll, ok := mdc.collections[collID] if !ok { return ErrCollectionNotFound @@ -434,6 +419,29 @@ func (mdc *mongodbClient) IncrementIndex(collID CollectionID, key []byte) (uint3 return entry.Value, nil } +// ShardHashedCollection will shard collection with a hashed shard key +func (mdc *mongodbClient) ShardHashedCollection(collID CollectionID) error { + coll, ok := mdc.collections[collID] + if !ok { + return ErrCollectionNotFound + } + + collectionPath := fmt.Sprintf("%s.%s", mdc.db.Name(), coll.Name()) + + cmd := bson.D{ + {Key: "shardCollection", Value: collectionPath}, + {Key: "key", Value: bson.D{{Key: "_id", Value: "hashed"}}}, + {Key: "numInitialChunks", Value: numInitialShardChunks}, + } + + err := mdc.db.RunCommand(mdc.ctx, cmd).Err() + if err != nil { + return err + } + + return nil +} + // Close will close the mongodb client func (mdc *mongodbClient) Close() error { return mdc.client.Disconnect(mdc.ctx) diff --git a/mongodb/dbClient_test.go b/mongodb/dbClient_test.go index c00e9e6f..79fa787c 100644 --- a/mongodb/dbClient_test.go +++ b/mongodb/dbClient_test.go @@ -90,6 +90,57 @@ func TestMongoDBClient_Put(t *testing.T) { }) } +func TestMongoDBClient_PutIndexIfNotExists(t *testing.T) { + t.Parallel() + + mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) + defer mt.Close() + + mt.Run("collection not found", func(mt *mtest.T) { + mt.Parallel() + + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(mt, err) + + err = client.PutIndexIfNotExists("another coll", []byte("key1"), 1) + require.Equal(mt, mongodb.ErrCollectionNotFound, err) + }) + + mt.Run("should fail", func(mt *mtest.T) { + mt.Parallel() + + mt.AddMockResponses( + mtest.CreateCommandErrorResponse(mtest.CommandError{ + Code: 1, + Message: expectedErr.Error(), + }), + ) + + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(mt, err) + + err = client.PutIndexIfNotExists(mongodb.UsersCollectionID, []byte("key1"), 1) + require.Equal(mt, expectedErr.Error(), err.Error()) + }) + + mt.Run("should work", func(mt *mtest.T) { + mt.Parallel() + + mt.AddMockResponses( + mtest.CreateCursorResponse(1, "foo.bar", mtest.FirstBatch, bson.D{ + {Key: "_id", Value: "key2"}, + {Key: "value", Value: []byte("value")}, + }), + ) + + client, err := mongodb.NewClient(mt.Client, "dbName") + require.Nil(mt, err) + + err = client.PutIndexIfNotExists(mongodb.UsersCollectionID, []byte("key1"), 1) + require.Nil(mt, err) + }) +} + func TestMongoDBClient_Get(t *testing.T) { t.Parallel() @@ -234,7 +285,7 @@ func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { }), ) - _, err = client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) + _, err = client.IncrementIndex(mongodb.UsersCollectionID, []byte("key1")) require.Equal(mt, expectedErr.Error(), err.Error()) }) @@ -251,7 +302,7 @@ func TestMongoDBClient_IncrementWithTransaction(t *testing.T) { client, err := mongodb.NewClient(mt.Client, "dbName") require.Nil(mt, err) - val, err := client.IncrementWithTransaction(mongodb.UsersCollectionID, []byte("key1")) + val, err := client.IncrementIndex(mongodb.UsersCollectionID, []byte("key1")) require.Nil(mt, err) require.Equal(mt, uint32(1), val) }) @@ -379,27 +430,3 @@ func TestMongoDBClient_UpdateTimestamp(t *testing.T) { require.Greater(t, newTimestamp, currentTimestamp) }) } - -func TestMongoDBClient_PutIfNotExists(t *testing.T) { - t.Parallel() - - mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) - defer mt.Close() - - mt.Run("should work", func(mt *mtest.T) { - mt.Parallel() - - client, err := mongodb.NewClient(mt.Client, "dbName") - require.Nil(mt, err) - - mt.AddMockResponses( - bson.D{ - {Key: "ok", Value: 1}, - {Key: "value", Value: []byte("data1")}, - }, - ) - - err = client.PutIfNotExists(mongodb.UsersCollectionID, []byte("key1"), []byte("data1")) - require.Nil(mt, err) - }) -} diff --git a/mongodb/interface.go b/mongodb/interface.go index b09c50f5..2c711cfb 100644 --- a/mongodb/interface.go +++ b/mongodb/interface.go @@ -27,11 +27,10 @@ type MongoDBCollection interface { // MongoDBClient defines what a mongodb client should do type MongoDBClient interface { Put(coll CollectionID, key []byte, data []byte) error - PutIfNotExists(collID CollectionID, key []byte, data []byte) error Get(coll CollectionID, key []byte) ([]byte, error) Has(coll CollectionID, key []byte) error Remove(coll CollectionID, key []byte) error - PutIndex(collID CollectionID, key []byte, index uint32) error + PutIndexIfNotExists(collID CollectionID, key []byte, index uint32) error IncrementIndex(collID CollectionID, key []byte) (uint32, error) ReadWriteWithCheck( collID CollectionID, diff --git a/testscommon/mongoDBClientStub.go b/testscommon/mongoDBClientStub.go index 38a002de..59ae6ef4 100644 --- a/testscommon/mongoDBClientStub.go +++ b/testscommon/mongoDBClientStub.go @@ -6,17 +6,18 @@ import ( // MongoDBClientStub implemented mongodb client wraper interface type MongoDBClientStub struct { - PutCalled func(coll mongodb.CollectionID, key []byte, data []byte) error - GetCalled func(coll mongodb.CollectionID, key []byte) ([]byte, error) - HasCalled func(coll mongodb.CollectionID, key []byte) error - RemoveCalled func(coll mongodb.CollectionID, key []byte) error - IncrementWithTransactionCalled func(coll mongodb.CollectionID, key []byte) (uint32, error) - CloseCalled func() error - ReadWriteWithCheckCalled func( + PutCalled func(coll mongodb.CollectionID, key []byte, data []byte) error + GetCalled func(coll mongodb.CollectionID, key []byte) ([]byte, error) + HasCalled func(coll mongodb.CollectionID, key []byte) error + RemoveCalled func(coll mongodb.CollectionID, key []byte) error + IncrementIndexCalled func(collID mongodb.CollectionID, key []byte) (uint32, error) + CloseCalled func() error + ReadWriteWithCheckCalled func( collID mongodb.CollectionID, key []byte, checker func(data interface{}) (interface{}, error), ) error + PutIndexIfNotExistsCalled func(collID mongodb.CollectionID, key []byte, index uint32) error } // Put - @@ -55,10 +56,10 @@ func (m *MongoDBClientStub) Remove(coll mongodb.CollectionID, key []byte) error return nil } -// IncrementWithTransaction - -func (m *MongoDBClientStub) IncrementWithTransaction(coll mongodb.CollectionID, key []byte) (uint32, error) { - if m.IncrementWithTransactionCalled != nil { - return m.IncrementWithTransactionCalled(coll, key) +// IncrementIndex - +func (m *MongoDBClientStub) IncrementIndex(coll mongodb.CollectionID, key []byte) (uint32, error) { + if m.IncrementIndexCalled != nil { + return m.IncrementIndexCalled(coll, key) } return 0, nil @@ -77,6 +78,15 @@ func (m *MongoDBClientStub) ReadWriteWithCheck( return nil } +// PutIndexIfNotExists - +func (m *MongoDBClientStub) PutIndexIfNotExists(collID mongodb.CollectionID, key []byte, index uint32) error { + if m.PutIndexIfNotExistsCalled != nil { + return m.PutIndexIfNotExistsCalled(collID, key, index) + } + + return nil +} + // Close - func (m *MongoDBClientStub) Close() error { if m.CloseCalled != nil { From da235848106028ec596d7f79b480ad3074416a61 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Tue, 4 Apr 2023 13:59:01 +0300 Subject: [PATCH 31/35] manual session with transaction handling --- .../storage/mongo/mongodbStorerHandler.go | 58 +++++++- mongodb/dbClient.go | 134 +++++++++++++++++- mongodb/integrationTests/mongo_test.go | 24 +++- mongodb/interface.go | 27 ++-- 4 files changed, 223 insertions(+), 20 deletions(-) diff --git a/handlers/storage/mongo/mongodbStorerHandler.go b/handlers/storage/mongo/mongodbStorerHandler.go index 9ac59f18..8c27e8fc 100644 --- a/handlers/storage/mongo/mongodbStorerHandler.go +++ b/handlers/storage/mongo/mongodbStorerHandler.go @@ -1,6 +1,9 @@ package mongo import ( + "fmt" + "sync" + "github.com/multiversx/multi-factor-auth-go-service/core" "github.com/multiversx/multi-factor-auth-go-service/mongodb" logger "github.com/multiversx/mx-chain-logger-go" @@ -11,6 +14,9 @@ var log = logger.GetOrCreate("handlers/storage/mongo") type mongodbStorerHandler struct { client mongodb.MongoDBClient collection mongodb.CollectionID + sessions map[string]mongodb.Session + sessionCtx map[string]mongodb.SessionContext + mutSess sync.RWMutex } // NewMongoDBStorerHandler will create a new storer handler instance @@ -22,17 +28,65 @@ func NewMongoDBStorerHandler(client mongodb.MongoDBClient, collection mongodb.Co return &mongodbStorerHandler{ client: client, collection: collection, + sessions: make(map[string]mongodb.Session), + sessionCtx: make(map[string]mongodb.SessionContext), }, nil } // Put will set key value pair func (msh *mongodbStorerHandler) Put(key []byte, data []byte) error { - return msh.client.Put(msh.collection, key, data) + // return msh.client.Put(msh.collection, key, data) + log.Debug("StorerM: put", "key", key, "data", data) + + msh.mutSess.Lock() + defer msh.mutSess.Unlock() + + session, ok := msh.sessions[string(key)] + if !ok { + log.Trace("%w: could not find session for key %s", core.ErrInvalidValue, string(key)) + return fmt.Errorf("%w: could not find session for key %s", core.ErrInvalidValue, string(key)) + } + if session == nil { + log.Trace("nil session for key %s", string(key)) + return fmt.Errorf("nil session for key %s", string(key)) + } + + sessionCtx, ok := msh.sessionCtx[string(key)] + if !ok { + log.Trace("%w: could not find session context for key %s", core.ErrInvalidValue, string(key)) + return fmt.Errorf("%w: could not find session context for key %s", core.ErrInvalidValue, string(key)) + } + if sessionCtx == nil { + log.Trace("nil session context for key %s", string(key)) + return fmt.Errorf("nil session context for key %s", string(key)) + } + + err := msh.client.WriteWithTx(msh.collection, key, data, session, sessionCtx) + if err != nil { + log.Trace("StorerM: put", "key", key, "err", err.Error()) + return err + } + + return nil } // Get will return the value for the provided key func (msh *mongodbStorerHandler) Get(key []byte) ([]byte, error) { - return msh.client.Get(msh.collection, key) + //return msh.client.Get(msh.collection, key) + msh.mutSess.Lock() + defer msh.mutSess.Unlock() + + data, session, sessionCtx, err := msh.client.ReadWithTx(msh.collection, key) + msh.sessions[string(key)] = session + msh.sessionCtx[string(key)] = sessionCtx + if err != nil { + log.Debug("StorerM: get", "key", key, "err", err.Error()) + return nil, err + } + + log.Debug("StorerM: get", "key", key, "data", data) + + return data, nil } // Has will return true if the provided key exists in the database collection diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 661c0b4d..8805500a 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -6,12 +6,14 @@ import ( "time" "github.com/multiversx/multi-factor-auth-go-service/core" + "github.com/multiversx/multi-factor-auth-go-service/handlers/storage" logger "github.com/multiversx/mx-chain-logger-go" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" + "go.mongodb.org/mongo-driver/x/mongo/driver" ) var log = logger.GetOrCreate("mongodb") @@ -30,6 +32,8 @@ const ( const initialCounterValue = 1 const numInitialShardChunks = 4 +var withTransactionTimeout = 10 * time.Second + type mongoEntry struct { Key string `bson:"_id"` Value []byte `bson:"value"` @@ -336,21 +340,149 @@ func (mdc *mongodbClient) ReadWriteWithCheck( return nil } +func (mdc *mongodbClient) ReadWithTx( + collID CollectionID, + key []byte, +) ([]byte, Session, SessionContext, error) { + coll, ok := mdc.collections[collID] + if !ok { + return nil, nil, nil, ErrCollectionNotFound + } + + session, err := mdc.client.StartSession() + if err != nil { + return nil, nil, nil, err + } + + log.Trace("started session", "ID", session.ID()) + + sessionCtx := mongo.NewSessionContext(mdc.ctx, session) + + wc := writeconcern.New(writeconcern.WMajority()) + txnOptions := options.Transaction().SetWriteConcern(wc) + txnOptions.SetReadPreference(readpref.Primary()) + + err = session.StartTransaction(txnOptions) + if err != nil { + log.Trace("ReadWithTx: StartTransaction", "err", err.Error()) + return nil, nil, nil, err + } + + filter := bson.M{"_id": string(key)} + + entry := &mongoEntry{} + err = coll.FindOne(sessionCtx, filter).Decode(entry) + if err != nil { + // TODO: abort transaction and create a new one on write + //_ = session.AbortTransaction(sessionCtx) + log.Trace("ReadWithTx", "err", err.Error()) + return nil, session, sessionCtx, storage.ErrKeyNotFound + } + + log.Trace("ReadWithTx", "key", string(key), "value", entry.Value) + + return entry.Value, session, sessionCtx, nil +} + +func (mdc *mongodbClient) WriteWithTx( + collID CollectionID, + key []byte, + value []byte, + session Session, + sessionCtx SessionContext, +) error { + + txCallback := func(ctx mongo.SessionContext) error { + coll, ok := mdc.collections[collID] + if !ok { + return ErrCollectionNotFound + } + + filter := bson.M{"_id": string(key)} + update := bson.M{ + "$set": bson.M{ + "_id": string(key), + "value": value, + }, + } + + // filter := bson.D{{Key: "_id", Value: string(key)}} + // update := bson.D{{Key: "$set", + // Value: bson.D{ + // {Key: "_id", Value: string(key)}, + // {Key: "value", Value: value}, + // }, + // }} + + opts := options.Update().SetUpsert(true) + + _, err := coll.UpdateOne(sessionCtx, filter, update, opts) + if err != nil { + log.Trace("WriteWithTx: UpdateOne", "err", err.Error()) + //_ = session.AbortTransaction(sessionCtx) + return err + } + + log.Trace("WriteWithTx before commit", "key", string(key), "value", value) + + err = session.CommitTransaction(sessionCtx) + if err != nil { + log.Trace("WriteWithTx: CommitTransaction", "err", err.Error()) + //_ = session.AbortTransaction(mdc.ctx) + return err + } + + log.Trace("WriteWithTx", "key", string(key), "value", value) + + return nil + } + + err := txCallback(sessionCtx) + if err != nil { + abortErr := session.AbortTransaction(mdc.ctx) + if abortErr != nil { + return abortErr + } + + log.Trace("ended session", "ID", session.ID()) + session.EndSession(mdc.ctx) + return err + } + + log.Trace("ended session", "ID", session.ID()) + session.EndSession(mdc.ctx) + + return nil +} + func runTxWithRetry(sctx mongo.SessionContext, txnFn func(mongo.SessionContext) error) error { + timeout := time.NewTimer(withTransactionTimeout) + defer timeout.Stop() + for { err := txnFn(sctx) if err == nil { return nil } + time.Sleep(2 * time.Second) + log.Trace("Transaction aborted. Caught exception during transaction.") + select { + case <-timeout.C: + log.Trace("Transaction timeout reached.") + return err + default: + } + cmdErr, ok := err.(mongo.CommandError) - if ok && cmdErr.HasErrorLabel("TransientTransactionError") { + if ok && cmdErr.HasErrorLabel(driver.TransientTransactionError) { log.Trace("TransientTransactionError, retrying transaction...") continue } + log.Trace("other transaction error: %s", err.Error()) return err } } diff --git a/mongodb/integrationTests/mongo_test.go b/mongodb/integrationTests/mongo_test.go index 7df3174a..4f768a5c 100644 --- a/mongodb/integrationTests/mongo_test.go +++ b/mongodb/integrationTests/mongo_test.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/multi-factor-auth-go-service/config" "github.com/multiversx/multi-factor-auth-go-service/core" "github.com/multiversx/multi-factor-auth-go-service/mongodb" + logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tryvium-travels/memongo" @@ -16,6 +17,8 @@ import ( func TestMongoDBClient_ConcurrentCalls(t *testing.T) { t.Parallel() + logger.SetLogLevel("*:TRACE") + if os.Getenv("CI") != "" { t.Skip("Skipping testing in CI environment") } @@ -37,27 +40,34 @@ func TestMongoDBClient_ConcurrentCalls(t *testing.T) { err = client.PutStruct(mongodb.UsersCollectionID, []byte("key"), &core.OTPInfo{LastTOTPChangeTimestamp: 101}) require.Nil(t, err) - numCalls := 60 + numCalls := 20 var wg sync.WaitGroup wg.Add(numCalls) for i := 0; i < numCalls; i++ { go func(idx int) { - switch idx % 5 { + switch idx % 6 { case 0: err := client.PutStruct(mongodb.UsersCollectionID, []byte("key"), &core.OTPInfo{LastTOTPChangeTimestamp: 101}) - require.Nil(t, err) + assert.Nil(t, err) case 1: _, err := client.GetStruct(mongodb.UsersCollectionID, []byte("key")) - require.Nil(t, err) + assert.Nil(t, err) case 2: - require.Nil(t, client.HasStruct(mongodb.UsersCollectionID, []byte("key"))) + assert.Nil(t, client.HasStruct(mongodb.UsersCollectionID, []byte("key"))) case 3: err := client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key"), checker) - require.Nil(t, err) + assert.Nil(t, err) case 4: _, err := client.UpdateTimestamp(mongodb.UsersCollectionID, []byte("key"), 0) - require.Nil(t, err) + assert.Nil(t, err) + case 5: + _ = client.Put(mongodb.UsersCollectionID, []byte("key2"), []byte{1, 2, 3}) + _, sess, sessCtx, err := client.ReadWithTx(mongodb.UsersCollectionID, []byte("key2")) + assert.Nil(t, err) + + err = client.WriteWithTx(mongodb.UsersCollectionID, []byte("key"), []byte("data"), sess, sessCtx) + assert.Nil(t, err) default: assert.Fail(t, "should not hit default") } diff --git a/mongodb/interface.go b/mongodb/interface.go index 2c711cfb..a7b33e0d 100644 --- a/mongodb/interface.go +++ b/mongodb/interface.go @@ -8,6 +8,12 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) +// Session defines the behaviour of a mongodb session +type Session mongo.Session + +// SessionContext defines the behaviour of a mongodb session context +type SessionContext mongo.SessionContext + // MongoDBClientWrapper defines the methods for mongo db client wrapper type MongoDBClientWrapper interface { Connect(ctx context.Context) error @@ -37,6 +43,17 @@ type MongoDBClient interface { key []byte, checker func(data interface{}) (interface{}, error), ) error + ReadWithTx( + collID CollectionID, + key []byte, + ) ([]byte, Session, SessionContext, error) + WriteWithTx( + collID CollectionID, + key []byte, + value []byte, + session Session, + sessionCtx SessionContext, + ) error Close() error IsInterfaceNil() bool } @@ -49,13 +66,3 @@ type MongoDBUsersHandler interface { GetStruct(collID CollectionID, key []byte) (*core.OTPInfo, error) HasStruct(collID CollectionID, key []byte) error } - -// MongoDBSession defines what a mongodb session should do -type MongoDBSession interface { - StartTransaction(...*options.TransactionOptions) error - AbortTransaction(context.Context) error - CommitTransaction(context.Context) error - WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), - opts ...*options.TransactionOptions) (interface{}, error) - EndSession(context.Context) -} From 0660bd0602e19f5fcc94233e748a50945830a574 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Thu, 13 Apr 2023 10:38:38 +0300 Subject: [PATCH 32/35] Revert "manual session with transaction handling" This reverts commit da235848106028ec596d7f79b480ad3074416a61. --- .../storage/mongo/mongodbStorerHandler.go | 58 +------- mongodb/dbClient.go | 134 +----------------- mongodb/integrationTests/mongo_test.go | 24 +--- mongodb/interface.go | 27 ++-- 4 files changed, 20 insertions(+), 223 deletions(-) diff --git a/handlers/storage/mongo/mongodbStorerHandler.go b/handlers/storage/mongo/mongodbStorerHandler.go index 8c27e8fc..9ac59f18 100644 --- a/handlers/storage/mongo/mongodbStorerHandler.go +++ b/handlers/storage/mongo/mongodbStorerHandler.go @@ -1,9 +1,6 @@ package mongo import ( - "fmt" - "sync" - "github.com/multiversx/multi-factor-auth-go-service/core" "github.com/multiversx/multi-factor-auth-go-service/mongodb" logger "github.com/multiversx/mx-chain-logger-go" @@ -14,9 +11,6 @@ var log = logger.GetOrCreate("handlers/storage/mongo") type mongodbStorerHandler struct { client mongodb.MongoDBClient collection mongodb.CollectionID - sessions map[string]mongodb.Session - sessionCtx map[string]mongodb.SessionContext - mutSess sync.RWMutex } // NewMongoDBStorerHandler will create a new storer handler instance @@ -28,65 +22,17 @@ func NewMongoDBStorerHandler(client mongodb.MongoDBClient, collection mongodb.Co return &mongodbStorerHandler{ client: client, collection: collection, - sessions: make(map[string]mongodb.Session), - sessionCtx: make(map[string]mongodb.SessionContext), }, nil } // Put will set key value pair func (msh *mongodbStorerHandler) Put(key []byte, data []byte) error { - // return msh.client.Put(msh.collection, key, data) - log.Debug("StorerM: put", "key", key, "data", data) - - msh.mutSess.Lock() - defer msh.mutSess.Unlock() - - session, ok := msh.sessions[string(key)] - if !ok { - log.Trace("%w: could not find session for key %s", core.ErrInvalidValue, string(key)) - return fmt.Errorf("%w: could not find session for key %s", core.ErrInvalidValue, string(key)) - } - if session == nil { - log.Trace("nil session for key %s", string(key)) - return fmt.Errorf("nil session for key %s", string(key)) - } - - sessionCtx, ok := msh.sessionCtx[string(key)] - if !ok { - log.Trace("%w: could not find session context for key %s", core.ErrInvalidValue, string(key)) - return fmt.Errorf("%w: could not find session context for key %s", core.ErrInvalidValue, string(key)) - } - if sessionCtx == nil { - log.Trace("nil session context for key %s", string(key)) - return fmt.Errorf("nil session context for key %s", string(key)) - } - - err := msh.client.WriteWithTx(msh.collection, key, data, session, sessionCtx) - if err != nil { - log.Trace("StorerM: put", "key", key, "err", err.Error()) - return err - } - - return nil + return msh.client.Put(msh.collection, key, data) } // Get will return the value for the provided key func (msh *mongodbStorerHandler) Get(key []byte) ([]byte, error) { - //return msh.client.Get(msh.collection, key) - msh.mutSess.Lock() - defer msh.mutSess.Unlock() - - data, session, sessionCtx, err := msh.client.ReadWithTx(msh.collection, key) - msh.sessions[string(key)] = session - msh.sessionCtx[string(key)] = sessionCtx - if err != nil { - log.Debug("StorerM: get", "key", key, "err", err.Error()) - return nil, err - } - - log.Debug("StorerM: get", "key", key, "data", data) - - return data, nil + return msh.client.Get(msh.collection, key) } // Has will return true if the provided key exists in the database collection diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 8805500a..661c0b4d 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -6,14 +6,12 @@ import ( "time" "github.com/multiversx/multi-factor-auth-go-service/core" - "github.com/multiversx/multi-factor-auth-go-service/handlers/storage" logger "github.com/multiversx/mx-chain-logger-go" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" - "go.mongodb.org/mongo-driver/x/mongo/driver" ) var log = logger.GetOrCreate("mongodb") @@ -32,8 +30,6 @@ const ( const initialCounterValue = 1 const numInitialShardChunks = 4 -var withTransactionTimeout = 10 * time.Second - type mongoEntry struct { Key string `bson:"_id"` Value []byte `bson:"value"` @@ -340,149 +336,21 @@ func (mdc *mongodbClient) ReadWriteWithCheck( return nil } -func (mdc *mongodbClient) ReadWithTx( - collID CollectionID, - key []byte, -) ([]byte, Session, SessionContext, error) { - coll, ok := mdc.collections[collID] - if !ok { - return nil, nil, nil, ErrCollectionNotFound - } - - session, err := mdc.client.StartSession() - if err != nil { - return nil, nil, nil, err - } - - log.Trace("started session", "ID", session.ID()) - - sessionCtx := mongo.NewSessionContext(mdc.ctx, session) - - wc := writeconcern.New(writeconcern.WMajority()) - txnOptions := options.Transaction().SetWriteConcern(wc) - txnOptions.SetReadPreference(readpref.Primary()) - - err = session.StartTransaction(txnOptions) - if err != nil { - log.Trace("ReadWithTx: StartTransaction", "err", err.Error()) - return nil, nil, nil, err - } - - filter := bson.M{"_id": string(key)} - - entry := &mongoEntry{} - err = coll.FindOne(sessionCtx, filter).Decode(entry) - if err != nil { - // TODO: abort transaction and create a new one on write - //_ = session.AbortTransaction(sessionCtx) - log.Trace("ReadWithTx", "err", err.Error()) - return nil, session, sessionCtx, storage.ErrKeyNotFound - } - - log.Trace("ReadWithTx", "key", string(key), "value", entry.Value) - - return entry.Value, session, sessionCtx, nil -} - -func (mdc *mongodbClient) WriteWithTx( - collID CollectionID, - key []byte, - value []byte, - session Session, - sessionCtx SessionContext, -) error { - - txCallback := func(ctx mongo.SessionContext) error { - coll, ok := mdc.collections[collID] - if !ok { - return ErrCollectionNotFound - } - - filter := bson.M{"_id": string(key)} - update := bson.M{ - "$set": bson.M{ - "_id": string(key), - "value": value, - }, - } - - // filter := bson.D{{Key: "_id", Value: string(key)}} - // update := bson.D{{Key: "$set", - // Value: bson.D{ - // {Key: "_id", Value: string(key)}, - // {Key: "value", Value: value}, - // }, - // }} - - opts := options.Update().SetUpsert(true) - - _, err := coll.UpdateOne(sessionCtx, filter, update, opts) - if err != nil { - log.Trace("WriteWithTx: UpdateOne", "err", err.Error()) - //_ = session.AbortTransaction(sessionCtx) - return err - } - - log.Trace("WriteWithTx before commit", "key", string(key), "value", value) - - err = session.CommitTransaction(sessionCtx) - if err != nil { - log.Trace("WriteWithTx: CommitTransaction", "err", err.Error()) - //_ = session.AbortTransaction(mdc.ctx) - return err - } - - log.Trace("WriteWithTx", "key", string(key), "value", value) - - return nil - } - - err := txCallback(sessionCtx) - if err != nil { - abortErr := session.AbortTransaction(mdc.ctx) - if abortErr != nil { - return abortErr - } - - log.Trace("ended session", "ID", session.ID()) - session.EndSession(mdc.ctx) - return err - } - - log.Trace("ended session", "ID", session.ID()) - session.EndSession(mdc.ctx) - - return nil -} - func runTxWithRetry(sctx mongo.SessionContext, txnFn func(mongo.SessionContext) error) error { - timeout := time.NewTimer(withTransactionTimeout) - defer timeout.Stop() - for { err := txnFn(sctx) if err == nil { return nil } - time.Sleep(2 * time.Second) - log.Trace("Transaction aborted. Caught exception during transaction.") - select { - case <-timeout.C: - log.Trace("Transaction timeout reached.") - return err - default: - } - cmdErr, ok := err.(mongo.CommandError) - if ok && cmdErr.HasErrorLabel(driver.TransientTransactionError) { + if ok && cmdErr.HasErrorLabel("TransientTransactionError") { log.Trace("TransientTransactionError, retrying transaction...") continue } - log.Trace("other transaction error: %s", err.Error()) return err } } diff --git a/mongodb/integrationTests/mongo_test.go b/mongodb/integrationTests/mongo_test.go index 4f768a5c..7df3174a 100644 --- a/mongodb/integrationTests/mongo_test.go +++ b/mongodb/integrationTests/mongo_test.go @@ -8,7 +8,6 @@ import ( "github.com/multiversx/multi-factor-auth-go-service/config" "github.com/multiversx/multi-factor-auth-go-service/core" "github.com/multiversx/multi-factor-auth-go-service/mongodb" - logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tryvium-travels/memongo" @@ -17,8 +16,6 @@ import ( func TestMongoDBClient_ConcurrentCalls(t *testing.T) { t.Parallel() - logger.SetLogLevel("*:TRACE") - if os.Getenv("CI") != "" { t.Skip("Skipping testing in CI environment") } @@ -40,34 +37,27 @@ func TestMongoDBClient_ConcurrentCalls(t *testing.T) { err = client.PutStruct(mongodb.UsersCollectionID, []byte("key"), &core.OTPInfo{LastTOTPChangeTimestamp: 101}) require.Nil(t, err) - numCalls := 20 + numCalls := 60 var wg sync.WaitGroup wg.Add(numCalls) for i := 0; i < numCalls; i++ { go func(idx int) { - switch idx % 6 { + switch idx % 5 { case 0: err := client.PutStruct(mongodb.UsersCollectionID, []byte("key"), &core.OTPInfo{LastTOTPChangeTimestamp: 101}) - assert.Nil(t, err) + require.Nil(t, err) case 1: _, err := client.GetStruct(mongodb.UsersCollectionID, []byte("key")) - assert.Nil(t, err) + require.Nil(t, err) case 2: - assert.Nil(t, client.HasStruct(mongodb.UsersCollectionID, []byte("key"))) + require.Nil(t, client.HasStruct(mongodb.UsersCollectionID, []byte("key"))) case 3: err := client.ReadWriteWithCheck(mongodb.UsersCollectionID, []byte("key"), checker) - assert.Nil(t, err) + require.Nil(t, err) case 4: _, err := client.UpdateTimestamp(mongodb.UsersCollectionID, []byte("key"), 0) - assert.Nil(t, err) - case 5: - _ = client.Put(mongodb.UsersCollectionID, []byte("key2"), []byte{1, 2, 3}) - _, sess, sessCtx, err := client.ReadWithTx(mongodb.UsersCollectionID, []byte("key2")) - assert.Nil(t, err) - - err = client.WriteWithTx(mongodb.UsersCollectionID, []byte("key"), []byte("data"), sess, sessCtx) - assert.Nil(t, err) + require.Nil(t, err) default: assert.Fail(t, "should not hit default") } diff --git a/mongodb/interface.go b/mongodb/interface.go index a7b33e0d..2c711cfb 100644 --- a/mongodb/interface.go +++ b/mongodb/interface.go @@ -8,12 +8,6 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -// Session defines the behaviour of a mongodb session -type Session mongo.Session - -// SessionContext defines the behaviour of a mongodb session context -type SessionContext mongo.SessionContext - // MongoDBClientWrapper defines the methods for mongo db client wrapper type MongoDBClientWrapper interface { Connect(ctx context.Context) error @@ -43,17 +37,6 @@ type MongoDBClient interface { key []byte, checker func(data interface{}) (interface{}, error), ) error - ReadWithTx( - collID CollectionID, - key []byte, - ) ([]byte, Session, SessionContext, error) - WriteWithTx( - collID CollectionID, - key []byte, - value []byte, - session Session, - sessionCtx SessionContext, - ) error Close() error IsInterfaceNil() bool } @@ -66,3 +49,13 @@ type MongoDBUsersHandler interface { GetStruct(collID CollectionID, key []byte) (*core.OTPInfo, error) HasStruct(collID CollectionID, key []byte) error } + +// MongoDBSession defines what a mongodb session should do +type MongoDBSession interface { + StartTransaction(...*options.TransactionOptions) error + AbortTransaction(context.Context) error + CommitTransaction(context.Context) error + WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), + opts ...*options.TransactionOptions) (interface{}, error) + EndSession(context.Context) +} From 9a9f030e6bf24394eee6f1a854a00bf6f168b716 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Thu, 13 Apr 2023 11:09:56 +0300 Subject: [PATCH 33/35] remove duplicated files --- handlers/storage/dbOTPHandler.go.orig | 198 -------- handlers/storage/dbOTPHandler_test.go.orig | 455 ------------------- handlers/storage/mongodbStorerWrapper.go.tmp | 77 ---- 3 files changed, 730 deletions(-) delete mode 100644 handlers/storage/dbOTPHandler.go.orig delete mode 100644 handlers/storage/dbOTPHandler_test.go.orig delete mode 100644 handlers/storage/mongodbStorerWrapper.go.tmp diff --git a/handlers/storage/dbOTPHandler.go.orig b/handlers/storage/dbOTPHandler.go.orig deleted file mode 100644 index 9b92ac90..00000000 --- a/handlers/storage/dbOTPHandler.go.orig +++ /dev/null @@ -1,198 +0,0 @@ -package storage - -import ( - "fmt" - "sync" - "time" - - "github.com/multiversx/multi-factor-auth-go-service/core" - "github.com/multiversx/multi-factor-auth-go-service/handlers" - "github.com/multiversx/mx-chain-core-go/core/check" -) - -const ( - keySeparator = "_" - minDelayBetweenOTPUpdates = 1 -) - -// ArgDBOTPHandler is the DTO used to create a new instance of dbOTPHandler -type ArgDBOTPHandler struct { - DB core.StorageWithIndexChecker - TOTPHandler handlers.TOTPHandler - Marshaller core.Marshaller - DelayBetweenOTPUpdatesInSec int64 -} - -type dbOTPHandler struct { - db core.StorageWithIndexChecker - totpHandler handlers.TOTPHandler - marshaller core.Marshaller - getTimeHandler func() time.Time - delayBetweenOTPUpdatesInSec int64 - mut sync.RWMutex -} - -// NewDBOTPHandler returns a new instance of dbOTPHandler -func NewDBOTPHandler(args ArgDBOTPHandler) (*dbOTPHandler, error) { - err := checkArgDBOTPHandler(args) - if err != nil { - return nil, err - } - - handler := &dbOTPHandler{ - db: args.DB, - totpHandler: args.TOTPHandler, - getTimeHandler: time.Now, - marshaller: args.Marshaller, - delayBetweenOTPUpdatesInSec: args.DelayBetweenOTPUpdatesInSec, - } - - return handler, nil -} - -func checkArgDBOTPHandler(args ArgDBOTPHandler) error { - if check.IfNil(args.DB) { - return handlers.ErrNilDB - } - if check.IfNil(args.TOTPHandler) { - return handlers.ErrNilTOTPHandler - } - if check.IfNil(args.Marshaller) { - return handlers.ErrNilMarshaller - } - if args.DelayBetweenOTPUpdatesInSec < minDelayBetweenOTPUpdates { - return fmt.Errorf("%w for DelayBetweenOTPUpdatesInSec, got %d, min expected %d", - handlers.ErrInvalidValue, args.DelayBetweenOTPUpdatesInSec, minDelayBetweenOTPUpdates) - } - - return nil -} - -// Save saves the one time password if possible, otherwise returns an error -func (handler *dbOTPHandler) Save(account, guardian []byte, otp handlers.OTP) error { - if otp == nil { - return handlers.ErrNilOTP - } - - key := computeKey(account, guardian) - - // critical section, do not allow a second Put until this is done - handler.mut.Lock() - defer handler.mut.Unlock() - - err := handler.db.Has(key) - if err != nil { - return handler.saveNewOTP(key, otp) - } - - checker := func(data interface{}) (interface{}, error) { - otpInfoBytes, ok := data.([]byte) - if !ok { - return nil, core.ErrInvalidValue - } - - err := handler.checkOtpUpdateAllowed(otpInfoBytes) - if err != nil { - return nil, err - } - - buff, err := handler.getMarshalledOtpData(otp) - if err != nil { - return nil, err - } - - return buff, nil - } - - err = handler.db.UpdateWithCheck(key, checker) - if err != nil { - return err - } - - return nil -} - -// Get returns the one time password -func (handler *dbOTPHandler) Get(account, guardian []byte) (handlers.OTP, error) { - key := computeKey(account, guardian) - oldOTPInfo, err := handler.getOldOTPInfo(key) - if err != nil { - return nil, fmt.Errorf("%w, account %s and guardian %s", err, account, guardian) - } - - return handler.totpHandler.TOTPFromBytes(oldOTPInfo.OTP) -} - -func (handler *dbOTPHandler) getOldOTPInfo(key []byte) (*core.OTPInfo, error) { - oldOTPInfo, err := handler.db.Get(key) - if err != nil { - return nil, err - } - - otpInfo := &core.OTPInfo{} - err = handler.marshaller.Unmarshal(otpInfo, oldOTPInfo) - if err != nil { - return nil, err - } - - return otpInfo, nil -} - -func (handler *dbOTPHandler) getMarshalledOtpData(otp handlers.OTP) ([]byte, error) { - newOtpInfo := &core.OTPInfo{ - LastTOTPChangeTimestamp: handler.getTimeHandler().Unix(), - } - - var err error - newOtpInfo.OTP, err = otp.ToBytes() - if err != nil { - return nil, err - } - - return handler.marshaller.Marshal(newOtpInfo) -} - -func (handler *dbOTPHandler) checkOtpUpdateAllowed(otpInfoBytes []byte) error { - otpInfo := &core.OTPInfo{} - err := handler.marshaller.Unmarshal(otpInfo, otpInfoBytes) - if err != nil { - return err - } - - currentTimestamp := handler.getTimeHandler().Unix() - isOTPUpdateAllowed := otpInfo.LastTOTPChangeTimestamp+handler.delayBetweenOTPUpdatesInSec < currentTimestamp - if !isOTPUpdateAllowed { - return fmt.Errorf("%w, last update was %d seconds ago", - handlers.ErrRegistrationFailed, currentTimestamp-otpInfo.LastTOTPChangeTimestamp) - } - - return nil -} - -func (handler *dbOTPHandler) saveNewOTP(key []byte, otp handlers.OTP) error { - otpInfo := &core.OTPInfo{ - LastTOTPChangeTimestamp: handler.getTimeHandler().Unix(), - } - - var err error - otpInfo.OTP, err = otp.ToBytes() - if err != nil { - return err - } - - buff, err := handler.marshaller.Marshal(otpInfo) - if err != nil { - return err - } - - return handler.db.Put(key, buff) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (handler *dbOTPHandler) IsInterfaceNil() bool { - return handler == nil -} - -func computeKey(account, guardian []byte) []byte { - return []byte(fmt.Sprintf("%s%s%s", guardian, keySeparator, account)) -} diff --git a/handlers/storage/dbOTPHandler_test.go.orig b/handlers/storage/dbOTPHandler_test.go.orig deleted file mode 100644 index bbcab4bc..00000000 --- a/handlers/storage/dbOTPHandler_test.go.orig +++ /dev/null @@ -1,455 +0,0 @@ -package storage_test - -import ( - "errors" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/multiversx/multi-factor-auth-go-service/core" - "github.com/multiversx/multi-factor-auth-go-service/handlers" - "github.com/multiversx/multi-factor-auth-go-service/handlers/storage" - "github.com/multiversx/multi-factor-auth-go-service/testscommon" - "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-core-go/data/mock" - "github.com/stretchr/testify/assert" -) - -var expectedErr = errors.New("expected error") - -func createMockArgs() storage.ArgDBOTPHandler { - return storage.ArgDBOTPHandler{ - DB: testscommon.NewShardedStorageWithIndexMock(), - TOTPHandler: &testscommon.TOTPHandlerStub{}, - Marshaller: &testscommon.MarshallerStub{}, - DelayBetweenOTPUpdatesInSec: 5, - } -} - -func TestNewDBOTPHandler(t *testing.T) { - t.Parallel() - - t.Run("nil db should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - args.DB = nil - handler, err := storage.NewDBOTPHandler(args) - assert.Equal(t, handlers.ErrNilDB, err) - assert.True(t, check.IfNil(handler)) - }) - t.Run("nil totp handler should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - args.TOTPHandler = nil - handler, err := storage.NewDBOTPHandler(args) - assert.Equal(t, handlers.ErrNilTOTPHandler, err) - assert.True(t, check.IfNil(handler)) - }) - t.Run("nil marshaller should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - args.Marshaller = nil - handler, err := storage.NewDBOTPHandler(args) - assert.Equal(t, handlers.ErrNilMarshaller, err) - assert.True(t, check.IfNil(handler)) - }) - t.Run("invalid delay should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - args.DelayBetweenOTPUpdatesInSec = 0 - handler, err := storage.NewDBOTPHandler(args) - assert.True(t, errors.Is(err, handlers.ErrInvalidValue)) - assert.True(t, strings.Contains(err.Error(), "DelayBetweenOTPUpdatesInSec")) - assert.True(t, check.IfNil(handler)) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - handler, err := storage.NewDBOTPHandler(createMockArgs()) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - }) -} - -func TestDBOTPHandler_Save(t *testing.T) { - t.Parallel() - - t.Run("nil otp should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - err = handler.Save([]byte("account"), []byte("guardian"), nil) - assert.Equal(t, handlers.ErrNilOTP, err) - }) - t.Run("ToBytes returns error", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - providedOTP := &testscommon.TotpStub{ - ToBytesCalled: func() ([]byte, error) { - return nil, expectedErr - }, - } - err = handler.Save([]byte("account"), []byte("guardian"), providedOTP) - assert.Equal(t, expectedErr, err) - }) - t.Run("new account but marshal fails", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - args.Marshaller = &testscommon.MarshallerStub{ - MarshalCalled: func(obj interface{}) ([]byte, error) { - return nil, expectedErr - }, - } - args.DB = &testscommon.ShardedStorageWithIndexStub{ - PutCalled: func(key, data []byte) error { - assert.Fail(t, "should have not been called") - return nil - }, - HasCalled: func(key []byte) error { - return errors.New("new account") - }, - } - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - providedOTPBytes := []byte("provided otp") - providedOTP := &testscommon.TotpStub{ - ToBytesCalled: func() ([]byte, error) { - return providedOTPBytes, nil - }, - } - err = handler.Save([]byte("account"), []byte("guardian"), providedOTP) - assert.Equal(t, expectedErr, err) - }) - t.Run("new account but save to db fails", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - args.DB = &testscommon.ShardedStorageWithIndexStub{ - PutCalled: func(key, data []byte) error { - return expectedErr - }, - HasCalled: func(key []byte) error { - return errors.New("new account") - }, - } - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - providedOTPBytes := []byte("provided otp") - providedOTP := &testscommon.TotpStub{ - ToBytesCalled: func() ([]byte, error) { - return providedOTPBytes, nil - }, - } - err = handler.Save([]byte("account"), []byte("guardian"), providedOTP) - assert.Equal(t, expectedErr, err) - }) - t.Run("new account should save to db", func(t *testing.T) { - t.Parallel() - - providedOTPBytes := []byte("provided otp") - args := createMockArgs() - wasCalled := false - args.Marshaller = &mock.MarshalizerMock{} - args.DB = &testscommon.ShardedStorageWithIndexStub{ - PutCalled: func(key, val []byte) error { - assert.Equal(t, []byte("guardian_account"), key) - otpInfo := &core.OTPInfo{} - assert.Nil(t, args.Marshaller.Unmarshal(otpInfo, val)) - assert.Equal(t, providedOTPBytes, otpInfo.OTP) - wasCalled = true - return nil - }, - HasCalled: func(key []byte) error { - return errors.New("new account") - }, - } - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - providedOTP := &testscommon.TotpStub{ - ToBytesCalled: func() ([]byte, error) { - return providedOTPBytes, nil - }, - } - err = handler.Save([]byte("account"), []byte("guardian"), providedOTP) - assert.Nil(t, err) - assert.True(t, wasCalled) - }) - t.Run("old account, get old otp fails", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - args.DB = &testscommon.ShardedStorageWithIndexStub{ - UpdateWithCheckCalled: func(key []byte, fn func(data interface{}) (interface{}, error)) error { - _, err := args.DB.Get(key) - return err - }, - GetCalled: func(key []byte) ([]byte, error) { - return nil, expectedErr - }, - } - args.Marshaller = &mock.MarshalizerMock{} - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - err = handler.Save([]byte("account"), []byte("guardian"), &testscommon.TotpStub{}) - assert.Equal(t, expectedErr, err) - }) - t.Run("old account, get old otp unmarshal fails", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - args.DB = &testscommon.ShardedStorageWithIndexStub{ - UpdateWithCheckCalled: func(key []byte, fn func(data interface{}) (interface{}, error)) error { - _, err := fn([]byte("badEncodedData")) - return err - }, - } - args.Marshaller = &testscommon.MarshallerStub{ - UnmarshalCalled: func(obj interface{}, buff []byte) error { - return expectedErr - }, - } - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - err = handler.Save([]byte("account"), []byte("guardian"), &testscommon.TotpStub{}) - assert.Equal(t, expectedErr, err) - }) - t.Run("old account, same guardian, different otp fails - too early", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - args.DelayBetweenOTPUpdatesInSec = 10 - providedOTPBytes := []byte("provided otp") - providedNewOTPBytes := []byte("provided new otp") - toBytesCounter := 0 - providedOTP := &testscommon.TotpStub{ - ToBytesCalled: func() ([]byte, error) { - toBytesCounter++ - if toBytesCounter > 1 { - return providedNewOTPBytes, nil - } - return providedOTPBytes, nil - }, - } - args.DB = testscommon.NewShardedStorageWithIndexMock() - args.Marshaller = &mock.MarshalizerMock{} - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - err = handler.Save([]byte("account"), []byte("guardian"), providedOTP) - assert.Nil(t, err) - otpInfoBuff, err := args.DB.Get([]byte("guardian_account")) - assert.Nil(t, err) - otpInfo := &core.OTPInfo{} - err = args.Marshaller.Unmarshal(otpInfo, otpInfoBuff) - assert.Nil(t, err) - assert.Equal(t, providedOTPBytes, otpInfo.OTP) - currentTime := time.Now().Unix() - timeDiff := currentTime - otpInfo.LastTOTPChangeTimestamp - assert.LessOrEqual(t, timeDiff, int64(1)) - - time.Sleep(time.Second) - // second call too early fails - err = handler.Save([]byte("account"), []byte("guardian"), providedOTP) - assert.True(t, errors.Is(err, handlers.ErrRegistrationFailed)) - otpInfoBuff, err = args.DB.Get([]byte("guardian_account")) - assert.Nil(t, err) - err = args.Marshaller.Unmarshal(otpInfo, otpInfoBuff) - assert.Nil(t, err) - assert.Equal(t, providedOTPBytes, otpInfo.OTP) - currentTime = time.Now().Unix() - timeDiff = currentTime - otpInfo.LastTOTPChangeTimestamp - assert.GreaterOrEqual(t, timeDiff, int64(1)) - }) - t.Run("old account, same guardian, different otp should update and save", func(t *testing.T) { - t.Parallel() - - providedOTPBytes := []byte("provided otp") - providedNewOTPBytes := []byte("provided new otp") - args := createMockArgs() - args.DelayBetweenOTPUpdatesInSec = 1 - args.DB = testscommon.NewShardedStorageWithIndexMock() - args.Marshaller = &mock.MarshalizerMock{} - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - counter := 0 - providedOTP := &testscommon.TotpStub{ - ToBytesCalled: func() ([]byte, error) { - counter++ - if counter > 1 { - return providedNewOTPBytes, nil - } - return providedOTPBytes, nil - }, - } - err = handler.Save([]byte("account"), []byte("guardian"), providedOTP) - assert.Nil(t, err) - otpInfoBuff, err := args.DB.Get([]byte("guardian_account")) - assert.Nil(t, err) - otpInfo := &core.OTPInfo{} - err = args.Marshaller.Unmarshal(otpInfo, otpInfoBuff) - assert.Nil(t, err) - assert.Equal(t, providedOTPBytes, otpInfo.OTP) - currentTime := time.Now().Unix() - timeDiff := currentTime - otpInfo.LastTOTPChangeTimestamp - assert.LessOrEqual(t, timeDiff, int64(1)) - - time.Sleep(time.Duration(args.DelayBetweenOTPUpdatesInSec+1) * time.Second) - err = handler.Save([]byte("account"), []byte("guardian"), providedOTP) - assert.Nil(t, err) - otpInfoBuff, err = args.DB.Get([]byte("guardian_account")) - assert.Nil(t, err) - err = args.Marshaller.Unmarshal(otpInfo, otpInfoBuff) - assert.Nil(t, err) - assert.Equal(t, providedNewOTPBytes, otpInfo.OTP) - currentTime = time.Now().Unix() - timeDiff = currentTime - otpInfo.LastTOTPChangeTimestamp - assert.LessOrEqual(t, timeDiff, int64(1)) - }) - t.Run("multiple concurrent calls should work", func(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - t.Parallel() - - args := createMockArgs() - args.DelayBetweenOTPUpdatesInSec = 5 - mockDB := testscommon.NewShardedStorageWithIndexMock() - putCounter := uint32(0) - args.DB = &testscommon.ShardedStorageWithIndexStub{ - HasCalled: func(key []byte) error { - return mockDB.Has(key) - }, - PutCalled: func(key, data []byte) error { - atomic.AddUint32(&putCounter, 1) - return mockDB.Put(key, data) - }, - GetCalled: func(key []byte) ([]byte, error) { - return mockDB.Get(key) - }, - UpdateWithCheckCalled: func(key []byte, fn func(data interface{}) (interface{}, error)) error { - data, err := mockDB.Get(key) - if err != nil { - return err - } - - newData, err := fn(data) - if err != nil { - return err - } - - atomic.AddUint32(&putCounter, 1) - return mockDB.Put(key, newData.([]byte)) - }, - } - args.Marshaller = &mock.MarshalizerMock{} - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - - numCalls := 120 - var wg sync.WaitGroup - wg.Add(numCalls) - for i := 0; i < numCalls; i++ { - go func() { - defer wg.Done() - _ = handler.Save([]byte("account"), []byte("guardian"), &testscommon.TotpStub{}) - }() - // 50 calls/5 sec => 3 times Put called - time.Sleep(time.Millisecond * 100) - } - - wg.Wait() - assert.Equal(t, uint32(3), atomic.LoadUint32(&putCounter)) - }) -} - -func TestDBOTPHandler_Get(t *testing.T) { - t.Parallel() - - t.Run("missing account should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - otp, err := handler.Get([]byte("account2"), []byte("guardian")) - assert.NotNil(t, err) - assert.Nil(t, otp) - }) - t.Run("missing guardian for account should error", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - providedOTPBytes := []byte("provided otp") - providedOTP := &testscommon.TotpStub{ - ToBytesCalled: func() ([]byte, error) { - return providedOTPBytes, nil - }, - } - err = handler.Save([]byte("account"), []byte("guardian"), providedOTP) - assert.Nil(t, err) - otp, err := handler.Get([]byte("account"), []byte("guardian2")) - assert.NotNil(t, err) - assert.Nil(t, otp) - }) - t.Run("should work", func(t *testing.T) { - t.Parallel() - - args := createMockArgs() - providedOTPBytes := []byte("provided otp") - providedOTP := &testscommon.TotpStub{ - ToBytesCalled: func() ([]byte, error) { - return providedOTPBytes, nil - }, - } - args.TOTPHandler = &testscommon.TOTPHandlerStub{ - TOTPFromBytesCalled: func(encryptedMessage []byte) (handlers.OTP, error) { - return providedOTP, nil - }, - } - handler, err := storage.NewDBOTPHandler(args) - assert.Nil(t, err) - assert.False(t, check.IfNil(handler)) - - err = handler.Save([]byte("account"), []byte("guardian"), providedOTP) - assert.Nil(t, err) - otp, err := handler.Get([]byte("account"), []byte("guardian")) - assert.Nil(t, err) - assert.Equal(t, providedOTP, otp) - }) -} diff --git a/handlers/storage/mongodbStorerWrapper.go.tmp b/handlers/storage/mongodbStorerWrapper.go.tmp deleted file mode 100644 index 1d7a5d01..00000000 --- a/handlers/storage/mongodbStorerWrapper.go.tmp +++ /dev/null @@ -1,77 +0,0 @@ -package storage - -import ( - "github.com/multiversx/multi-factor-auth-go-service/core" - "github.com/multiversx/multi-factor-auth-go-service/mongodb" -) - -// ArgMongoDBStorerWrapper defines the fields needed to create a new storer wrapper -type ArgMongoDBStorerWrapper struct { - Storer core.StorageWithIndex - Client mongodb.MongoDBClient - Marshaller core.Marshaller -} - -type mongodbStorerWrapper struct { - marshaller core.Marshaller - storer core.StorageWithIndex -} - -func NewMongoDBStorerWrapper(args ArgMongoDBStorerWrapper) (*mongodbStorerWrapper, error) { - return &mongodbStorerWrapper{ - marshaller: args.Marshaller, - storer: args.Storer, - }, nil -} - -func (usw *mongodbStorerWrapper) Load(key []byte) (*core.OTPInfo, error) { - otpInfo, err := usw.getFromStorage(key) - if err != nil { - return nil, err - } - - return usw.decrypt(otpInfo) -} - -func (usw *mongodbStorerWrapper) Save(key []byte, otpInfo *core.OTPInfo) error { - encryptedOTPInfo, err := usw.encrypt(otpInfo) - if err != nil { - return err - } - - buff, err := usw.marshaller.Marshal(encryptedOTPInfo) - if err != nil { - return err - } - - return usw.storer.Put(key, buff) -} - -func (usw *mongodbStorerWrapper) getFromStorage(key []byte) (*core.OTPInfo, error) { - oldOTPInfo, err := usw.storer.Get(key) - if err != nil { - return nil, err - } - - otpInfo := &core.OTPInfo{} - err = usw.marshaller.Unmarshal(otpInfo, oldOTPInfo) - if err != nil { - return nil, err - } - - return otpInfo, nil -} - -// TODO: implement encryption -func (usw *mongodbStorerWrapper) decrypt(otpInfo *core.OTPInfo) (*core.OTPInfo, error) { - return otpInfo, nil -} - -func (usw *mongodbStorerWrapper) encrypt(otpInfo *core.OTPInfo) (*core.OTPInfo, error) { - return otpInfo, nil -} - -// IsInterfaceNil return true if there is no value under the interface -func (usw *mongodbStorerWrapper) IsInterfaceNil() bool { - return usw == nil -} From e5651b8704dfbe5df82de4be05bd1167d082c325 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Thu, 13 Apr 2023 11:12:25 +0300 Subject: [PATCH 34/35] restore bucket index handler changes --- handlers/storage/bucket/bucketIndexHandler.go | 55 ++++++------------- 1 file changed, 16 insertions(+), 39 deletions(-) diff --git a/handlers/storage/bucket/bucketIndexHandler.go b/handlers/storage/bucket/bucketIndexHandler.go index 5af2b2f6..ac7fc822 100644 --- a/handlers/storage/bucket/bucketIndexHandler.go +++ b/handlers/storage/bucket/bucketIndexHandler.go @@ -14,9 +14,8 @@ const ( ) type bucketIndexHandler struct { - bucket core.Storer - bucketMut sync.RWMutex - bucketOpMut sync.RWMutex + bucket core.Storer + mut sync.RWMutex } // NewBucketIndexHandler returns a new instance of a bucket index handler @@ -34,7 +33,7 @@ func NewBucketIndexHandler(bucket core.Storer) (*bucketIndexHandler, error) { return handler, nil } - err = saveNewIndex(handler.bucket, 0) + err = handler.saveNewIndex(0) if err != nil { return nil, err } @@ -44,17 +43,17 @@ func NewBucketIndexHandler(bucket core.Storer) (*bucketIndexHandler, error) { // AllocateBucketIndex allocates a new index and returns it func (handler *bucketIndexHandler) AllocateBucketIndex() (uint32, error) { - handler.bucketMut.Lock() - defer handler.bucketMut.Unlock() + handler.mut.Lock() + defer handler.mut.Unlock() - index, err := getIndex(handler.bucket) + index, err := handler.getIndex() if err != nil { return 0, err } index++ - return index, saveNewIndex(handler.bucket, index) + return index, handler.saveNewIndex(index) } // Put adds data to the bucket @@ -72,47 +71,25 @@ func (handler *bucketIndexHandler) Has(key []byte) error { return handler.bucket.Has(key) } -// UpdateWithCheck will update key value pair based on callback function -func (handler *bucketIndexHandler) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { - handler.bucketOpMut.Lock() - defer handler.bucketOpMut.Unlock() - - data, err := handler.bucket.Get(key) - if err != nil { - return nil - } - - newData, err := fn(data) - if err != nil { - return err - } - newDataBytes, ok := newData.([]byte) - if !ok { - return core.ErrInvalidValue - } - - return handler.bucket.Put(key, newDataBytes) -} - // GetLastIndex returns the last index that was allocated func (handler *bucketIndexHandler) GetLastIndex() (uint32, error) { - handler.bucketMut.RLock() - defer handler.bucketMut.RUnlock() + handler.mut.RLock() + defer handler.mut.RUnlock() - return getIndex(handler.bucket) + return handler.getIndex() } // Close closes the internal bucket func (handler *bucketIndexHandler) Close() error { - handler.bucketMut.Lock() - defer handler.bucketMut.Unlock() + handler.mut.Lock() + defer handler.mut.Unlock() return handler.bucket.Close() } // must be called under mutex protection -func getIndex(storer core.Storer) (uint32, error) { - lastIndexBytes, err := storer.Get([]byte(lastIndexKey)) +func (handler *bucketIndexHandler) getIndex() (uint32, error) { + lastIndexBytes, err := handler.bucket.Get([]byte(lastIndexKey)) if err != nil { return 0, err } @@ -121,10 +98,10 @@ func getIndex(storer core.Storer) (uint32, error) { } // must be called under mutex protection -func saveNewIndex(storer core.Storer, newIndex uint32) error { +func (handler *bucketIndexHandler) saveNewIndex(newIndex uint32) error { latestIndexBytes := make([]byte, uint32Bytes) binary.BigEndian.PutUint32(latestIndexBytes, newIndex) - return storer.Put([]byte(lastIndexKey), latestIndexBytes) + return handler.bucket.Put([]byte(lastIndexKey), latestIndexBytes) } // IsInterfaceNil returns true if there is no value under the interface From 2e20459c873639533aace8d8e6f3f1ebb8429d09 Mon Sep 17 00:00:00 2001 From: ssd04 Date: Thu, 13 Apr 2023 11:18:28 +0300 Subject: [PATCH 35/35] cleanup unused code --- mongodb/dbClient.go | 1 - mongodb/interface.go | 27 ---------------------- testscommon/bucketIndexHandlerStub.go | 10 -------- testscommon/shardedStorageWithIndexMock.go | 14 ----------- testscommon/shardedStorageWithIndexStub.go | 10 -------- 5 files changed, 62 deletions(-) diff --git a/mongodb/dbClient.go b/mongodb/dbClient.go index 41cce350..93889ff4 100644 --- a/mongodb/dbClient.go +++ b/mongodb/dbClient.go @@ -28,7 +28,6 @@ const ( IndexCollectionID CollectionID = "index" ) -const initialCounterValue = 1 const numInitialShardChunks = 4 const incrementIndexStep = 1 const minNumUsersColls = 1 diff --git a/mongodb/interface.go b/mongodb/interface.go index 1cbe7687..52a68b88 100644 --- a/mongodb/interface.go +++ b/mongodb/interface.go @@ -1,13 +1,5 @@ package mongodb -import ( - "context" - - "github.com/multiversx/multi-factor-auth-go-service/core" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" -) - // MongoDBClient defines what a mongodb client should do type MongoDBClient interface { Put(coll CollectionID, key []byte, data []byte) error @@ -21,22 +13,3 @@ type MongoDBClient interface { Close() error IsInterfaceNil() bool } - -// MongoDBUsersHandler defines the behaviour of a mongo users handler component -type MongoDBUsersHandler interface { - MongoDBClient - UpdateTimestamp(collID CollectionID, key []byte, interval int64) (int64, error) - PutStruct(collID CollectionID, key []byte, data *core.OTPInfo) error - GetStruct(collID CollectionID, key []byte) (*core.OTPInfo, error) - HasStruct(collID CollectionID, key []byte) error -} - -// MongoDBSession defines what a mongodb session should do -type MongoDBSession interface { - StartTransaction(...*options.TransactionOptions) error - AbortTransaction(context.Context) error - CommitTransaction(context.Context) error - WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), - opts ...*options.TransactionOptions) (interface{}, error) - EndSession(context.Context) -} diff --git a/testscommon/bucketIndexHandlerStub.go b/testscommon/bucketIndexHandlerStub.go index 74351094..a69413b7 100644 --- a/testscommon/bucketIndexHandlerStub.go +++ b/testscommon/bucketIndexHandlerStub.go @@ -8,7 +8,6 @@ type BucketIndexHandlerStub struct { CloseCalled func() error AllocateBucketIndexCalled func() (uint32, error) GetLastIndexCalled func() (uint32, error) - UpdateWithCheckCalled func(key []byte, fn func(data interface{}) (interface{}, error)) error } // Put - @@ -59,15 +58,6 @@ func (stub *BucketIndexHandlerStub) GetLastIndex() (uint32, error) { return 0, nil } -// UpdateWithCheck - -func (stub *BucketIndexHandlerStub) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { - if stub.UpdateWithCheckCalled != nil { - return stub.UpdateWithCheckCalled(key, fn) - } - - return nil -} - // IsInterfaceNil - func (stub *BucketIndexHandlerStub) IsInterfaceNil() bool { return stub == nil diff --git a/testscommon/shardedStorageWithIndexMock.go b/testscommon/shardedStorageWithIndexMock.go index c9b0c214..3f451ecc 100644 --- a/testscommon/shardedStorageWithIndexMock.go +++ b/testscommon/shardedStorageWithIndexMock.go @@ -66,20 +66,6 @@ func (mock *shardedStorageWithIndexMock) Count() (uint32, error) { return uint32(len(mock.cache)), nil } -func (mock *shardedStorageWithIndexMock) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { - data, err := mock.Get(key) - if err != nil { - return err - } - - newData, err := fn(data) - if err != nil { - return err - } - - return mock.Put(key, newData.([]byte)) -} - // IsInterfaceNil - func (mock *shardedStorageWithIndexMock) IsInterfaceNil() bool { return mock == nil diff --git a/testscommon/shardedStorageWithIndexStub.go b/testscommon/shardedStorageWithIndexStub.go index 6d07deb4..f13cae93 100644 --- a/testscommon/shardedStorageWithIndexStub.go +++ b/testscommon/shardedStorageWithIndexStub.go @@ -9,7 +9,6 @@ type ShardedStorageWithIndexStub struct { CloseCalled func() error AllocateBucketIndexCalled func(address []byte) (uint32, error) CountCalled func() (uint32, error) - UpdateWithCheckCalled func(key []byte, fn func(data interface{}) (interface{}, error)) error } // AllocateIndex - @@ -68,15 +67,6 @@ func (stub *ShardedStorageWithIndexStub) Count() (uint32, error) { return 0, nil } -// UpdateWithCheck - -func (stub *ShardedStorageWithIndexStub) UpdateWithCheck(key []byte, fn func(data interface{}) (interface{}, error)) error { - if stub.UpdateWithCheckCalled != nil { - return stub.UpdateWithCheckCalled(key, fn) - } - - return nil -} - // IsInterfaceNil - func (stub *ShardedStorageWithIndexStub) IsInterfaceNil() bool { return stub == nil