Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions internal/commands/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"log"

"github.com/MakeNowJust/heredoc"
"github.com/checkmarx/ast-cli/internal/logger"
"github.com/checkmarx/ast-cli/internal/params"
"github.com/checkmarx/ast-cli/internal/wrappers"
"github.com/google/uuid"
Expand Down Expand Up @@ -38,7 +39,7 @@ type ClientCreated struct {
Secret string `json:"secret"`
}

func NewAuthCommand(authWrapper wrappers.AuthWrapper) *cobra.Command {
func NewAuthCommand(authWrapper wrappers.AuthWrapper, telemetryWrapper wrappers.TelemetryWrapper) *cobra.Command {
authCmd := &cobra.Command{
Use: "auth",
Short: "Validate authentication and create OAuth2 credentials",
Expand Down Expand Up @@ -110,14 +111,29 @@ func NewAuthCommand(authWrapper wrappers.AuthWrapper) *cobra.Command {
`,
),
},
RunE: validLogin(),
RunE: validLogin(telemetryWrapper),
}
authCmd.AddCommand(createClientCmd, validLoginCmd)
return authCmd
}

func validLogin() func(cmd *cobra.Command, args []string) error {
func validLogin(telemetryWrapper wrappers.TelemetryWrapper) func(cmd *cobra.Command, args []string) error {
return func(cmd *cobra.Command, args []string) error {
defer func() {
logger.PrintIfVerbose("Calling GetUniqueId func")
uniqueID := wrappers.GetUniqueID()
if uniqueID != "" {
logger.PrintIfVerbose("Set unique id: " + uniqueID)
err := telemetryWrapper.SendAIDataToLog(&wrappers.DataForAITelemetry{
UniqueID: uniqueID,
Type: "authentication",
SubType: "authentication",
})
if err != nil {
logger.PrintIfVerbose("Failed to send telemetry data: " + err.Error())
}
}
}()
clientID := viper.GetString(params.AccessKeyIDConfigKey)
clientSecret := viper.GetString(params.AccessKeySecretConfigKey)
apiKey := viper.GetString(params.AstAPIKey)
Expand Down
2 changes: 1 addition & 1 deletion internal/commands/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func NewAstCLI(
)

versionCmd := util.NewVersionCommand()
authCmd := NewAuthCommand(authWrapper)
authCmd := NewAuthCommand(authWrapper, telemetryWrapper)
utilsCmd := util.NewUtilsCommand(
gitHubWrapper,
azureWrapper,
Expand Down
5 changes: 4 additions & 1 deletion internal/commands/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package commands

import (
"github.com/MakeNowJust/heredoc"
"github.com/checkmarx/ast-cli/internal/logger"
"github.com/checkmarx/ast-cli/internal/params"
"github.com/checkmarx/ast-cli/internal/wrappers"
"github.com/pkg/errors"
Expand Down Expand Up @@ -58,7 +59,8 @@ func runTelemetryAI(telemetryWrapper wrappers.TelemetryWrapper) func(*cobra.Comm
scanType, _ := cmd.Flags().GetString("scan-type")
status, _ := cmd.Flags().GetString("status")
totalCount, _ := cmd.Flags().GetInt("total-count")

uniqueID := wrappers.GetUniqueID()
logger.PrintIfVerbose("unique id: " + uniqueID)
err := telemetryWrapper.SendAIDataToLog(&wrappers.DataForAITelemetry{
AIProvider: aiProvider,
ProblemSeverity: problemSeverity,
Expand All @@ -69,6 +71,7 @@ func runTelemetryAI(telemetryWrapper wrappers.TelemetryWrapper) func(*cobra.Comm
ScanType: scanType,
Status: status,
TotalCount: totalCount,
UniqueID: uniqueID,
})

if err != nil {
Expand Down
1 change: 1 addition & 0 deletions internal/params/envs.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ const (
RiskManagementPathEnv = "CX_RISK_MANAGEMENT_PATH"
ConfigFilePathEnv = "CX_CONFIG_FILE_PATH"
RealtimeScannerPathEnv = "CX_REALTIME_SCANNER_PATH"
UniqueIDEnv = "CX_UNIQUE_ID"
StartMultiPartUploadPathEnv = "CX_START_MULTIPART_UPLOAD_PATH"
MultipartPresignedPathEnv = "CX_MULTIPART_PRESIGNED_URL_PATH"
CompleteMultipartUploadPathEnv = "CX_COMPLETE_MULTIPART_UPLOAD_PATH"
Expand Down
1 change: 1 addition & 0 deletions internal/params/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ var (
RiskManagementPathKey = strings.ToLower(RiskManagementPathEnv)
ConfigFilePathKey = strings.ToLower(ConfigFilePathEnv)
RealtimeScannerPathKey = strings.ToLower(RealtimeScannerPathEnv)
UniqueIDConfigKey = strings.ToLower(UniqueIDEnv)
StartMultiPartUploadPathKey = strings.ToLower(StartMultiPartUploadPathEnv)
MultipartPresignedPathKey = strings.ToLower(MultipartPresignedPathEnv)
CompleteMultiPartUploadPathKey = strings.ToLower(CompleteMultipartUploadPathEnv)
Expand Down
21 changes: 15 additions & 6 deletions internal/wrappers/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,21 @@ func retryHTTPForIAMRequest(requestFunc func() (*http.Response, error), retries
return nil, err
}

func setAgentNameAndOrigin(req *http.Request) {
func setAgentNameAndOrigin(req *http.Request, isAuth bool) {
agentStr := viper.GetString(commonParams.AgentNameKey) + "/" + commonParams.Version
req.Header.Set("User-Agent", agentStr)

originStr := viper.GetString(commonParams.OriginKey)
req.Header.Set("Cx-Origin", originStr)
logger.PrintIfVerbose("getting unique id")

if !isAuth {
uniqueID := GetUniqueID()
if uniqueID != "" {
req.Header.Set("UniqueId", uniqueID)
logger.PrintIfVerbose("unique id: " + uniqueID)
}
}
}

func GetClient(timeout uint) *http.Client {
Expand Down Expand Up @@ -375,7 +384,7 @@ func SendHTTPRequestByFullURLContentLength(
req.ContentLength = contentLength
}
client := GetClient(timeout)
setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, false)
if auth {
enrichWithOath2Credentials(req, accessToken, bearerFormat)
}
Expand Down Expand Up @@ -427,7 +436,7 @@ func SendHTTPRequestPasswordAuth(method string, body io.Reader, timeout uint, us
}
req, err := http.NewRequest(method, u, body)
client := GetClient(timeout)
setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -464,7 +473,7 @@ func HTTPRequestWithQueryParams(
}
req, err := http.NewRequest(method, u, body)
client := GetClient(timeout)
setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -512,7 +521,7 @@ func SendHTTPRequestWithJSONContentType(method, path string, body io.Reader, aut
}
req, err := http.NewRequest(method, fullURL, body)
client := GetClient(timeout)
setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, false)
req.Header.Add("Content-Type", jsonContentType)
if err != nil {
return nil, err
Expand Down Expand Up @@ -645,7 +654,7 @@ func writeCredentialsToCache(accessToken string) {
func getNewToken(credentialsPayload, authServerURI string) (string, error) {
payload := strings.NewReader(credentialsPayload)
req, err := http.NewRequest(http.MethodPost, authServerURI, payload)
setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, true)
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/wrappers/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func TestSetAgentNameAndOrigin(t *testing.T) {

req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)

setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, false)

userAgent := req.Header.Get("User-Agent")
origin := req.Header.Get("origin")
Expand Down
51 changes: 51 additions & 0 deletions internal/wrappers/jwt-helper.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package wrappers

import (
"os/user"
"strings"

"github.com/checkmarx/ast-cli/internal/logger"
commonParams "github.com/checkmarx/ast-cli/internal/params"
"github.com/checkmarx/ast-cli/internal/wrappers/configuration"
"github.com/checkmarx/ast-cli/internal/wrappers/utils"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/spf13/viper"
)

// JWTStruct model used to get all jwt fields
Expand Down Expand Up @@ -158,3 +163,49 @@ func (*JWTStruct) CheckPermissionByAccessToken(requiredPermission string) (hasPe
}
return permission, nil
}

func GetUniqueID() string {
var uniqueID string
// Check License first
jwtWrapper := NewJwtWrapper()
isAllowed, err := jwtWrapper.IsAllowedEngine("Checkmarx Developer Assist")
if err != nil {
logger.PrintIfVerbose("Failed to check engine allowance: " + err.Error())
return ""
}
if !isAllowed {
logger.PrintIfVerbose("User does not have permission to standalone dev assists feature")
return ""
}

// Check if unique id is already set
uniqueID = viper.GetString(commonParams.UniqueIDConfigKey)
if uniqueID != "" {
return uniqueID
}

// Generate new unique id
logger.PrintIfVerbose("Generating new unique id")
currentUser, err := user.Current()
if err != nil {
logger.PrintIfVerbose("Failed to get user: " + err.Error())
return ""
}
username := currentUser.Username
username = strings.TrimSpace(username)
logger.PrintIfVerbose("Username to be used for unique id: " + username)
if strings.Contains(username, "\\") {
username = strings.Split(username, "\\")[1]
}
uniqueID = uuid.New().String() + "_" + username

logger.PrintIfVerbose("Unique id: " + uniqueID)
viper.Set(commonParams.UniqueIDConfigKey, uniqueID)
configFilePath, _ := configuration.GetConfigFilePath()
err = configuration.SafeWriteSingleConfigKeyString(configFilePath, commonParams.UniqueIDConfigKey, uniqueID)
if err != nil {
logger.PrintIfVerbose("Failed to write config: " + err.Error())
return ""
}
return uniqueID
}
44 changes: 44 additions & 0 deletions internal/wrappers/jwt-helper_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package wrappers

import (
"strings"
"testing"

commonParams "github.com/checkmarx/ast-cli/internal/params"
"github.com/spf13/viper"
"gotest.tools/assert"
)

Expand Down Expand Up @@ -71,3 +74,44 @@ func TestGetEnabledEngines(t *testing.T) {
})
}
}

func TestGetUniqueID(t *testing.T) {
// Save original value and restore after test
originalID := viper.GetString(commonParams.UniqueIDConfigKey)
defer viper.Set(commonParams.UniqueIDConfigKey, originalID)

t.Run("returns existing unique ID from config", func(t *testing.T) {
// Setup: set existing ID
existingID := "test-uuid-456_testuser"
viper.Set(commonParams.UniqueIDConfigKey, existingID)

result := GetUniqueID()

if result != "" {
assert.Equal(t, existingID, result)
} else {
t.Skip("Requires valid auth and 'Checkmarx Developer Assist' license")
}
})

t.Run("generates new unique ID when none exists", func(t *testing.T) {
// Setup: clear existing ID
viper.Set(commonParams.UniqueIDConfigKey, "")

result := GetUniqueID()

if result == "" {
t.Skip("Requires valid auth and 'Checkmarx Developer Assist' license")
return
}

// Verify format: UUID_username
assert.Assert(t, strings.Contains(result, "_"), "Should have UUID_username format")
assert.Assert(t, len(result) > 36, "Should contain UUID and username")

// Verify no backslash (Windows domain stripped)
parts := strings.Split(result, "_")
assert.Assert(t, len(parts) >= 2, "Should have at least 2 parts")
assert.Assert(t, !strings.Contains(parts[1], "\\"), "Username should not contain backslash")
})
}
1 change: 1 addition & 0 deletions internal/wrappers/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type DataForAITelemetry struct {
ScanType string `json:"scanType"`
Status string `json:"status"`
TotalCount int `json:"totalCount"`
UniqueID string `json:"uniqueId"`
}

type TelemetryWrapper interface {
Expand Down
Loading