Skip to content
15 changes: 8 additions & 7 deletions internal/commands/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -1064,13 +1064,6 @@ func validateScanTypes(cmd *cobra.Command, jwtWrapper wrappers.JWTWrapper, featu
userScanTypes = strings.Replace(strings.ToLower(userScanTypes), commonParams.ContainersTypeFlag, commonParams.ContainersType, 1)
userSCSScanTypes = strings.Replace(strings.ToLower(userSCSScanTypes), commonParams.SCSEnginesFlag, commonParams.ScsType, 1)

SCSScanTypes = strings.Split(userSCSScanTypes, ",")
if slices.Contains(SCSScanTypes, ScsSecretDetectionType) && !allowedEngines[commonParams.EnterpriseSecretsType] {
keys := reflect.ValueOf(allowedEngines).MapKeys()
err = errors.Errorf(engineNotAllowed, ScsSecretDetectionType, ScsSecretDetectionType, keys)
return err
}

scanTypes = strings.Split(userScanTypes, ",")
for _, scanType := range scanTypes {
if !allowedEngines[scanType] || (scanType == commonParams.ContainersType && !(containerEngineCLIEnabled.Status)) {
Expand All @@ -1079,6 +1072,14 @@ func validateScanTypes(cmd *cobra.Command, jwtWrapper wrappers.JWTWrapper, featu
return err
}
}

SCSScanTypes = strings.Split(userSCSScanTypes, ",")
if slices.Contains(SCSScanTypes, ScsSecretDetectionType) && !allowedEngines[commonParams.EnterpriseSecretsType] {
keys := reflect.ValueOf(allowedEngines).MapKeys()
err = errors.Errorf(engineNotAllowed, ScsSecretDetectionType, ScsSecretDetectionType, keys)
return err
}

} else {
for k := range allowedEngines {
if k == commonParams.ContainersType && !(containerEngineCLIEnabled.Status) {
Expand Down
57 changes: 57 additions & 0 deletions internal/commands/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1765,3 +1765,60 @@ func TestUploadZip_whenUserNotProvideZip_shouldReturnZipFilePathInFailureCase(t
assert.Assert(t, strings.Contains(err.Error(), "error from UploadFile"), err.Error())
assert.Equal(t, zipPath, "failureCase.zip")
}

func TestValidateScanTypes(t *testing.T) {
tests := []struct {
name string
userScanTypes string
userSCSScanTypes string
allowedEngines map[string]bool
containerEngineCLIEnabled bool
expectedError string
}{
{
name: "No licenses available",
userScanTypes: "scs",
userSCSScanTypes: "sast,secret-detection",
allowedEngines: map[string]bool{"scs": false, "enterprise-secrets": false},
containerEngineCLIEnabled: true,
expectedError: "It looks like the \"scs\" scan type does",
},
{
name: "SCS license available, secret-detection not available",
userScanTypes: "scs",
userSCSScanTypes: "secret-detection",
allowedEngines: map[string]bool{"scs": true, "enterprise-secrets": false},
containerEngineCLIEnabled: true,
expectedError: "It looks like the \"secret-detection\" scan type does not exist",
},
{
name: "All licenses available",
userScanTypes: "scs",
userSCSScanTypes: "secret-detection",
allowedEngines: map[string]bool{"scs": true, "enterprise-secrets": true},
containerEngineCLIEnabled: true,
expectedError: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{}
cmd.Flags().String(commonParams.ScanTypes, tt.userScanTypes, "")
cmd.Flags().String(commonParams.SCSEnginesFlag, tt.userSCSScanTypes, "")

jwtWrapper := &mock.JWTMockWrapper{
CustomGetAllowedEngines: func(featureFlagsWrapper wrappers.FeatureFlagsWrapper) (map[string]bool, error) {
return tt.allowedEngines, nil
},
}
featureFlagsWrapper := &mock.FeatureFlagsMockWrapper{}
err := validateScanTypes(cmd, jwtWrapper, featureFlagsWrapper)
if tt.expectedError != "" {
assert.ErrorContains(t, err, tt.expectedError)
} else {
assert.NilError(t, err)
}
})
}
}
8 changes: 6 additions & 2 deletions internal/wrappers/mock/jwt-helper-mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ import (
)

type JWTMockWrapper struct {
AIEnabled int
AIEnabled int
CustomGetAllowedEngines func(wrappers.FeatureFlagsWrapper) (map[string]bool, error)
}

const AIProtectionDisabled = 1

// GetAllowedEngines mock for tests
func (*JWTMockWrapper) GetAllowedEngines(featureFlagsWrapper wrappers.FeatureFlagsWrapper) (allowedEngines map[string]bool, err error) {
func (j *JWTMockWrapper) GetAllowedEngines(featureFlagsWrapper wrappers.FeatureFlagsWrapper) (allowedEngines map[string]bool, err error) {
if j.CustomGetAllowedEngines != nil {
return j.CustomGetAllowedEngines(featureFlagsWrapper)
}
allowedEngines = make(map[string]bool)
engines := []string{"sast", "iac-security", "sca", "api-security", "containers", "scs"}
for _, value := range engines {
Expand Down