Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions internal/commands/asca/asca-engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/spf13/viper"
)

func RunScanASCACommand(jwtWrapper wrappers.JWTWrapper, featureFlagsWrapper wrappers.FeatureFlagsWrapper) func(cmd *cobra.Command, args []string) error {
func RunScanASCACommand(jwtWrapper wrappers.JWTWrapper) func(cmd *cobra.Command, args []string) error {
return func(cmd *cobra.Command, args []string) error {
ASCALatestVersion, _ := cmd.Flags().GetBool(commonParams.ASCALatestVersion)
fileSourceFlag, _ := cmd.Flags().GetString(commonParams.SourcesFlag)
Expand All @@ -23,9 +23,8 @@ func RunScanASCACommand(jwtWrapper wrappers.JWTWrapper, featureFlagsWrapper wrap
IsDefaultAgent: agent == commonParams.DefaultAgent,
}
wrapperParams := services.AscaWrappersParam{
JwtWrapper: jwtWrapper,
FeatureFlagsWrapper: featureFlagsWrapper,
ASCAWrapper: ASCAWrapper,
JwtWrapper: jwtWrapper,
ASCAWrapper: ASCAWrapper,
}
scanResult, err := services.CreateASCAScanRequest(ASCAParams, wrapperParams)
if err != nil {
Expand Down
7 changes: 3 additions & 4 deletions internal/commands/asca/asca-engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ func Test_ExecuteAscaScan(t *testing.T) {
IsDefaultAgent: true,
}
wrapperParams := services.AscaWrappersParam{
JwtWrapper: &mock.JWTMockWrapper{},
FeatureFlagsWrapper: &mock.FeatureFlagsMockWrapper{},
ASCAWrapper: &mock.ASCAMockWrapper{},
JwtWrapper: &mock.JWTMockWrapper{},
ASCAWrapper: &mock.ASCAMockWrapper{},
}
got, err := services.CreateASCAScanRequest(ASCAParams, wrapperParams)
if (err != nil) != ttt.wantErr {
Expand Down Expand Up @@ -129,7 +128,7 @@ func Test_runScanASCACommand(t *testing.T) {
cmd.Flags().String(commonParams.SourcesFlag, ttt.sourceFlag, "")
cmd.Flags().Bool(commonParams.ASCALatestVersion, ttt.engineFlag, "")
cmd.Flags().String(commonParams.FormatFlag, printer.FormatJSON, "")
runFunc := RunScanASCACommand(&mock.JWTMockWrapper{}, &mock.FeatureFlagsMockWrapper{})
runFunc := RunScanASCACommand(&mock.JWTMockWrapper{})
err := runFunc(cmd, []string{})
if (err != nil) != ttt.wantErr {
t.Errorf("RunScanASCACommand() error = %v, wantErr %v", err, ttt.wantErr)
Expand Down
2 changes: 1 addition & 1 deletion internal/commands/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ func scanASCASubCommand(jwtWrapper wrappers.JWTWrapper, featureFlagsWrapper wrap
`,
),
},
RunE: asca.RunScanASCACommand(jwtWrapper, featureFlagsWrapper),
RunE: asca.RunScanASCACommand(jwtWrapper),
}

scanASCACmd.PersistentFlags().Bool(commonParams.ASCALatestVersion, false,
Expand Down
7 changes: 3 additions & 4 deletions internal/services/asca.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ type AscaScanParams struct {
}

type AscaWrappersParam struct {
JwtWrapper wrappers.JWTWrapper
FeatureFlagsWrapper wrappers.FeatureFlagsWrapper
ASCAWrapper grpcs.AscaWrapper
JwtWrapper wrappers.JWTWrapper
ASCAWrapper grpcs.AscaWrapper
}

func CreateASCAScanRequest(ascaParams AscaScanParams, wrapperParams AscaWrappersParam) (*grpcs.ScanResult, error) {
Expand Down Expand Up @@ -165,7 +164,7 @@ func ensureASCAServiceRunning(wrappersParam AscaWrappersParam, ascaParams AscaSc

func checkLicense(isDefaultAgent bool, wrapperParams AscaWrappersParam) error {
if !isDefaultAgent {
allowed, err := wrapperParams.JwtWrapper.IsAllowedEngine(params.AIProtectionType, wrapperParams.FeatureFlagsWrapper)
allowed, err := wrapperParams.JwtWrapper.IsAllowedEngine(params.AIProtectionType)
if err != nil {
return err
}
Expand Down
43 changes: 10 additions & 33 deletions internal/services/asca_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"testing"

errorconstants "github.com/checkmarx/ast-cli/internal/constants/errors"
"github.com/checkmarx/ast-cli/internal/wrappers"
"github.com/checkmarx/ast-cli/internal/wrappers/grpcs"
"github.com/checkmarx/ast-cli/internal/wrappers/mock"
"github.com/stretchr/testify/assert"
Expand All @@ -18,9 +17,8 @@ func TestCreateASCAScanRequest_DefaultAgent_Success(t *testing.T) {
IsDefaultAgent: true,
}
wrapperParams := AscaWrappersParam{
JwtWrapper: &mock.JWTMockWrapper{},
FeatureFlagsWrapper: &mock.FeatureFlagsMockWrapper{},
ASCAWrapper: mock.NewASCAMockWrapper(1234),
JwtWrapper: &mock.JWTMockWrapper{},
ASCAWrapper: mock.NewASCAMockWrapper(1234),
}
sr, err := CreateASCAScanRequest(ASCAParams, wrapperParams)
if err != nil {
Expand All @@ -39,9 +37,8 @@ func TestCreateASCAScanRequest_DefaultAgentAndLatestVersionFlag_Success(t *testi
IsDefaultAgent: true,
}
wrapperParams := AscaWrappersParam{
JwtWrapper: &mock.JWTMockWrapper{},
FeatureFlagsWrapper: &mock.FeatureFlagsMockWrapper{},
ASCAWrapper: mock.NewASCAMockWrapper(1234),
JwtWrapper: &mock.JWTMockWrapper{},
ASCAWrapper: mock.NewASCAMockWrapper(1234),
}
sr, err := CreateASCAScanRequest(ASCAParams, wrapperParams)
if err != nil {
Expand All @@ -61,9 +58,8 @@ func TestCreateASCAScanRequest_SpecialAgentAndNoLicense_Fail(t *testing.T) {
IsDefaultAgent: false,
}
wrapperParams := AscaWrappersParam{
JwtWrapper: &mock.JWTMockWrapper{AIEnabled: mock.AIProtectionDisabled},
FeatureFlagsWrapper: &mock.FeatureFlagsMockWrapper{},
ASCAWrapper: &mock.ASCAMockWrapper{Port: specialErrorPort},
JwtWrapper: &mock.JWTMockWrapper{AIEnabled: mock.AIProtectionDisabled},
ASCAWrapper: &mock.ASCAMockWrapper{Port: specialErrorPort},
}
_, err := CreateASCAScanRequest(ASCAParams, wrapperParams)
assert.ErrorContains(t, err, errorconstants.NoASCALicense)
Expand All @@ -82,9 +78,8 @@ func TestCreateASCAScanRequest_EngineRunningAndSpecialAgentAndNoLicense_Fail(t *
}

wrapperParams := AscaWrappersParam{
JwtWrapper: &mock.JWTMockWrapper{},
FeatureFlagsWrapper: &mock.FeatureFlagsMockWrapper{},
ASCAWrapper: grpcs.NewASCAGrpcWrapper(port),
JwtWrapper: &mock.JWTMockWrapper{},
ASCAWrapper: grpcs.NewASCAGrpcWrapper(port),
}
err = manageASCAInstallation(ASCAParams, wrapperParams)
assert.Nil(t, err)
Expand Down Expand Up @@ -113,9 +108,8 @@ func TestCreateASCAScanRequest_EngineRunningAndDefaultAgentAndNoLicense_Success(
}

wrapperParams := AscaWrappersParam{
JwtWrapper: &mock.JWTMockWrapper{},
FeatureFlagsWrapper: &mock.FeatureFlagsMockWrapper{},
ASCAWrapper: grpcs.NewASCAGrpcWrapper(port),
JwtWrapper: &mock.JWTMockWrapper{},
ASCAWrapper: grpcs.NewASCAGrpcWrapper(port),
}
err = manageASCAInstallation(ASCAParams, wrapperParams)
assert.Nil(t, err)
Expand All @@ -131,20 +125,3 @@ func TestCreateASCAScanRequest_EngineRunningAndDefaultAgentAndNoLicense_Success(
assert.Nil(t, wrapperParams.ASCAWrapper.HealthCheck())
_ = wrapperParams.ASCAWrapper.ShutDown()
}

func TestCreateASCAScanRequest_whenCheckLicenseWithPackageEnforcementFFOff_shouldSuccess(t *testing.T) {
port, err := getAvailablePort()
if err != nil {
t.Fatalf("Failed to get available port: %v", err)
}

mock.Flag = wrappers.FeatureFlagResponseModel{Name: wrappers.PackageEnforcementEnabled, Status: false}

wrapperParams := AscaWrappersParam{
JwtWrapper: wrappers.NewJwtWrapper(),
FeatureFlagsWrapper: &mock.FeatureFlagsMockWrapper{},
ASCAWrapper: grpcs.NewASCAGrpcWrapper(port),
}
err = checkLicense(false, wrapperParams)
assert.Nil(t, err)
}
9 changes: 2 additions & 7 deletions internal/wrappers/jwt-helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var defaultEngines = map[string]bool{

type JWTWrapper interface {
GetAllowedEngines(featureFlagsWrapper FeatureFlagsWrapper) (allowedEngines map[string]bool, err error)
IsAllowedEngine(engine string, featureFlagsWrapper FeatureFlagsWrapper) (bool, error)
IsAllowedEngine(engine string) (bool, error)
ExtractTenantFromToken() (tenant string, err error)
}

Expand Down Expand Up @@ -65,12 +65,7 @@ func getJwtStruct() (*JWTStruct, error) {
}

// IsAllowedEngine will return if the engine is allowed in the user license
func (*JWTStruct) IsAllowedEngine(engine string, featureFlagsWrapper FeatureFlagsWrapper) (bool, error) {
flagResponse, _ := GetSpecificFeatureFlag(featureFlagsWrapper, PackageEnforcementEnabled)
if !flagResponse.Status {
return true, nil
}

func (*JWTStruct) IsAllowedEngine(engine string) (bool, error) {
jwtStruct, err := getJwtStruct()
if err != nil {
return false, err
Expand Down
2 changes: 1 addition & 1 deletion internal/wrappers/mock/jwt-helper-mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (*JWTMockWrapper) ExtractTenantFromToken() (tenant string, err error) {
}

// IsAllowedEngine mock for tests
func (j *JWTMockWrapper) IsAllowedEngine(engine string, featureFlagWrapper wrappers.FeatureFlagsWrapper) (bool, error) {
func (j *JWTMockWrapper) IsAllowedEngine(engine string) (bool, error) {
if j.AIEnabled == AIProtectionDisabled {
return false, nil
}
Expand Down
Loading