diff --git a/api/groups/common_test.go b/api/groups/common_test.go index 5f6f77b8..f50e148d 100644 --- a/api/groups/common_test.go +++ b/api/groups/common_test.go @@ -48,6 +48,7 @@ func getServiceRoutesConfig() config.ApiRoutesConfig { {Name: "/sign-multiple-transactions", Open: true}, {Name: "/set-security-mode", Open: true}, {Name: "/unset-security-mode", Open: true}, + {Name: "/user-status/:address", Open: true}, {Name: "/debug", Open: true}, {Name: "/verify-code", Open: true}, {Name: "/registered-users", Open: true}, diff --git a/api/groups/guardianGroup.go b/api/groups/guardianGroup.go index 8d7768e0..bde9a35b 100644 --- a/api/groups/guardianGroup.go +++ b/api/groups/guardianGroup.go @@ -30,6 +30,7 @@ const ( signMultipleTransactionsPath = "/sign-multiple-transactions" setSecurityModeNoExpirePath = "/set-security-mode" unsetSecurityModeNoExpirePath = "/unset-security-mode" + getUserStatusPath = "/user-status/:address" registerPath = "/register" verifyCodePath = "/verify-code" registeredUsersPath = "/registered-users" @@ -103,6 +104,11 @@ func NewGuardianGroup(facade shared.FacadeHandler) (*guardianGroup, error) { Method: http.MethodGet, Handler: gg.config, }, + { + Path: getUserStatusPath, + Method: http.MethodGet, + Handler: gg.getUserStatus, + }, } gg.endpoints = endpoints @@ -237,6 +243,47 @@ func (gg *guardianGroup) unsetSecurityModeNoExpire(c *gin.Context) { returnStatus(c, nil, http.StatusOK, "", chainApiShared.ReturnCodeSuccess) } +func (gg *guardianGroup) getUserStatus(c *gin.Context) { + var debugErr error + + userIp := c.GetString(mfaMiddleware.UserIpKey) + userAgent := c.GetString(mfaMiddleware.UserAgentKey) + userAddr := c.Param("address") + + defer func() { + logUserStatusRequest(userIp, userAgent, userAddr, debugErr) + }() + + status, err := gg.facade.GetUserStatus(userAddr) + if err != nil { + debugErr = fmt.Errorf("%w while interrogating security status", err) + handleErrorAndReturn(c, status, err.Error()) + return + + } + + returnStatus(c, status, http.StatusOK, "", chainApiShared.ReturnCodeSuccess) +} + +func logUserStatusRequest(userIp string, userAgent string, userAddr string, debugErr error) { + logArgs := []interface{}{ + "route", getUserStatusPath, + "ip", userIp, + "user agent", userAgent, + "userAddr", userAddr, + } + defer func() { + guardianLog.Info("Request info", logArgs...) + }() + + if debugErr == nil { + logArgs = append(logArgs, "result", "success") + return + } + + logArgs = append(logArgs, "error", debugErr.Error()) +} + // signTransaction returns the transaction signed by the guardian if the verification passed func (gg *guardianGroup) signTransaction(c *gin.Context) { var request requests.SignTransaction diff --git a/api/groups/guardianGroup_test.go b/api/groups/guardianGroup_test.go index 87b68e16..32c831ca 100644 --- a/api/groups/guardianGroup_test.go +++ b/api/groups/guardianGroup_test.go @@ -516,6 +516,70 @@ func TestGuardianGroup_UnsetSecurityModeNoExpire(t *testing.T) { }) } +func TestGuardianGroup_getUserStatus(t *testing.T) { + t.Parallel() + + t.Run("facade returns error", func(t *testing.T) { + t.Parallel() + + facade := mockFacade.GuardianFacadeStub{ + GetUserStatusCalled: func(userAddress string) (*requests.UserStatusResponse, error) { + return &requests.UserStatusResponse{SecurityModeStatus: -1}, expectedError + }, + } + + gg, _ := groups.NewGuardianGroup(&facade) + + ws := startWebServer(gg, "guardian", getServiceRoutesConfig(), providedAddr) + + req, _ := http.NewRequest("GET", "/guardian/user-status/"+providedAddr, nil) + resp := httptest.NewRecorder() + ws.ServeHTTP(resp, req) + + statusRsp := generalResponse{} + loadResponse(resp.Body, &statusRsp) + + expectedGenResponse := createExpectedGeneralResponse(&requests.UserStatusResponse{ + SecurityModeStatus: -1, + }, "") + + assert.Equal(t, expectedGenResponse.Data, statusRsp.Data) + assert.True(t, strings.Contains(statusRsp.Error, expectedError.Error())) + require.Equal(t, http.StatusInternalServerError, resp.Code) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + facade := mockFacade.GuardianFacadeStub{ + GetUserStatusCalled: func(userAddress string) (*requests.UserStatusResponse, error) { + return &requests.UserStatusResponse{ + SecurityModeStatus: 1, + }, nil + }, + } + + gg, _ := groups.NewGuardianGroup(&facade) + + ws := startWebServer(gg, "guardian", getServiceRoutesConfig(), providedAddr) + + req, _ := http.NewRequest("GET", "/guardian/user-status/"+providedAddr, nil) + resp := httptest.NewRecorder() + ws.ServeHTTP(resp, req) + + statusRsp := generalResponse{} + loadResponse(resp.Body, &statusRsp) + + expectedGenResponse := createExpectedGeneralResponse(&requests.UserStatusResponse{ + SecurityModeStatus: 1, + }, "") + + assert.Equal(t, expectedGenResponse.Data, statusRsp.Data) + assert.Equal(t, "", statusRsp.Error) + require.Equal(t, http.StatusOK, resp.Code) + }) +} + func TestGuardianGroup_signMultipleTransaction(t *testing.T) { t.Parallel() diff --git a/api/middleware/nativeAuthWhitelistHandler.go b/api/middleware/nativeAuthWhitelistHandler.go index 74fd3b0f..9b630c55 100644 --- a/api/middleware/nativeAuthWhitelistHandler.go +++ b/api/middleware/nativeAuthWhitelistHandler.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "strings" "github.com/multiversx/mx-multi-factor-auth-go-service/config" ) @@ -20,7 +21,8 @@ func NewNativeAuthWhitelistHandler(apiPackages map[string]config.APIPackageConfi for _, route := range groupCfg.Routes { if !route.Auth { fullPath := fmt.Sprintf("%s%s", groupPath, route.Name) - whitelistedRoutes[fullPath] = struct{}{} + basePath := trimPathPlaceholder(fullPath) + whitelistedRoutes[basePath] = struct{}{} } } } @@ -31,9 +33,29 @@ func NewNativeAuthWhitelistHandler(apiPackages map[string]config.APIPackageConfi } } +func trimPathPlaceholder(path string) string { + parts := strings.Split(path, ":") + if len(parts) > 0 { + return "/" + strings.Trim(parts[0], "/") + } + + return path +} + +func extractBaseRoutePath(path string) string { + parts := strings.Split(path, "/") + + if len(parts) > 2 { + return "/" + parts[1] + "/" + parts[2] // group and base path + } + + return path +} + // IsWhitelisted returns true if the provided route is whitelisted for native authentication func (handler *nativeAuthWhitelistHandler) IsWhitelisted(route string) bool { - _, found := handler.whitelistedRoutesMap[route] + baseRoute := extractBaseRoutePath(route) + _, found := handler.whitelistedRoutesMap[baseRoute] return found } diff --git a/api/middleware/nativeAuthWhitelistHandler_test.go b/api/middleware/nativeAuthWhitelistHandler_test.go index ea08dd1c..08abfd0b 100644 --- a/api/middleware/nativeAuthWhitelistHandler_test.go +++ b/api/middleware/nativeAuthWhitelistHandler_test.go @@ -23,6 +23,11 @@ func TestNativeAuthWhitelistHandler(t *testing.T) { Open: true, Auth: false, }, + { + Name: "/user-status/:address", + Open: true, + Auth: false, + }, }, }, "status": { @@ -43,9 +48,13 @@ func TestNativeAuthWhitelistHandler(t *testing.T) { require.True(t, handler.IsWhitelisted("/guardian")) require.True(t, handler.IsWhitelisted("/status")) require.True(t, handler.IsWhitelisted("/log")) + require.True(t, handler.IsWhitelisted("/guardian/user-status/erd1")) + require.True(t, handler.IsWhitelisted("/guardian/user-status")) + require.True(t, handler.IsWhitelisted("/guardian/user-status/")) require.False(t, handler.IsWhitelisted("/guardian/register")) require.False(t, handler.IsWhitelisted("guardian/sign-transaction")) require.False(t, handler.IsWhitelisted("/sign-transaction")) + require.False(t, handler.IsWhitelisted("guardian/user-status/erd1")) require.False(t, handler.IsWhitelisted("guardian")) require.False(t, handler.IsWhitelisted("")) } diff --git a/api/shared/interface.go b/api/shared/interface.go index ada45584..110bf477 100644 --- a/api/shared/interface.go +++ b/api/shared/interface.go @@ -28,6 +28,7 @@ type FacadeHandler interface { SignMultipleTransactions(userIp string, request requests.SignMultipleTransactions) ([][]byte, *requests.OTPCodeVerifyData, error) SetSecurityModeNoExpire(userIp string, request requests.SecurityModeNoExpire) (*requests.OTPCodeVerifyData, error) UnsetSecurityModeNoExpire(userIp string, request requests.SecurityModeNoExpire) (*requests.OTPCodeVerifyData, error) + GetUserStatus(userAddress string) (*requests.UserStatusResponse, error) RegisteredUsers() (uint32, error) TcsConfig() *tcsCore.TcsConfig GetMetrics() map[string]*requests.EndpointMetricsResponse diff --git a/cmd/multi-factor-auth/config/api.toml b/cmd/multi-factor-auth/config/api.toml index ec02bf86..78d7fa05 100644 --- a/cmd/multi-factor-auth/config/api.toml +++ b/cmd/multi-factor-auth/config/api.toml @@ -21,6 +21,7 @@ RestApiInterface = ":8080" # The interface `address and port` to which the REST { Name = "/sign-multiple-transactions", Open = true, Auth = false, MaxContentLength = 1500000 }, { Name = "/set-security-mode", Open = true, Auth = false, MaxContentLength = 200 }, { Name = "/unset-security-mode", Open = true, Auth = false, MaxContentLength = 200 }, + { Name = "/user-status/:address", Open = true, Auth = false }, { Name = "/verify-code", Open = true, Auth = true, MaxContentLength = 200 }, { Name = "/registered-users", Open = true, Auth = false }, { Name = "/config", Open = true, Auth = false }, diff --git a/cmd/multi-factor-auth/swagger/data.go b/cmd/multi-factor-auth/swagger/data.go index 27aeffe0..8e1ccfa0 100644 --- a/cmd/multi-factor-auth/swagger/data.go +++ b/cmd/multi-factor-auth/swagger/data.go @@ -289,3 +289,33 @@ type _ struct { // required:true Payload requests.SecurityModeNoExpire } + +// swagger:route GET /user-status/{address} Guardian getUsersStatus +// Returns the status of the user. +// This request does not need the Authorization header +// +// responses: +// 200: userStatusResponse + +// swagger:parameters getUsersStatus +type _ struct { + // address of an account + // + // in:path + // required:true + Address string `json:"address"` +} + +// The security status of the user +// swagger:response userStatusResponse +type _ struct { + // in:body + Body struct { + // UserStatusResponse + Data requests.UserStatusResponse `json:"data"` + // HTTP status code + Status string `json:"status"` + // Internal error + Error string `json:"error"` + } +} diff --git a/cmd/multi-factor-auth/swagger/ui/swagger.json b/cmd/multi-factor-auth/swagger/ui/swagger.json index 2cc032bb..49102339 100644 --- a/cmd/multi-factor-auth/swagger/ui/swagger.json +++ b/cmd/multi-factor-auth/swagger/ui/swagger.json @@ -207,6 +207,31 @@ } } }, + "/user-status/{address}": { + "get": { + "description": "This request does not need the Authorization header", + "tags": [ + "Guardian" + ], + "summary": "Returns the status of the user.", + "operationId": "getUsersStatus", + "parameters": [ + { + "type": "string", + "x-go-name": "Address", + "description": "address of an account", + "name": "address", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "$ref": "#/responses/userStatusResponse" + } + } + } + }, "/verify-code": { "post": { "security": [ @@ -321,6 +346,14 @@ }, "x-go-name": "ReceiverUsername" }, + "relayer": { + "type": "string", + "x-go-name": "RelayerAddr" + }, + "relayerSignature": { + "type": "string", + "x-go-name": "RelayerSignature" + }, "sender": { "type": "string", "x-go-name": "Sender" @@ -599,6 +632,18 @@ }, "x-go-package": "github.com/multiversx/mx-multi-factor-auth-go-service/core/requests" }, + "UserStatusResponse": { + "description": "UserStatusResponse is the JSON response for the user status interrogation", + "type": "object", + "properties": { + "security-mode-status": { + "type": "integer", + "format": "int64", + "x-go-name": "SecurityModeStatus" + } + }, + "x-go-package": "github.com/multiversx/mx-multi-factor-auth-go-service/core/requests" + }, "VerificationPayload": { "description": "VerificationPayload represents the JSON requests a user uses to validate the authentication code", "type": "object", @@ -782,6 +827,27 @@ } } }, + "userStatusResponse": { + "description": "The security status of the user", + "schema": { + "type": "object", + "properties": { + "data": { + "$ref": "#/definitions/UserStatusResponse" + }, + "error": { + "description": "Internal error", + "type": "string", + "x-go-name": "Error" + }, + "status": { + "description": "HTTP status code", + "type": "string", + "x-go-name": "Status" + } + } + } + }, "verifyCodeResponse": { "description": "Verification result", "schema": { diff --git a/core/constants.go b/core/constants.go index a6ef2c91..b291891f 100644 --- a/core/constants.go +++ b/core/constants.go @@ -30,3 +30,15 @@ const ( // NoExpiryValue is the returned value for a persistent key expiry time const NoExpiryValue = -1 + +// EnhancedSecurityModeStatus represents the status of the security mode +type EnhancedSecurityModeStatus int + +const ( + // NotSet means that security mode is not activated + NotSet EnhancedSecurityModeStatus = iota + // ManuallySet means that security mode was activated because of failures + ManuallySet + // AutomaticallySet means that security mode was activated by user with SetSecurityModeNoExpire + AutomaticallySet +) diff --git a/core/interface.go b/core/interface.go index 8013a7fe..d19f3a48 100644 --- a/core/interface.go +++ b/core/interface.go @@ -37,6 +37,7 @@ type ServiceResolver interface { SignMessage(userIp string, request requests.SignMessage) ([]byte, *requests.OTPCodeVerifyData, error) SetSecurityModeNoExpire(userIp string, request requests.SecurityModeNoExpire) (*requests.OTPCodeVerifyData, error) UnsetSecurityModeNoExpire(userIp string, request requests.SecurityModeNoExpire) (*requests.OTPCodeVerifyData, error) + GetUserStatus(userAddr string) (*requests.UserStatusResponse, error) SignTransaction(userIp string, request requests.SignTransaction) ([]byte, *requests.OTPCodeVerifyData, error) SignMultipleTransactions(userIp string, request requests.SignMultipleTransactions) ([][]byte, *requests.OTPCodeVerifyData, error) RegisteredUsers() (uint32, error) diff --git a/core/requests/types.go b/core/requests/types.go index 99695760..54a66ab5 100644 --- a/core/requests/types.go +++ b/core/requests/types.go @@ -24,6 +24,11 @@ type SecurityModeNoExpire struct { UserAddr string `json:"user"` } +// UserStatusResponse is the JSON response for the user status interrogation +type UserStatusResponse struct { + SecurityModeStatus int `json:"security-mode-status"` +} + // SignMessageResponse is the service response to the sign message request type SignMessageResponse struct { Message string `json:"message"` diff --git a/facade/guardianFacade.go b/facade/guardianFacade.go index 6bccac7e..04693047 100644 --- a/facade/guardianFacade.go +++ b/facade/guardianFacade.go @@ -65,6 +65,11 @@ func (gf *guardianFacade) UnsetSecurityModeNoExpire(userIp string, request reque return gf.serviceResolver.UnsetSecurityModeNoExpire(userIp, request) } +// GetUserStatus returns the user's security status +func (gf *guardianFacade) GetUserStatus(userAddress string) (*requests.UserStatusResponse, error) { + return gf.serviceResolver.GetUserStatus(userAddress) +} + // SignMultipleTransactions validates user's transactions, then adds guardian signature and returns the transaction func (gf *guardianFacade) SignMultipleTransactions(userIp string, request requests.SignMultipleTransactions) ([][]byte, *requests.OTPCodeVerifyData, error) { return gf.serviceResolver.SignMultipleTransactions(userIp, request) diff --git a/facade/guardianFacade_test.go b/facade/guardianFacade_test.go index cd276ea5..3e9ee7db 100644 --- a/facade/guardianFacade_test.go +++ b/facade/guardianFacade_test.go @@ -113,6 +113,10 @@ func TestGuardianFacade_Getters(t *testing.T) { } wasUnsetSecurityModeNoExpireCalled := false + providedUserAddr := "erd1qyu5wthldzr8wx5c9ucg8kjagg0jfs53s8nr3zpz3hypefsdd8ssycr6th" + + expectedUserStatusResponse := requests.UserStatusResponse{SecurityModeStatus: 1} + args.ServiceResolver = &testscommon.ServiceResolverStub{ VerifyCodeCalled: func(userAddress sdkCore.AddressHandler, userIp string, request requests.VerificationPayload) (*requests.OTPCodeVerifyData, error) { assert.Equal(t, providedVerifyCodeReq, request) @@ -142,6 +146,10 @@ func TestGuardianFacade_Getters(t *testing.T) { wasUnsetSecurityModeNoExpireCalled = true return nil, nil }, + GetUserStatusCalled: func(userAddress string) (*requests.UserStatusResponse, error) { + assert.Equal(t, providedUserAddr, providedUserAddr) + return &expectedUserStatusResponse, nil + }, SignTransactionCalled: func(userIp string, request requests.SignTransaction) ([]byte, *requests.OTPCodeVerifyData, error) { assert.Equal(t, providedIp, userIp) assert.Equal(t, providedSignTxReq, request) @@ -207,6 +215,10 @@ func TestGuardianFacade_Getters(t *testing.T) { assert.Nil(t, err) assert.True(t, wasUnsetSecurityModeNoExpireCalled) + userStatus, err := facadeInstance.GetUserStatus(providedUserAddr) + assert.Nil(t, err) + assert.Equal(t, &expectedUserStatusResponse, userStatus) + signedTxs, _, err := facadeInstance.SignMultipleTransactions(providedIp, providedSignMultipleTxsReq) assert.Nil(t, err) assert.Equal(t, expectedSignMultipleTxsResponse, signedTxs) diff --git a/handlers/interface.go b/handlers/interface.go index a5f604df..4f3b749c 100644 --- a/handlers/interface.go +++ b/handlers/interface.go @@ -22,6 +22,7 @@ type SecureOtpHandler interface { SecurityModeMaxFailures() uint64 SetSecurityModeNoExpire(key string) error UnsetSecurityModeNoExpire(key string) error + GetSecurityStatus(key string) core.EnhancedSecurityModeStatus IsVerificationAllowedAndIncreaseTrials(account string, ip string) (*requests.OTPCodeVerifyData, error) Reset(account string, ip string) DecrementSecurityModeFailedTrials(account string) error diff --git a/handlers/secureOtp/secureOtpHandler.go b/handlers/secureOtp/secureOtpHandler.go index 8c9ed673..5d1a0114 100644 --- a/handlers/secureOtp/secureOtpHandler.go +++ b/handlers/secureOtp/secureOtpHandler.go @@ -114,6 +114,11 @@ func (totp *secureOtpHandler) UnsetSecurityModeNoExpire(key string) error { return totp.rateLimiter.UnsetSecurityModeNoExpire(key) } +// GetSecurityStatus returns the status of the security mode +func (totp *secureOtpHandler) GetSecurityStatus(key string) core.EnhancedSecurityModeStatus { + return totp.rateLimiter.GetSecurityStatus(key) +} + // Reset removes the account and ip from local cache func (totp *secureOtpHandler) Reset(account string, ip string) { key := computeVerificationKey(account, ip) diff --git a/handlers/secureOtp/secureOtpHandler_test.go b/handlers/secureOtp/secureOtpHandler_test.go index 4e140129..922afc98 100644 --- a/handlers/secureOtp/secureOtpHandler_test.go +++ b/handlers/secureOtp/secureOtpHandler_test.go @@ -412,6 +412,21 @@ func TestSecureOtpHandler_UnsetSecurityModeNoExpireShouldErr(t *testing.T) { require.Equal(t, expectedErr, err) } +func TestSecureOtpHandler_GetSecurityStatusShouldWork(t *testing.T) { + t.Parallel() + + args := createMockArgsSecureOtpHandler() + args.RateLimiter = &testscommon.RateLimiterStub{ + GetSecurityStatusCalled: func(key string) core.EnhancedSecurityModeStatus { + return core.NotSet + }, + } + totp, _ := secureOtp.NewSecureOtpHandler(args) + require.NotNil(t, totp) + + require.Equal(t, core.NotSet, totp.GetSecurityStatus(account)) +} + func TestSecureOtpHandler_Getters(t *testing.T) { t.Parallel() diff --git a/integrationTests/otpRateLimiting_test.go b/integrationTests/otpRateLimiting_test.go index f7b3a1b4..6469af49 100644 --- a/integrationTests/otpRateLimiting_test.go +++ b/integrationTests/otpRateLimiting_test.go @@ -421,12 +421,14 @@ func TestSecurityMode(t *testing.T) { securityModePeriodLimit := 86400 t.Run("test set security mode", func(t *testing.T) { - secureOtpHandler, redisServer := createRateLimiter(t, maxFailures, periodLimit, securityModeMaxFailures, securityModePeriodLimit) userAddress := "addr2" userIp := "ip2" + status := secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.NotSet, status) + otpVerifyData, err := secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) expOtpVerifyData := &requests.OTPCodeVerifyData{ @@ -437,6 +439,9 @@ func TestSecurityMode(t *testing.T) { } require.Equal(t, expOtpVerifyData, otpVerifyData) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.NotSet, status) + err = secureOtpHandler.SetSecurityModeNoExpire(userAddress) require.Nil(t, err) @@ -451,6 +456,9 @@ func TestSecurityMode(t *testing.T) { } require.Equal(t, expOtpVerifyData, otpVerifyData) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.ManuallySet, status) + redisServer.FastForward(time.Second * time.Duration(expOtpVerifyData.ResetAfter)) otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) @@ -461,6 +469,10 @@ func TestSecurityMode(t *testing.T) { SecurityModeResetAfter: -1, } require.Equal(t, expOtpVerifyData, otpVerifyData) + + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.ManuallySet, status) + }) t.Run("test unset when security mode is activated by user", func(t *testing.T) { @@ -469,6 +481,9 @@ func TestSecurityMode(t *testing.T) { userAddress := "addr2" userIp := "ip2" + status := secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.NotSet, status) + otpVerifyData, err := secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) expOtpVerifyData := &requests.OTPCodeVerifyData{ @@ -482,6 +497,9 @@ func TestSecurityMode(t *testing.T) { err = secureOtpHandler.SetSecurityModeNoExpire(userAddress) require.Nil(t, err) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.ManuallySet, status) + otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) expOtpVerifyData = &requests.OTPCodeVerifyData{ @@ -492,9 +510,15 @@ func TestSecurityMode(t *testing.T) { } require.Equal(t, expOtpVerifyData, otpVerifyData) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.ManuallySet, status) + err = secureOtpHandler.UnsetSecurityModeNoExpire(userAddress) require.Nil(t, err) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.NotSet, status) + otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) expOtpVerifyData = &requests.OTPCodeVerifyData{ @@ -513,6 +537,9 @@ func TestSecurityMode(t *testing.T) { userAddress := "addr2" userIp := "ip2" + status := secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.NotSet, status) + otpVerifyData, err := secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) expOtpVerifyData := &requests.OTPCodeVerifyData{ @@ -523,6 +550,9 @@ func TestSecurityMode(t *testing.T) { } require.Equal(t, expOtpVerifyData, otpVerifyData) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.NotSet, status) + otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) expOtpVerifyData = &requests.OTPCodeVerifyData{ @@ -533,6 +563,9 @@ func TestSecurityMode(t *testing.T) { } require.Equal(t, expOtpVerifyData, otpVerifyData) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.NotSet, status) + otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) expOtpVerifyData = &requests.OTPCodeVerifyData{ @@ -543,6 +576,9 @@ func TestSecurityMode(t *testing.T) { } require.Equal(t, expOtpVerifyData, otpVerifyData) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.AutomaticallySet, status) + otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.NotNil(t, otpVerifyData) require.Equal(t, core.ErrTooManyFailedAttempts, err) @@ -550,6 +586,9 @@ func TestSecurityMode(t *testing.T) { err = secureOtpHandler.UnsetSecurityModeNoExpire(userAddress) require.Nil(t, err) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.AutomaticallySet, status) + otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.NotNil(t, otpVerifyData) require.Equal(t, core.ErrTooManyFailedAttempts, err) @@ -571,9 +610,15 @@ func TestSecurityMode(t *testing.T) { } require.Equal(t, expOtpVerifyData, otpVerifyData) + status := secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.NotSet, status) + err = secureOtpHandler.SetSecurityModeNoExpire(userAddress) require.Nil(t, err) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.ManuallySet, status) + redisServer.FastForward(time.Second * time.Duration(expOtpVerifyData.ResetAfter)) otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) @@ -588,6 +633,9 @@ func TestSecurityMode(t *testing.T) { err = secureOtpHandler.SetSecurityModeNoExpire(userAddress) require.Nil(t, err) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.ManuallySet, status) + otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) expOtpVerifyData = &requests.OTPCodeVerifyData{ @@ -597,6 +645,9 @@ func TestSecurityMode(t *testing.T) { SecurityModeResetAfter: -1, } require.Equal(t, expOtpVerifyData, otpVerifyData) + + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.ManuallySet, status) }) t.Run("test unset multiple times", func(t *testing.T) { @@ -615,9 +666,15 @@ func TestSecurityMode(t *testing.T) { } require.Equal(t, expOtpVerifyData, otpVerifyData) + status := secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.NotSet, status) + err = secureOtpHandler.SetSecurityModeNoExpire(userAddress) require.Nil(t, err) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.ManuallySet, status) + otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) expOtpVerifyData = &requests.OTPCodeVerifyData{ @@ -631,6 +688,9 @@ func TestSecurityMode(t *testing.T) { err = secureOtpHandler.UnsetSecurityModeNoExpire(userAddress) require.Nil(t, err) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.NotSet, status) + otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Nil(t, err) expOtpVerifyData = &requests.OTPCodeVerifyData{ @@ -644,6 +704,9 @@ func TestSecurityMode(t *testing.T) { err = secureOtpHandler.UnsetSecurityModeNoExpire(userAddress) require.Nil(t, err) + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.NotSet, status) + otpVerifyData, err = secureOtpHandler.IsVerificationAllowedAndIncreaseTrials(userAddress, userIp) require.Equal(t, core.ErrTooManyFailedAttempts, err) expOtpVerifyData = &requests.OTPCodeVerifyData{ @@ -653,6 +716,9 @@ func TestSecurityMode(t *testing.T) { SecurityModeResetAfter: 86400, } require.Equal(t, expOtpVerifyData, otpVerifyData) + + status = secureOtpHandler.GetSecurityStatus(userAddress) + require.Equal(t, core.AutomaticallySet, status) }) } diff --git a/redis/interface.go b/redis/interface.go index f9e2a1dc..20412b89 100644 --- a/redis/interface.go +++ b/redis/interface.go @@ -3,6 +3,8 @@ package redis import ( "context" "time" + + "github.com/multiversx/mx-multi-factor-auth-go-service/core" ) type Mode int @@ -20,6 +22,7 @@ type RateLimiter interface { Reset(key string) error SetSecurityModeNoExpire(key string) error UnsetSecurityModeNoExpire(key string) error + GetSecurityStatus(key string) core.EnhancedSecurityModeStatus DecrementSecurityFailedTrials(key string) error Period(mode Mode) time.Duration Rate(mode Mode) int @@ -35,6 +38,7 @@ type RedisStorer interface { SetExpireIfNotExists(ctx context.Context, key string, ttl time.Duration) (bool, error) SetPersist(ctx context.Context, key string) (bool, error) SetGreaterExpireTTL(ctx context.Context, key string, ttl time.Duration) (bool, error) + Get(ctx context.Context, key string) (string, error) ResetCounterAndKeepTTL(ctx context.Context, key string) error ExpireTime(ctx context.Context, key string) (time.Duration, error) IsConnected(ctx context.Context) bool diff --git a/redis/redisClient.go b/redis/redisClient.go index d5963cf3..03a11e63 100644 --- a/redis/redisClient.go +++ b/redis/redisClient.go @@ -2,6 +2,7 @@ package redis import ( "context" + "errors" "time" "github.com/redis/go-redis/v9" @@ -37,6 +38,19 @@ func (r *redisClientWrapper) Decrement(ctx context.Context, key string) (int64, return r.client.Decr(ctx, key).Result() } +// Get will return the value corresponding to the specified key +func (r *redisClientWrapper) Get(ctx context.Context, key string) (string, error) { + val, err := r.client.Get(ctx, key).Result() + if errors.Is(err, redis.Nil) { + return "", ErrKeyNotExists + } + if err != nil { + return "", err + } + + return val, nil +} + // SetExpire will run expire for the specified key, setting the specified ttl func (r *redisClientWrapper) SetExpire(ctx context.Context, key string, ttl time.Duration) (bool, error) { return r.client.Expire(ctx, key, ttl).Result() diff --git a/redis/redisClient_test.go b/redis/redisClient_test.go index c0f63ff9..fb3259ae 100644 --- a/redis/redisClient_test.go +++ b/redis/redisClient_test.go @@ -86,6 +86,14 @@ func TestOperations(t *testing.T) { require.Nil(t, err) require.Equal(t, int64(2), retries) + value, err := rcw.Get(context.TODO(), "key1") + require.Nil(t, err) + require.Equal(t, "2", value) + + value, err = rcw.Get(context.TODO(), "invalidKey") + require.Equal(t, redis.ErrKeyNotExists, err) + require.Equal(t, "", value) + wasSet, err = rcw.SetPersist(context.TODO(), "key1") require.Nil(t, err) require.True(t, wasSet) diff --git a/redis/redisLimiter.go b/redis/redisLimiter.go index 0eccfd9f..9c7930a4 100644 --- a/redis/redisLimiter.go +++ b/redis/redisLimiter.go @@ -2,7 +2,9 @@ package redis import ( "context" + "errors" "fmt" + "strconv" "sync" "time" @@ -214,6 +216,45 @@ func (rl *rateLimiter) SetSecurityModeNoExpire(key string) error { return nil } +// GetSecurityStatus will return the security status based on the expiry time of the key +func (rl *rateLimiter) GetSecurityStatus(key string) core.EnhancedSecurityModeStatus { + ctx, cancel := context.WithTimeout(context.Background(), rl.operationTimeout) + defer cancel() + + return rl.getSecurityStatus(ctx, key) +} + +func (rl *rateLimiter) getSecurityStatus(ctx context.Context, key string) core.EnhancedSecurityModeStatus { + _, maxFailures := rl.getFailConfig(SecurityMode) + + dbVal, err := rl.storer.Get(ctx, key) + if errors.Is(err, ErrKeyNotExists) { + return core.NotSet + } + + expTime, err := rl.storer.ExpireTime(ctx, key) + if err != nil { + return core.NotSet + } + + trials, err := strconv.ParseInt(dbVal, 10, 64) + if err != nil { + log.Debug("error when converting security status", "err", err) + return core.NotSet + } + + if expTime == core.NoExpiryValue { + return core.ManuallySet + } + + hasTrialsLeft := trials < maxFailures + if !hasTrialsLeft { + return core.AutomaticallySet + } + + return core.NotSet +} + // UnsetSecurityModeNoExpire will set the key from persistent to volatile func (rl *rateLimiter) UnsetSecurityModeNoExpire(key string) error { ctx, cancel := context.WithTimeout(context.Background(), rl.operationTimeout) diff --git a/redis/redisLimiter_test.go b/redis/redisLimiter_test.go index dca79030..f14bff38 100644 --- a/redis/redisLimiter_test.go +++ b/redis/redisLimiter_test.go @@ -3,6 +3,7 @@ package redis_test import ( "context" "errors" + "strconv" "sync" "testing" "time" @@ -508,6 +509,121 @@ func TestUnsetSecurityModeNoExpire(t *testing.T) { }) } +func TestGetSecurityStatus(t *testing.T) { + t.Parallel() + + t.Run("should return NotSet because ErrKeyNotExists", func(t *testing.T) { + t.Parallel() + + args := createMockRateLimiterArgs() + redisClient := &testscommon.RedisClientStub{ + GetCalled: func(ctx context.Context, key string) (string, error) { + return "", redis.ErrKeyNotExists + }, + } + args.Storer = redisClient + + rl, err := redis.NewRateLimiter(args) + require.Nil(t, err) + + actualStatus := rl.GetSecurityStatus("key") + require.Equal(t, core.NotSet, actualStatus) + }) + + t.Run("should return NotSet because security mode wasn't activate neither manually nor automatically", func(t *testing.T) { + t.Parallel() + + maxFailures := 3 + maxDuration := 9 + securityModeMaxFailures := 100 + securityModeMaxDuration := 86400 + + args := createMockRateLimiterArgs() + args.FreezeFailureConfig.MaxFailures = int64(maxFailures) + args.FreezeFailureConfig.LimitPeriodInSec = uint64(maxDuration) + args.SecurityModeFailureConfig.MaxFailures = int64(securityModeMaxFailures) + args.SecurityModeFailureConfig.LimitPeriodInSec = uint64(securityModeMaxDuration) + + redisClient := &testscommon.RedisClientStub{ + ExpireTimeCalled: func(ctx context.Context, key string) (time.Duration, error) { + return time.Duration(securityModeMaxDuration), nil + }, + GetCalled: func(ctx context.Context, key string) (string, error) { + return strconv.Itoa(securityModeMaxFailures - 1), nil + }, + } + args.Storer = redisClient + + rl, err := redis.NewRateLimiter(args) + require.Nil(t, err) + + actualStatus := rl.GetSecurityStatus("key") + require.Equal(t, core.NotSet, actualStatus) + }) + + t.Run("should return Automatically set when failures is exceeded", func(t *testing.T) { + t.Parallel() + + maxFailures := 3 + maxDuration := 9 + securityModeMaxFailures := 100 + securityModeMaxDuration := 86400 + + args := createMockRateLimiterArgs() + args.FreezeFailureConfig.MaxFailures = int64(maxFailures) + args.FreezeFailureConfig.LimitPeriodInSec = uint64(maxDuration) + args.SecurityModeFailureConfig.MaxFailures = int64(securityModeMaxFailures) + args.SecurityModeFailureConfig.LimitPeriodInSec = uint64(securityModeMaxDuration) + + redisClient := &testscommon.RedisClientStub{ + ExpireTimeCalled: func(ctx context.Context, key string) (time.Duration, error) { + return time.Duration(securityModeMaxDuration), nil + }, + GetCalled: func(ctx context.Context, key string) (string, error) { + return strconv.Itoa(securityModeMaxFailures), nil + }, + } + args.Storer = redisClient + + rl, err := redis.NewRateLimiter(args) + require.Nil(t, err) + + actualStatus := rl.GetSecurityStatus("key") + require.Equal(t, core.AutomaticallySet, actualStatus) + }) + + t.Run("should return Manual set when key is persistent", func(t *testing.T) { + t.Parallel() + + maxFailures := 3 + maxDuration := 9 + securityModeMaxFailures := 100 + securityModeMaxDuration := 86400 + + args := createMockRateLimiterArgs() + args.FreezeFailureConfig.MaxFailures = int64(maxFailures) + args.FreezeFailureConfig.LimitPeriodInSec = uint64(maxDuration) + args.SecurityModeFailureConfig.MaxFailures = int64(securityModeMaxFailures) + args.SecurityModeFailureConfig.LimitPeriodInSec = uint64(securityModeMaxDuration) + + redisClient := &testscommon.RedisClientStub{ + ExpireTimeCalled: func(ctx context.Context, key string) (time.Duration, error) { + return -1, nil + }, + GetCalled: func(ctx context.Context, key string) (string, error) { + return strconv.Itoa(securityModeMaxFailures), nil + }, + } + args.Storer = redisClient + + rl, err := redis.NewRateLimiter(args) + require.Nil(t, err) + + actualStatus := rl.GetSecurityStatus("key") + require.Equal(t, core.ManuallySet, actualStatus) + }) +} + func TestDecrementSecurityFailedTrials(t *testing.T) { t.Parallel() diff --git a/resolver/serviceResolver.go b/resolver/serviceResolver.go index e6af204a..57214b30 100644 --- a/resolver/serviceResolver.go +++ b/resolver/serviceResolver.go @@ -289,6 +289,15 @@ func (resolver *serviceResolver) UnsetSecurityModeNoExpire(userIp string, reques return verifyCodeData, resolver.secureOtpHandler.UnsetSecurityModeNoExpire(request.UserAddr) } +// GetUserStatus gets the user's status +func (resolver *serviceResolver) GetUserStatus(userAddress string) (*requests.UserStatusResponse, error) { + status, err := resolver.getUserStatus(userAddress) + + return &requests.UserStatusResponse{ + SecurityModeStatus: int(status), + }, err +} + // SignTransaction validates user's transaction, then adds guardian signature and returns the transaction func (resolver *serviceResolver) SignTransaction(userIp string, request requests.SignTransaction) ([]byte, *requests.OTPCodeVerifyData, error) { guardian, otpCodeVerifyData, err := resolver.validateTxRequestReturningGuardian(userIp, request.Code, request.SecondCode, []transaction.FrontendTransaction{request.Tx}) @@ -485,6 +494,10 @@ func (resolver *serviceResolver) validateTxRequestReturningGuardian( return resolver.verifyCodesReturningGuardian(userAddress, txs[0].GuardianAddr, userIp, code, secondCode) } +func (resolver *serviceResolver) getUserStatus(userAddr string) (core.EnhancedSecurityModeStatus, error) { + return resolver.secureOtpHandler.GetSecurityStatus(userAddr), nil +} + func (resolver *serviceResolver) verifyCodesReturningGuardian( userAddress sdkCore.AddressHandler, guardianAddr string, diff --git a/resolver/serviceResolver_test.go b/resolver/serviceResolver_test.go index d1bd50c2..526234c4 100644 --- a/resolver/serviceResolver_test.go +++ b/resolver/serviceResolver_test.go @@ -2735,6 +2735,105 @@ func TestServiceResolver_UnsetSecurityModeNoExpire(t *testing.T) { } +func TestServiceResolver_GetUserStatus(t *testing.T) { + t.Parallel() + + providedSender := "erd1qyu5wthldzr8wx5c9ucg8kjagg0jfs53s8nr3zpz3hypefsdd8ssycr6th" + + t.Run("should return 0", func(t *testing.T) { + t.Parallel() + providedUserInfoCopy := *providedUserInfo + + args := createMockArgs() + args.RegisteredUsersDB = &testscommon.ShardedStorageWithIndexStub{ + GetCalled: func(key []byte) ([]byte, error) { + encryptedUser, err := args.UserEncryptor.EncryptUserInfo(&providedUserInfoCopy) + require.Nil(t, err) + return args.UserDataMarshaller.Marshal(encryptedUser) + }, + } + + args.SecureOtpHandler = &testscommon.SecureOtpHandlerStub{ + GetSecurityStatusCalled: func(key string) core.EnhancedSecurityModeStatus { + return core.NotSet + }, + } + + resolver, _ := NewServiceResolver(args) + assert.NotNil(t, resolver) + + expectedStatus := &requests.UserStatusResponse{ + SecurityModeStatus: 0, + } + statusReturned, err := resolver.GetUserStatus(providedSender) + + assert.Nil(t, err) + assert.Equal(t, expectedStatus, statusReturned) + }) + + t.Run("should return 1", func(t *testing.T) { + t.Parallel() + providedUserInfoCopy := *providedUserInfo + + args := createMockArgs() + args.RegisteredUsersDB = &testscommon.ShardedStorageWithIndexStub{ + GetCalled: func(key []byte) ([]byte, error) { + encryptedUser, err := args.UserEncryptor.EncryptUserInfo(&providedUserInfoCopy) + require.Nil(t, err) + return args.UserDataMarshaller.Marshal(encryptedUser) + }, + } + + args.SecureOtpHandler = &testscommon.SecureOtpHandlerStub{ + GetSecurityStatusCalled: func(key string) core.EnhancedSecurityModeStatus { + return core.ManuallySet + }, + } + + resolver, _ := NewServiceResolver(args) + assert.NotNil(t, resolver) + + expectedStatus := &requests.UserStatusResponse{ + SecurityModeStatus: 1, + } + statusReturned, err := resolver.GetUserStatus(providedSender) + + assert.Nil(t, err) + assert.Equal(t, expectedStatus, statusReturned) + }) + + t.Run("should return 2", func(t *testing.T) { + t.Parallel() + providedUserInfoCopy := *providedUserInfo + + args := createMockArgs() + args.RegisteredUsersDB = &testscommon.ShardedStorageWithIndexStub{ + GetCalled: func(key []byte) ([]byte, error) { + encryptedUser, err := args.UserEncryptor.EncryptUserInfo(&providedUserInfoCopy) + require.Nil(t, err) + return args.UserDataMarshaller.Marshal(encryptedUser) + }, + } + + args.SecureOtpHandler = &testscommon.SecureOtpHandlerStub{ + GetSecurityStatusCalled: func(key string) core.EnhancedSecurityModeStatus { + return core.AutomaticallySet + }, + } + + resolver, _ := NewServiceResolver(args) + assert.NotNil(t, resolver) + + expectedStatus := &requests.UserStatusResponse{ + SecurityModeStatus: 2, + } + statusReturned, err := resolver.GetUserStatus(providedSender) + + assert.Nil(t, err) + assert.Equal(t, expectedStatus, statusReturned) + }) +} + func TestServiceResolver_SignMultipleTransactions(t *testing.T) { t.Parallel() diff --git a/testscommon/facade/guardianFacadeStub.go b/testscommon/facade/guardianFacadeStub.go index a8fe4845..18bfb21f 100644 --- a/testscommon/facade/guardianFacadeStub.go +++ b/testscommon/facade/guardianFacadeStub.go @@ -14,6 +14,7 @@ type GuardianFacadeStub struct { SignMessageCalled func(userIp string, request requests.SignMessage) ([]byte, *requests.OTPCodeVerifyData, error) SetSecurityModeNoExpireCalled func(userIp string, request requests.SecurityModeNoExpire) (*requests.OTPCodeVerifyData, error) UnsetSecurityModeNoExpireCalled func(userIp string, request requests.SecurityModeNoExpire) (*requests.OTPCodeVerifyData, error) + GetUserStatusCalled func(userAddress string) (*requests.UserStatusResponse, error) SignTransactionCalled func(userIp string, request requests.SignTransaction) ([]byte, *requests.OTPCodeVerifyData, error) SignMultipleTransactionsCalled func(userIp string, request requests.SignMultipleTransactions) ([][]byte, *requests.OTPCodeVerifyData, error) RegisteredUsersCalled func() (uint32, error) @@ -62,6 +63,14 @@ func (stub *GuardianFacadeStub) UnsetSecurityModeNoExpire(userIp string, request return nil, nil } +// GetUserStatus - +func (stub *GuardianFacadeStub) GetUserStatus(userAddress string) (*requests.UserStatusResponse, error) { + if stub.GetUserStatusCalled != nil { + return stub.GetUserStatusCalled(userAddress) + } + return &requests.UserStatusResponse{}, nil +} + // SignTransaction - func (stub *GuardianFacadeStub) SignTransaction(userIp string, request requests.SignTransaction) ([]byte, *requests.OTPCodeVerifyData, error) { if stub.SignTransactionCalled != nil { diff --git a/testscommon/rateLimiterMock.go b/testscommon/rateLimiterMock.go index 4df6fc16..c9927639 100644 --- a/testscommon/rateLimiterMock.go +++ b/testscommon/rateLimiterMock.go @@ -4,6 +4,7 @@ import ( "sync" "time" + "github.com/multiversx/mx-multi-factor-auth-go-service/core" "github.com/multiversx/mx-multi-factor-auth-go-service/redis" ) @@ -59,6 +60,11 @@ func (r *RateLimiterMock) UnsetSecurityModeNoExpire(key string) error { return nil } +// GetSecurityStatus - +func (r *RateLimiterMock) GetSecurityStatus(key string) core.EnhancedSecurityModeStatus { + return core.NotSet +} + // Reset - func (r *RateLimiterMock) Reset(key string) error { r.mutTrials.Lock() diff --git a/testscommon/rateLimiterStub.go b/testscommon/rateLimiterStub.go index e10945a4..04821813 100644 --- a/testscommon/rateLimiterStub.go +++ b/testscommon/rateLimiterStub.go @@ -3,6 +3,7 @@ package testscommon import ( "time" + "github.com/multiversx/mx-multi-factor-auth-go-service/core" "github.com/multiversx/mx-multi-factor-auth-go-service/redis" ) @@ -15,6 +16,7 @@ type RateLimiterStub struct { RateCalled func(mode redis.Mode) int SetSecurityModeNoExpireCalled func(key string) error UnsetSecurityModeNoExpireCalled func(key string) error + GetSecurityStatusCalled func(key string) core.EnhancedSecurityModeStatus ExtendSecurityModeCalled func(key string) error } @@ -52,6 +54,14 @@ func (r *RateLimiterStub) UnsetSecurityModeNoExpire(key string) error { return nil } +// GetSecurityStatus - +func (r *RateLimiterStub) GetSecurityStatus(key string) core.EnhancedSecurityModeStatus { + if r.GetSecurityStatusCalled != nil { + return r.GetSecurityStatusCalled(key) + } + return core.NotSet +} + // Reset - func (r *RateLimiterStub) Reset(key string) error { if r.ResetCalled != nil { diff --git a/testscommon/redisClientStub.go b/testscommon/redisClientStub.go index cde47dd1..0743f8ec 100644 --- a/testscommon/redisClientStub.go +++ b/testscommon/redisClientStub.go @@ -9,6 +9,7 @@ import ( type RedisClientStub struct { IncrementCalled func(ctx context.Context, key string) (int64, error) DecrementCalled func(ctx context.Context, key string) (int64, error) + GetCalled func(ctx context.Context, key string) (string, error) SetExpireCalled func(ctx context.Context, key string, ttl time.Duration) (bool, error) SetExpireIfNotExistsCalled func(ctx context.Context, key string, ttl time.Duration) (bool, error) SetPersistCalled func(ctx context.Context, key string) (bool, error) @@ -35,6 +36,14 @@ func (r *RedisClientStub) Decrement(ctx context.Context, key string) (int64, err return 0, nil } +// Get - +func (r *RedisClientStub) Get(ctx context.Context, key string) (string, error) { + if r.GetCalled != nil { + return r.GetCalled(ctx, key) + } + return "", nil +} + // SetExpire - func (r *RedisClientStub) SetExpire(ctx context.Context, key string, ttl time.Duration) (bool, error) { if r.SetExpireCalled != nil { diff --git a/testscommon/secureOtpHandlerStub.go b/testscommon/secureOtpHandlerStub.go index 3d067221..a30fc86f 100644 --- a/testscommon/secureOtpHandlerStub.go +++ b/testscommon/secureOtpHandlerStub.go @@ -1,6 +1,9 @@ package testscommon -import "github.com/multiversx/mx-multi-factor-auth-go-service/core/requests" +import ( + "github.com/multiversx/mx-multi-factor-auth-go-service/core" + "github.com/multiversx/mx-multi-factor-auth-go-service/core/requests" +) // SecureOtpHandlerStub is a stub implementation of the SecureOtpHandler interface type SecureOtpHandlerStub struct { @@ -9,6 +12,7 @@ type SecureOtpHandlerStub struct { DecrementSecurityModeFailedTrialsCalled func(account string) error SetSecurityModeNoExpireCalled func(key string) error UnsetSecurityModeNoExpireCalled func(key string) error + GetSecurityStatusCalled func(key string) core.EnhancedSecurityModeStatus FreezeBackoffTimeCalled func() uint64 FreezeMaxFailuresCalled func() uint64 SecurityModeBackOffTimeCalled func() uint64 @@ -41,6 +45,14 @@ func (stub *SecureOtpHandlerStub) UnsetSecurityModeNoExpire(key string) error { return nil } +// GetSecurityStatus - +func (stub *SecureOtpHandlerStub) GetSecurityStatus(key string) core.EnhancedSecurityModeStatus { + if stub.GetSecurityStatusCalled != nil { + return stub.GetSecurityStatusCalled(key) + } + return core.NotSet +} + // Reset removes the account and ip from local cache func (stub *SecureOtpHandlerStub) Reset(account string, ip string) { if stub.ResetCalled != nil { diff --git a/testscommon/serviceResolverStub.go b/testscommon/serviceResolverStub.go index 0dbadf2f..3ecd3a41 100644 --- a/testscommon/serviceResolverStub.go +++ b/testscommon/serviceResolverStub.go @@ -14,6 +14,7 @@ type ServiceResolverStub struct { VerifyCodeCalled func(userAddress core.AddressHandler, userIp string, request requests.VerificationPayload) (*requests.OTPCodeVerifyData, error) SetSecurityModeNoExpireCalled func(userIp string, request requests.SecurityModeNoExpire) (*requests.OTPCodeVerifyData, error) UnsetSecurityModeNoExpireCalled func(userIp string, request requests.SecurityModeNoExpire) (*requests.OTPCodeVerifyData, error) + GetUserStatusCalled func(userAddress string) (*requests.UserStatusResponse, error) SignMessageCalled func(userIp string, request requests.SignMessage) ([]byte, *requests.OTPCodeVerifyData, error) SignTransactionCalled func(userIp string, request requests.SignTransaction) ([]byte, *requests.OTPCodeVerifyData, error) SignMultipleTransactionsCalled func(userIp string, request requests.SignMultipleTransactions) ([][]byte, *requests.OTPCodeVerifyData, error) @@ -61,6 +62,15 @@ func (stub *ServiceResolverStub) UnsetSecurityModeNoExpire(userIp string, reques return nil, nil } +// GetUserStatus - +func (stub *ServiceResolverStub) GetUserStatus(userAddress string) (*requests.UserStatusResponse, error) { + if stub.GetUserStatusCalled != nil { + return stub.GetUserStatusCalled(userAddress) + } + + return &requests.UserStatusResponse{}, nil +} + // SignTransaction - func (stub *ServiceResolverStub) SignTransaction(userIp string, request requests.SignTransaction) ([]byte, *requests.OTPCodeVerifyData, error) { if stub.SignTransactionCalled != nil {