Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/wrappers/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http/httptrace"
"net/url"
"strings"
"sync"
"time"

applicationErrors "github.com/checkmarx/ast-cli/internal/constants/errors"
Expand Down Expand Up @@ -44,6 +45,10 @@ const (
jsonContentType = "application/json"
)

var (
credentialsMutex sync.Mutex
)

type ClientCredentialsInfo struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
Expand Down Expand Up @@ -478,6 +483,9 @@ func getClientCredentialsFromCache(tokenExpirySeconds int) string {
}

func writeCredentialsToCache(accessToken string) {
credentialsMutex.Lock()
defer credentialsMutex.Unlock()

logger.PrintIfVerbose("Storing API access token to cache.")
viper.Set(commonParams.AstToken, accessToken)
cachedAccessToken = accessToken
Expand Down
37 changes: 36 additions & 1 deletion internal/wrappers/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@ package wrappers

import (
"errors"
"github.com/stretchr/testify/assert"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"testing"
"time"

commonParams "github.com/checkmarx/ast-cli/internal/params"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
)

type mockReadCloser struct{}
Expand Down Expand Up @@ -78,3 +85,31 @@ func TestRetryHTTPRequest_EndWithBadGateway(t *testing.T) {
assert.NotNil(t, resp)
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
}

func TestConcurrentWriteCredentialsToCache(t *testing.T) {
var wg sync.WaitGroup

for i := 0; i < 1000; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
writeCredentialsToCache(fmt.Sprintf("testToken_%d", i))
}(i)
}
wg.Wait()

token := viper.Get(commonParams.AstToken)
assert.NotNil(t, token, "Token should not be nil")

tokenStr, ok := token.(string)
assert.True(t, ok, "Token should be a string")

splitToken := strings.Split(tokenStr, "_")
assert.Equal(t, 2, len(splitToken), "Token should split into 2 parts")
assert.Equal(t, "testToken", splitToken[0], "Token prefix should be 'testToken'")

testTokenNumber, err := strconv.Atoi(splitToken[1])
assert.NoError(t, err, "The token suffix should be a valid number")
assert.True(t, testTokenNumber >= 0 && testTokenNumber < 1000,
"The token number should be within the expected range")
}
Loading