diff --git a/internal/commands/auth.go b/internal/commands/auth.go index 832fa66bc..bdbdb06b1 100644 --- a/internal/commands/auth.go +++ b/internal/commands/auth.go @@ -5,6 +5,7 @@ import ( "log" "github.com/MakeNowJust/heredoc" + "github.com/checkmarx/ast-cli/internal/logger" "github.com/checkmarx/ast-cli/internal/params" "github.com/checkmarx/ast-cli/internal/wrappers" "github.com/google/uuid" @@ -38,7 +39,7 @@ type ClientCreated struct { Secret string `json:"secret"` } -func NewAuthCommand(authWrapper wrappers.AuthWrapper) *cobra.Command { +func NewAuthCommand(authWrapper wrappers.AuthWrapper, telemetryWrapper wrappers.TelemetryWrapper) *cobra.Command { authCmd := &cobra.Command{ Use: "auth", Short: "Validate authentication and create OAuth2 credentials", @@ -110,14 +111,29 @@ func NewAuthCommand(authWrapper wrappers.AuthWrapper) *cobra.Command { `, ), }, - RunE: validLogin(), + RunE: validLogin(telemetryWrapper), } authCmd.AddCommand(createClientCmd, validLoginCmd) return authCmd } -func validLogin() func(cmd *cobra.Command, args []string) error { +func validLogin(telemetryWrapper wrappers.TelemetryWrapper) func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error { + defer func() { + logger.PrintIfVerbose("Calling GetUniqueId func") + uniqueID := wrappers.GetUniqueID() + if uniqueID != "" { + logger.PrintIfVerbose("Set unique id: " + uniqueID) + err := telemetryWrapper.SendAIDataToLog(&wrappers.DataForAITelemetry{ + UniqueID: uniqueID, + Type: "authentication", + SubType: "authentication", + }) + if err != nil { + logger.PrintIfVerbose("Failed to send telemetry data: " + err.Error()) + } + } + }() clientID := viper.GetString(params.AccessKeyIDConfigKey) clientSecret := viper.GetString(params.AccessKeySecretConfigKey) apiKey := viper.GetString(params.AstAPIKey) diff --git a/internal/commands/root.go b/internal/commands/root.go index 9c56cb812..dc9587c52 100644 --- a/internal/commands/root.go +++ b/internal/commands/root.go @@ -205,7 +205,7 @@ func NewAstCLI( ) versionCmd := util.NewVersionCommand() - authCmd := NewAuthCommand(authWrapper) + authCmd := NewAuthCommand(authWrapper, telemetryWrapper) utilsCmd := util.NewUtilsCommand( gitHubWrapper, azureWrapper, diff --git a/internal/commands/telemetry.go b/internal/commands/telemetry.go index 3b5bbefe0..cbf3d1f3a 100644 --- a/internal/commands/telemetry.go +++ b/internal/commands/telemetry.go @@ -2,6 +2,7 @@ package commands import ( "github.com/MakeNowJust/heredoc" + "github.com/checkmarx/ast-cli/internal/logger" "github.com/checkmarx/ast-cli/internal/params" "github.com/checkmarx/ast-cli/internal/wrappers" "github.com/pkg/errors" @@ -58,7 +59,8 @@ func runTelemetryAI(telemetryWrapper wrappers.TelemetryWrapper) func(*cobra.Comm scanType, _ := cmd.Flags().GetString("scan-type") status, _ := cmd.Flags().GetString("status") totalCount, _ := cmd.Flags().GetInt("total-count") - + uniqueID := wrappers.GetUniqueID() + logger.PrintIfVerbose("unique id: " + uniqueID) err := telemetryWrapper.SendAIDataToLog(&wrappers.DataForAITelemetry{ AIProvider: aiProvider, ProblemSeverity: problemSeverity, @@ -69,6 +71,7 @@ func runTelemetryAI(telemetryWrapper wrappers.TelemetryWrapper) func(*cobra.Comm ScanType: scanType, Status: status, TotalCount: totalCount, + UniqueID: uniqueID, }) if err != nil { diff --git a/internal/params/envs.go b/internal/params/envs.go index 9698eb699..89fdc97d9 100644 --- a/internal/params/envs.go +++ b/internal/params/envs.go @@ -80,6 +80,7 @@ const ( RiskManagementPathEnv = "CX_RISK_MANAGEMENT_PATH" ConfigFilePathEnv = "CX_CONFIG_FILE_PATH" RealtimeScannerPathEnv = "CX_REALTIME_SCANNER_PATH" + UniqueIDEnv = "CX_UNIQUE_ID" StartMultiPartUploadPathEnv = "CX_START_MULTIPART_UPLOAD_PATH" MultipartPresignedPathEnv = "CX_MULTIPART_PRESIGNED_URL_PATH" CompleteMultipartUploadPathEnv = "CX_COMPLETE_MULTIPART_UPLOAD_PATH" diff --git a/internal/params/keys.go b/internal/params/keys.go index adabc95a5..baeaf046c 100644 --- a/internal/params/keys.go +++ b/internal/params/keys.go @@ -79,6 +79,7 @@ var ( RiskManagementPathKey = strings.ToLower(RiskManagementPathEnv) ConfigFilePathKey = strings.ToLower(ConfigFilePathEnv) RealtimeScannerPathKey = strings.ToLower(RealtimeScannerPathEnv) + UniqueIDConfigKey = strings.ToLower(UniqueIDEnv) StartMultiPartUploadPathKey = strings.ToLower(StartMultiPartUploadPathEnv) MultipartPresignedPathKey = strings.ToLower(MultipartPresignedPathEnv) CompleteMultiPartUploadPathKey = strings.ToLower(CompleteMultipartUploadPathEnv) diff --git a/internal/wrappers/client.go b/internal/wrappers/client.go index 032eb4e4b..f588b4d55 100644 --- a/internal/wrappers/client.go +++ b/internal/wrappers/client.go @@ -126,12 +126,21 @@ func retryHTTPForIAMRequest(requestFunc func() (*http.Response, error), retries return nil, err } -func setAgentNameAndOrigin(req *http.Request) { +func setAgentNameAndOrigin(req *http.Request, isAuth bool) { agentStr := viper.GetString(commonParams.AgentNameKey) + "/" + commonParams.Version req.Header.Set("User-Agent", agentStr) originStr := viper.GetString(commonParams.OriginKey) req.Header.Set("Cx-Origin", originStr) + logger.PrintIfVerbose("getting unique id") + + if !isAuth { + uniqueID := GetUniqueID() + if uniqueID != "" { + req.Header.Set("UniqueId", uniqueID) + logger.PrintIfVerbose("unique id: " + uniqueID) + } + } } func GetClient(timeout uint) *http.Client { @@ -375,7 +384,7 @@ func SendHTTPRequestByFullURLContentLength( req.ContentLength = contentLength } client := GetClient(timeout) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, false) if auth { enrichWithOath2Credentials(req, accessToken, bearerFormat) } @@ -427,7 +436,7 @@ func SendHTTPRequestPasswordAuth(method string, body io.Reader, timeout uint, us } req, err := http.NewRequest(method, u, body) client := GetClient(timeout) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, true) if err != nil { return nil, err } @@ -464,7 +473,7 @@ func HTTPRequestWithQueryParams( } req, err := http.NewRequest(method, u, body) client := GetClient(timeout) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, false) if err != nil { return nil, err } @@ -512,7 +521,7 @@ func SendHTTPRequestWithJSONContentType(method, path string, body io.Reader, aut } req, err := http.NewRequest(method, fullURL, body) client := GetClient(timeout) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, false) req.Header.Add("Content-Type", jsonContentType) if err != nil { return nil, err @@ -645,7 +654,7 @@ func writeCredentialsToCache(accessToken string) { func getNewToken(credentialsPayload, authServerURI string) (string, error) { payload := strings.NewReader(credentialsPayload) req, err := http.NewRequest(http.MethodPost, authServerURI, payload) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, true) if err != nil { return "", err } diff --git a/internal/wrappers/client_test.go b/internal/wrappers/client_test.go index b8a45f0d1..e75e9e9b0 100644 --- a/internal/wrappers/client_test.go +++ b/internal/wrappers/client_test.go @@ -192,7 +192,7 @@ func TestSetAgentNameAndOrigin(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, false) userAgent := req.Header.Get("User-Agent") origin := req.Header.Get("origin") diff --git a/internal/wrappers/jwt-helper.go b/internal/wrappers/jwt-helper.go index 7a3ec506f..9793e8d4e 100644 --- a/internal/wrappers/jwt-helper.go +++ b/internal/wrappers/jwt-helper.go @@ -1,12 +1,17 @@ package wrappers import ( + "os/user" "strings" + "github.com/checkmarx/ast-cli/internal/logger" commonParams "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/wrappers/configuration" "github.com/checkmarx/ast-cli/internal/wrappers/utils" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" "github.com/pkg/errors" + "github.com/spf13/viper" ) // JWTStruct model used to get all jwt fields @@ -158,3 +163,49 @@ func (*JWTStruct) CheckPermissionByAccessToken(requiredPermission string) (hasPe } return permission, nil } + +func GetUniqueID() string { + var uniqueID string + // Check License first + jwtWrapper := NewJwtWrapper() + isAllowed, err := jwtWrapper.IsAllowedEngine("Checkmarx Developer Assist") + if err != nil { + logger.PrintIfVerbose("Failed to check engine allowance: " + err.Error()) + return "" + } + if !isAllowed { + logger.PrintIfVerbose("User does not have permission to standalone dev assists feature") + return "" + } + + // Check if unique id is already set + uniqueID = viper.GetString(commonParams.UniqueIDConfigKey) + if uniqueID != "" { + return uniqueID + } + + // Generate new unique id + logger.PrintIfVerbose("Generating new unique id") + currentUser, err := user.Current() + if err != nil { + logger.PrintIfVerbose("Failed to get user: " + err.Error()) + return "" + } + username := currentUser.Username + username = strings.TrimSpace(username) + logger.PrintIfVerbose("Username to be used for unique id: " + username) + if strings.Contains(username, "\\") { + username = strings.Split(username, "\\")[1] + } + uniqueID = uuid.New().String() + "_" + username + + logger.PrintIfVerbose("Unique id: " + uniqueID) + viper.Set(commonParams.UniqueIDConfigKey, uniqueID) + configFilePath, _ := configuration.GetConfigFilePath() + err = configuration.SafeWriteSingleConfigKeyString(configFilePath, commonParams.UniqueIDConfigKey, uniqueID) + if err != nil { + logger.PrintIfVerbose("Failed to write config: " + err.Error()) + return "" + } + return uniqueID +} diff --git a/internal/wrappers/jwt-helper_test.go b/internal/wrappers/jwt-helper_test.go index 8bf721983..ca51c1e19 100644 --- a/internal/wrappers/jwt-helper_test.go +++ b/internal/wrappers/jwt-helper_test.go @@ -1,8 +1,11 @@ package wrappers import ( + "strings" "testing" + commonParams "github.com/checkmarx/ast-cli/internal/params" + "github.com/spf13/viper" "gotest.tools/assert" ) @@ -71,3 +74,44 @@ func TestGetEnabledEngines(t *testing.T) { }) } } + +func TestGetUniqueID(t *testing.T) { + // Save original value and restore after test + originalID := viper.GetString(commonParams.UniqueIDConfigKey) + defer viper.Set(commonParams.UniqueIDConfigKey, originalID) + + t.Run("returns existing unique ID from config", func(t *testing.T) { + // Setup: set existing ID + existingID := "test-uuid-456_testuser" + viper.Set(commonParams.UniqueIDConfigKey, existingID) + + result := GetUniqueID() + + if result != "" { + assert.Equal(t, existingID, result) + } else { + t.Skip("Requires valid auth and 'Checkmarx Developer Assist' license") + } + }) + + t.Run("generates new unique ID when none exists", func(t *testing.T) { + // Setup: clear existing ID + viper.Set(commonParams.UniqueIDConfigKey, "") + + result := GetUniqueID() + + if result == "" { + t.Skip("Requires valid auth and 'Checkmarx Developer Assist' license") + return + } + + // Verify format: UUID_username + assert.Assert(t, strings.Contains(result, "_"), "Should have UUID_username format") + assert.Assert(t, len(result) > 36, "Should contain UUID and username") + + // Verify no backslash (Windows domain stripped) + parts := strings.Split(result, "_") + assert.Assert(t, len(parts) >= 2, "Should have at least 2 parts") + assert.Assert(t, !strings.Contains(parts[1], "\\"), "Username should not contain backslash") + }) +} diff --git a/internal/wrappers/telemetry.go b/internal/wrappers/telemetry.go index 58e8a5b73..b3e58781f 100644 --- a/internal/wrappers/telemetry.go +++ b/internal/wrappers/telemetry.go @@ -10,6 +10,7 @@ type DataForAITelemetry struct { ScanType string `json:"scanType"` Status string `json:"status"` TotalCount int `json:"totalCount"` + UniqueID string `json:"uniqueId"` } type TelemetryWrapper interface {