diff --git a/aws/resource_registry.go b/aws/resource_registry.go index ff90682a..fd84bead 100644 --- a/aws/resource_registry.go +++ b/aws/resource_registry.go @@ -56,9 +56,9 @@ func getRegisteredRegionalResources() []AwsResource { resources.NewACM(), resources.NewACMPCA(), resources.NewAMIs(), - &resources.ApiGateway{}, - &resources.ApiGatewayV2{}, - &resources.ASGroups{}, + resources.NewApiGateway(), + resources.NewApiGatewayV2(), + resources.NewASGroups(), resources.NewAppRunnerService(), resources.NewBackupVault(), resources.NewManagedPrometheus(), @@ -75,57 +75,57 @@ func getRegisteredRegionalResources() []AwsResource { resources.NewCloudWatchLogGroups(), &resources.CloudMapServices{}, &resources.CloudMapNamespaces{}, - &resources.CodeDeployApplications{}, + resources.NewCodeDeployApplications(), resources.NewConfigServiceRecorders(), &resources.ConfigServiceRule{}, resources.NewDataSyncTask(), resources.NewDataSyncLocation(), resources.NewDynamoDB(), - &resources.EBSVolumes{}, + resources.NewEBSVolumes(), resources.NewEBApplications(), - &resources.EC2Instances{}, - &resources.EC2DedicatedHosts{}, + resources.NewEC2Instances(), + resources.NewEC2DedicatedHosts(), resources.NewEC2KeyPairs(), resources.NewEC2PlacementGroups(), - &resources.TransitGateways{}, + resources.NewTransitGateways(), resources.NewTransitGatewaysRouteTables(), // Note: nuking transitgateway vpc attachement before nuking the vpc since vpc could be associated with it. resources.NewTransitGatewayPeeringAttachment(), - &resources.TransitGatewaysVpcAttachment{}, - &resources.EC2Endpoints{}, + resources.NewTransitGatewaysVpcAttachment(), + resources.NewEC2Endpoints(), resources.NewECR(), - &resources.ECSClusters{}, - &resources.ECSServices{}, + resources.NewECSClusters(), + resources.NewECSServices(), &resources.EgressOnlyInternetGateway{}, - &resources.ElasticFileSystem{}, + resources.NewElasticFileSystem(), resources.NewEIPAddresses(), - &resources.EKSClusters{}, + resources.NewEKSClusters(), &resources.ElasticCacheServerless{}, - &resources.Elasticaches{}, - &resources.ElasticacheParameterGroups{}, - &resources.ElasticacheSubnetGroups{}, - &resources.LoadBalancers{}, - &resources.LoadBalancersV2{}, + resources.NewElasticaches(), + resources.NewElasticacheParameterGroups(), + resources.NewElasticacheSubnetGroups(), + resources.NewLoadBalancers(), + resources.NewLoadBalancersV2(), resources.NewGuardDuty(), resources.NewKinesisFirehose(), resources.NewKinesisStreams(), - &resources.KmsCustomerKeys{}, - &resources.LambdaFunctions{}, - &resources.LambdaLayers{}, + resources.NewKmsCustomerKeys(), + resources.NewLambdaFunctions(), + resources.NewLambdaLayers(), resources.NewLaunchConfigs(), resources.NewLaunchTemplates(), &resources.MacieMember{}, - &resources.MSKCluster{}, + resources.NewMSKCluster(), resources.NewNatGateways(), - &resources.OpenSearchDomains{}, + resources.NewOpenSearchDomains(), &resources.DBGlobalClusterMemberships{}, - &resources.DBInstances{}, - &resources.DBSubnetGroups{}, - &resources.DBClusters{}, + resources.NewDBInstances(), + resources.NewDBSubnetGroups(), + resources.NewDBClusters(), resources.NewRdsProxy(), resources.NewRdsSnapshot(), - &resources.RdsParameterGroup{}, - &resources.RedshiftClusters{}, + resources.NewRdsParameterGroup(), + resources.NewRedshiftClusters(), &resources.RedshiftSnapshotCopyGrants{}, &resources.S3Buckets{}, &resources.S3AccessPoint{}, @@ -135,7 +135,7 @@ func getRegisteredRegionalResources() []AwsResource { &resources.SageMakerStudio{}, &resources.SageMakerEndpoint{}, resources.NewSecretsManagerSecrets(), - &resources.SecurityHub{}, + resources.NewSecurityHub(), resources.NewSesConfigurationSet(), resources.NewSesEmailTemplates(), resources.NewSesIdentities(), @@ -160,9 +160,9 @@ func getRegisteredRegionalResources() []AwsResource { &resources.NetworkFirewallRuleGroup{}, &resources.NetworkFirewallTLSConfig{}, &resources.NetworkFirewallResourcePolicy{}, - &resources.VPCLatticeServiceNetwork{}, + resources.NewVPCLatticeServiceNetwork(), &resources.VPCLatticeService{}, - &resources.VPCLatticeTargetGroup{}, + resources.NewVPCLatticeTargetGroup(), // Note: VPCs must be deleted last after all resources that create network interfaces (EKS, ECS, etc.) &resources.EC2VPCs{}, // Note: nuking EC2 DHCP options after nuking EC2 VPC because DHCP options could be associated with VPCs. diff --git a/aws/resources/apigateway.go b/aws/resources/apigateway.go index 0fe8993b..805789da 100644 --- a/aws/resources/apigateway.go +++ b/aws/resources/apigateway.go @@ -9,19 +9,52 @@ import ( "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" "github.com/hashicorp/go-multierror" ) -func (gateway *ApiGateway) getAll(c context.Context, configObj config.Config) ([]*string, error) { - result, err := gateway.Client.GetRestApis(c, &apigateway.GetRestApisInput{}) +// ApiGatewayAPI defines the interface for API Gateway (v1) operations. +type ApiGatewayAPI interface { + GetRestApis(ctx context.Context, params *apigateway.GetRestApisInput, optFns ...func(*apigateway.Options)) (*apigateway.GetRestApisOutput, error) + GetDomainNames(ctx context.Context, params *apigateway.GetDomainNamesInput, optFns ...func(*apigateway.Options)) (*apigateway.GetDomainNamesOutput, error) + GetBasePathMappings(ctx context.Context, params *apigateway.GetBasePathMappingsInput, optFns ...func(*apigateway.Options)) (*apigateway.GetBasePathMappingsOutput, error) + DeleteBasePathMapping(ctx context.Context, params *apigateway.DeleteBasePathMappingInput, optFns ...func(*apigateway.Options)) (*apigateway.DeleteBasePathMappingOutput, error) + DeleteRestApi(ctx context.Context, params *apigateway.DeleteRestApiInput, optFns ...func(*apigateway.Options)) (*apigateway.DeleteRestApiOutput, error) +} + +// NewApiGateway creates a new ApiGateway resource using the generic resource pattern. +func NewApiGateway() AwsResource { + return NewAwsResource(&resource.Resource[ApiGatewayAPI]{ + ResourceTypeName: "apigateway", + BatchSize: 10, + InitClient: func(r *resource.Resource[ApiGatewayAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for ApiGateway client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = apigateway.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.APIGateway + }, + Lister: listApiGateways, + Nuker: deleteApiGateways, + }) +} + +// listApiGateways retrieves all API Gateway (v1) REST APIs that match the config filters. +func listApiGateways(ctx context.Context, client ApiGatewayAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + result, err := client.GetRestApis(ctx, &apigateway.GetRestApisInput{}) if err != nil { return []*string{}, errors.WithStackTrace(err) } var IDs []*string for _, api := range result.Items { - if configObj.APIGateway.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Name: api.Name, Time: api.CreatedDate, }) { @@ -32,25 +65,26 @@ func (gateway *ApiGateway) getAll(c context.Context, configObj config.Config) ([ return IDs, nil } -func (gateway *ApiGateway) nukeAll(identifiers []*string) error { +// deleteApiGateways deletes the provided API Gateway (v1) REST APIs. +func deleteApiGateways(ctx context.Context, client ApiGatewayAPI, scope resource.Scope, resourceType string, identifiers []*string) error { if len(identifiers) == 0 { - logging.Debugf("No API Gateways (v1) to nuke in region %s", gateway.Region) + logging.Debugf("No API Gateways (v1) to nuke in region %s", scope.Region) + return nil } if len(identifiers) > 100 { - logging.Errorf("Nuking too many API Gateways (v1) at once (100): " + - "halting to avoid hitting AWS API rate limiting") + logging.Errorf("Nuking too many API Gateways (v1) at once (100): halting to avoid hitting AWS API rate limiting") return TooManyApiGatewayErr{} } // There is no bulk delete Api Gateway API, so we delete the batch of gateways concurrently using goroutines - logging.Debugf("Deleting Api Gateways (v1) in region %s", gateway.Region) + logging.Debugf("Deleting Api Gateways (v1) in region %s", scope.Region) wg := new(sync.WaitGroup) wg.Add(len(identifiers)) errChans := make([]chan error, len(identifiers)) for i, apigwID := range identifiers { errChans[i] = make(chan error, 1) - go gateway.nukeAsync(wg, errChans[i], apigwID) + go deleteApiGatewayAsync(ctx, client, scope.Region, wg, errChans[i], apigwID) } wg.Wait() @@ -69,59 +103,16 @@ func (gateway *ApiGateway) nukeAll(identifiers []*string) error { return nil } -func (gateway *ApiGateway) getAttachedStageClientCerts(apigwID *string) ([]*string, error) { - var clientCerts []*string - - // remove the client certificate attached with the stages - stages, err := gateway.Client.GetStages(gateway.Context, &apigateway.GetStagesInput{ - RestApiId: apigwID, - }) - - if err != nil { - return nil, err - } - // get the stages attached client certificates - for _, stage := range stages.Item { - if stage.ClientCertificateId == nil { - logging.Debugf("Skipping certyficate for stage %s, certyficate ID is nil", *stage.StageName) - continue - } - clientCerts = append(clientCerts, stage.ClientCertificateId) - } - return clientCerts, nil -} - -func (gateway *ApiGateway) removeAttachedClientCertificates(clientCerts []*string) error { - - for _, cert := range clientCerts { - logging.Debugf("Deleting Client Certificate %s", *cert) - _, err := gateway.Client.DeleteClientCertificate(gateway.Context, &apigateway.DeleteClientCertificateInput{ - ClientCertificateId: cert, - }) - if err != nil { - logging.Errorf("[Failed] Error deleting Client Certificate %s", *cert) - return err - } - } - return nil -} - -func (gateway *ApiGateway) nukeAsync( +func deleteApiGatewayAsync( + ctx context.Context, client ApiGatewayAPI, region string, wg *sync.WaitGroup, errChan chan error, apigwID *string, ) { defer wg.Done() var err error - // Why defer? - // Defer error reporting, channel sending, and logging to ensure they run - // after function execution completes, regardless of success or failure. - // This ensures consistent reporting, prevents missed logs, and avoids - // duplicated code paths for error/success handling. - // - // See: https://go.dev/ref/spec#Defer_statements + // Defer error reporting, channel sending, and logging defer func() { - // send the error data to channel errChan <- err // Record status of this resource @@ -133,45 +124,34 @@ func (gateway *ApiGateway) nukeAsync( report.Record(e) if err == nil { - logging.Debugf("[OK] API Gateway (v1) %s deleted in %s", aws.ToString(apigwID), gateway.Region) + logging.Debugf("[OK] API Gateway (v1) %s deleted in %s", aws.ToString(apigwID), region) } else { - logging.Debugf("[Failed] Error deleting API Gateway (v1) %s in %s", aws.ToString(apigwID), gateway.Region) + logging.Debugf("[Failed] Error deleting API Gateway (v1) %s in %s", aws.ToString(apigwID), region) } }() - // get the attached client certificates - var clientCerts []*string - clientCerts, err = gateway.getAttachedStageClientCerts(apigwID) - if err != nil { - return - } - // Check if the API Gateway has any associated API mappings. // If so, remove them before deleting the API Gateway. - err = gateway.deleteAssociatedApiMappings(context.Background(), []*string{apigwID}) + err = deleteAssociatedApiMappingsV1(ctx, client, []*string{apigwID}) if err != nil { return } // delete the API Gateway input := &apigateway.DeleteRestApiInput{RestApiId: apigwID} - _, err = gateway.Client.DeleteRestApi(gateway.Context, input) - if err != nil { - return - } - - // When the rest-api endpoint delete successfully, then remove attached client certs - err = gateway.removeAttachedClientCertificates(clientCerts) + _, err = client.DeleteRestApi(ctx, input) } -func (gateway *ApiGateway) deleteAssociatedApiMappings(ctx context.Context, identifiers []*string) error { +// deleteAssociatedApiMappingsV1 deletes API mappings for API Gateway v1. +// Named with V1 suffix to avoid conflict with similar function in apigatewayv2.go. +func deleteAssociatedApiMappingsV1(ctx context.Context, client ApiGatewayAPI, identifiers []*string) error { // Convert identifiers to map to check if identifier is in list identifierMap := make(map[string]struct{}) for _, identifier := range identifiers { identifierMap[*identifier] = struct{}{} } - domainNames, err := gateway.Client.GetDomainNames(ctx, &apigateway.GetDomainNamesInput{}) + domainNames, err := client.GetDomainNames(ctx, &apigateway.GetDomainNamesInput{}) if err != nil { logging.Debugf("Failed to get domain names: %s", err) return errors.WithStackTrace(err) @@ -180,7 +160,7 @@ func (gateway *ApiGateway) deleteAssociatedApiMappings(ctx context.Context, iden logging.Debugf("Found %d domain name(s)", len(domainNames.Items)) for _, domain := range domainNames.Items { - apiMappings, err := gateway.Client.GetBasePathMappings(ctx, &apigateway.GetBasePathMappingsInput{ + apiMappings, err := client.GetBasePathMappings(ctx, &apigateway.GetBasePathMappingsInput{ DomainName: domain.DomainName, }) @@ -202,7 +182,7 @@ func (gateway *ApiGateway) deleteAssociatedApiMappings(ctx context.Context, iden logging.Debugf("Deleting base path mapping for API %s on domain %s", *mapping.RestApiId, *domain.DomainName) - _, err := gateway.Client.DeleteBasePathMapping(ctx, &apigateway.DeleteBasePathMappingInput{ + _, err := client.DeleteBasePathMapping(ctx, &apigateway.DeleteBasePathMappingInput{ DomainName: domain.DomainName, BasePath: mapping.BasePath, }) @@ -218,3 +198,10 @@ func (gateway *ApiGateway) deleteAssociatedApiMappings(ctx context.Context, iden logging.Debug("Completed deletion of matching API mappings.") return nil } + +// TooManyApiGatewayErr is returned when too many API Gateways are requested at once. +type TooManyApiGatewayErr struct{} + +func (err TooManyApiGatewayErr) Error() string { + return "Too many Api Gateways requested at once." +} diff --git a/aws/resources/apigateway_test.go b/aws/resources/apigateway_test.go index 6f99b66e..1e0df673 100644 --- a/aws/resources/apigateway_test.go +++ b/aws/resources/apigateway_test.go @@ -9,18 +9,16 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/aws/aws-sdk-go-v2/service/apigateway/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type mockedApiGateway struct { - ApiGatewayServiceAPI - GetRestApisOutput apigateway.GetRestApisOutput - GetStagesOutput apigateway.GetStagesOutput - DeleteClientCertificateOutput apigateway.DeleteClientCertificateOutput - DeleteRestApiOutput apigateway.DeleteRestApiOutput - + ApiGatewayAPI + GetRestApisOutput apigateway.GetRestApisOutput + DeleteRestApiOutput apigateway.DeleteRestApiOutput GetDomainNamesOutput apigateway.GetDomainNamesOutput GetBasePathMappingsOutput apigateway.GetBasePathMappingsOutput DeleteBasePathMappingOutput apigateway.DeleteBasePathMappingOutput @@ -29,13 +27,6 @@ type mockedApiGateway struct { func (m mockedApiGateway) GetRestApis(ctx context.Context, params *apigateway.GetRestApisInput, optFns ...func(*apigateway.Options)) (*apigateway.GetRestApisOutput, error) { return &m.GetRestApisOutput, nil } -func (m mockedApiGateway) GetStages(ctx context.Context, params *apigateway.GetStagesInput, optFns ...func(*apigateway.Options)) (*apigateway.GetStagesOutput, error) { - return &m.GetStagesOutput, nil -} - -func (m mockedApiGateway) DeleteClientCertificate(ctx context.Context, params *apigateway.DeleteClientCertificateInput, optFns ...func(*apigateway.Options)) (*apigateway.DeleteClientCertificateOutput, error) { - return &m.DeleteClientCertificateOutput, nil -} func (m mockedApiGateway) DeleteRestApi(ctx context.Context, params *apigateway.DeleteRestApiInput, optFns ...func(*apigateway.Options)) (*apigateway.DeleteRestApiOutput, error) { return &m.DeleteRestApiOutput, nil @@ -52,27 +43,25 @@ func (m mockedApiGateway) GetBasePathMappings(ctx context.Context, params *apiga func (m mockedApiGateway) DeleteBasePathMapping(ctx context.Context, params *apigateway.DeleteBasePathMappingInput, optFns ...func(*apigateway.Options)) (*apigateway.DeleteBasePathMappingOutput, error) { return &m.DeleteBasePathMappingOutput, nil } + func TestAPIGatewayGetAllAndNukeAll(t *testing.T) { t.Parallel() testApiID := "aws-nuke-test-" + util.UniqueID() - apiGateway := ApiGateway{ - Client: mockedApiGateway{ - ApiGatewayServiceAPI: nil, - GetRestApisOutput: apigateway.GetRestApisOutput{ - Items: []types.RestApi{ - {Id: aws.String(testApiID)}, - }, + mockClient := mockedApiGateway{ + GetRestApisOutput: apigateway.GetRestApisOutput{ + Items: []types.RestApi{ + {Id: aws.String(testApiID)}, }, - DeleteRestApiOutput: apigateway.DeleteRestApiOutput{}, }, + DeleteRestApiOutput: apigateway.DeleteRestApiOutput{}, } - apis, err := apiGateway.getAll(context.Background(), config.Config{}) + apis, err := listApiGateways(context.Background(), mockClient, resource.Scope{Region: "us-east-1"}, config.ResourceType{}) require.NoError(t, err) require.Contains(t, aws.ToStringSlice(apis), testApiID) - err = apiGateway.nukeAll([]*string{aws.String(testApiID)}) + err = deleteApiGateways(context.Background(), mockClient, resource.Scope{Region: "us-east-1"}, "apigateway", []*string{aws.String(testApiID)}) require.NoError(t, err) } @@ -81,34 +70,28 @@ func TestAPIGatewayGetAllTimeFilter(t *testing.T) { testApiID := "aws-nuke-test-" + util.UniqueID() now := time.Now() - apiGateway := ApiGateway{ - Client: mockedApiGateway{ - GetRestApisOutput: apigateway.GetRestApisOutput{ - Items: []types.RestApi{{ - Id: aws.String(testApiID), - CreatedDate: aws.Time(now), - }}, - }, + mockClient := mockedApiGateway{ + GetRestApisOutput: apigateway.GetRestApisOutput{ + Items: []types.RestApi{{ + Id: aws.String(testApiID), + CreatedDate: aws.Time(now), + }}, }, } // test API is not excluded from the filter - IDs, err := apiGateway.getAll(context.Background(), config.Config{ - APIGateway: config.ResourceType{ - ExcludeRule: config.FilterRule{ - TimeAfter: aws.Time(now.Add(1)), - }, + IDs, err := listApiGateways(context.Background(), mockClient, resource.Scope{Region: "us-east-1"}, config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now.Add(1)), }, }) require.NoError(t, err) assert.Contains(t, aws.ToStringSlice(IDs), testApiID) // test API being excluded from the filter - apiGwIdsOlder, err := apiGateway.getAll(context.Background(), config.Config{ - APIGateway: config.ResourceType{ - ExcludeRule: config.FilterRule{ - TimeAfter: aws.Time(now.Add(-1)), - }, + apiGwIdsOlder, err := listApiGateways(context.Background(), mockClient, resource.Scope{Region: "us-east-1"}, config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now.Add(-1)), }, }) require.NoError(t, err) @@ -120,64 +103,26 @@ func TestNukeAPIGatewayMoreThanOne(t *testing.T) { testApiID1 := "aws-nuke-test-" + util.UniqueID() testApiID2 := "aws-nuke-test-" + util.UniqueID() - apiGateway := ApiGateway{ - Client: mockedApiGateway{ - GetRestApisOutput: apigateway.GetRestApisOutput{ - Items: []types.RestApi{ - {Id: aws.String(testApiID1)}, - {Id: aws.String(testApiID2)}, - }, + mockClient := mockedApiGateway{ + GetRestApisOutput: apigateway.GetRestApisOutput{ + Items: []types.RestApi{ + {Id: aws.String(testApiID1)}, + {Id: aws.String(testApiID2)}, }, - DeleteRestApiOutput: apigateway.DeleteRestApiOutput{}, }, + DeleteRestApiOutput: apigateway.DeleteRestApiOutput{}, } - apis, err := apiGateway.getAll(context.Background(), config.Config{}) + apis, err := listApiGateways(context.Background(), mockClient, resource.Scope{Region: "us-east-1"}, config.ResourceType{}) require.NoError(t, err) require.Contains(t, aws.ToStringSlice(apis), testApiID1) require.Contains(t, aws.ToStringSlice(apis), testApiID2) - err = apiGateway.nukeAll([]*string{aws.String(testApiID1), aws.String(testApiID2)}) + err = deleteApiGateways(context.Background(), mockClient, resource.Scope{Region: "us-east-1"}, "apigateway", []*string{aws.String(testApiID1), aws.String(testApiID2)}) require.NoError(t, err) } -func TestNukeAPIGatewayWithCertificates(t *testing.T) { - t.Parallel() - - testApiID1 := "aws-nuke-test-" + util.UniqueID() - testApiID2 := "aws-nuke-test-" + util.UniqueID() - - clientCertID := "aws-client-cert" + util.UniqueID() - apiGateway := ApiGateway{ - Client: mockedApiGateway{ - GetRestApisOutput: apigateway.GetRestApisOutput{ - Items: []types.RestApi{ - {Id: aws.String(testApiID1)}, - {Id: aws.String(testApiID2)}, - }, - }, - GetStagesOutput: apigateway.GetStagesOutput{ - Item: []types.Stage{ - { - ClientCertificateId: aws.String(clientCertID), - }, - }, - }, - DeleteClientCertificateOutput: apigateway.DeleteClientCertificateOutput{}, - DeleteRestApiOutput: apigateway.DeleteRestApiOutput{}, - }, - } - - apis, err := apiGateway.getAll(context.Background(), config.Config{}) - require.NoError(t, err) - require.Contains(t, aws.ToStringSlice(apis), testApiID1) - require.Contains(t, aws.ToStringSlice(apis), testApiID2) - - err = apiGateway.nukeAll([]*string{aws.String(testApiID1), aws.String(testApiID2)}) - require.NoError(t, err) -} - -func TestDeleteAssociatedApiMappings(t *testing.T) { +func TestDeleteAssociatedApiMappingsV1(t *testing.T) { t.Parallel() apiIDToDelete := "test-api-id" @@ -205,10 +150,6 @@ func TestDeleteAssociatedApiMappings(t *testing.T) { }, } - apiGateway := ApiGateway{ - Client: mockClient, - } - - err := apiGateway.deleteAssociatedApiMappings(context.Background(), []*string{aws.String(apiIDToDelete)}) + err := deleteAssociatedApiMappingsV1(context.Background(), mockClient, []*string{aws.String(apiIDToDelete)}) require.NoError(t, err) } diff --git a/aws/resources/apigateway_types.go b/aws/resources/apigateway_types.go deleted file mode 100644 index 306da14d..00000000 --- a/aws/resources/apigateway_types.go +++ /dev/null @@ -1,71 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/apigateway" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type ApiGatewayServiceAPI interface { - GetRestApis(ctx context.Context, params *apigateway.GetRestApisInput, optFns ...func(*apigateway.Options)) (*apigateway.GetRestApisOutput, error) - GetStages(ctx context.Context, params *apigateway.GetStagesInput, optFns ...func(*apigateway.Options)) (*apigateway.GetStagesOutput, error) - GetDomainNames(ctx context.Context, params *apigateway.GetDomainNamesInput, optFns ...func(*apigateway.Options)) (*apigateway.GetDomainNamesOutput, error) - GetBasePathMappings(ctx context.Context, params *apigateway.GetBasePathMappingsInput, optFns ...func(*apigateway.Options)) (*apigateway.GetBasePathMappingsOutput, error) - DeleteBasePathMapping(ctx context.Context, params *apigateway.DeleteBasePathMappingInput, optFns ...func(*apigateway.Options)) (*apigateway.DeleteBasePathMappingOutput, error) - DeleteClientCertificate(ctx context.Context, params *apigateway.DeleteClientCertificateInput, optFns ...func(*apigateway.Options)) (*apigateway.DeleteClientCertificateOutput, error) - DeleteRestApi(ctx context.Context, params *apigateway.DeleteRestApiInput, optFns ...func(*apigateway.Options)) (*apigateway.DeleteRestApiOutput, error) -} - -type ApiGateway struct { - BaseAwsResource - Client ApiGatewayServiceAPI - Region string - Ids []string -} - -func (gateway *ApiGateway) Init(cfg aws.Config) { - gateway.Client = apigateway.NewFromConfig(cfg) -} - -func (gateway *ApiGateway) ResourceName() string { - return "apigateway" -} - -func (gateway *ApiGateway) ResourceIdentifiers() []string { - return gateway.Ids -} - -func (gateway *ApiGateway) MaxBatchSize() int { - return 10 -} - -func (gateway *ApiGateway) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.APIGateway -} - -func (gateway *ApiGateway) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := gateway.getAll(c, configObj) - if err != nil { - return nil, err - } - - gateway.Ids = aws.ToStringSlice(identifiers) - return gateway.Ids, nil -} - -func (gateway *ApiGateway) Nuke(ctx context.Context, identifiers []string) error { - if err := gateway.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} - -type TooManyApiGatewayErr struct{} - -func (err TooManyApiGatewayErr) Error() string { - return "Too many Api Gateways requested at once." -} diff --git a/aws/resources/apigatewayv2.go b/aws/resources/apigatewayv2.go index 22a9ae61..612d7faa 100644 --- a/aws/resources/apigatewayv2.go +++ b/aws/resources/apigatewayv2.go @@ -2,7 +2,6 @@ package resources import ( "context" - "fmt" "sync" "github.com/aws/aws-sdk-go-v2/aws" @@ -10,54 +9,89 @@ import ( "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" "github.com/hashicorp/go-multierror" ) -func (gw *ApiGatewayV2) getAll(ctx context.Context, configObj config.Config) ([]*string, error) { - output, err := gw.Client.GetApis(gw.Context, &apigatewayv2.GetApisInput{}) +// ApiGatewayV2API defines the interface for API Gateway V2 operations. +type ApiGatewayV2API interface { + GetApis(ctx context.Context, params *apigatewayv2.GetApisInput, optFns ...func(*apigatewayv2.Options)) (*apigatewayv2.GetApisOutput, error) + GetDomainNames(ctx context.Context, params *apigatewayv2.GetDomainNamesInput, optFns ...func(*apigatewayv2.Options)) (*apigatewayv2.GetDomainNamesOutput, error) + GetApiMappings(ctx context.Context, params *apigatewayv2.GetApiMappingsInput, optFns ...func(*apigatewayv2.Options)) (*apigatewayv2.GetApiMappingsOutput, error) + DeleteApi(ctx context.Context, params *apigatewayv2.DeleteApiInput, optFns ...func(*apigatewayv2.Options)) (*apigatewayv2.DeleteApiOutput, error) + DeleteApiMapping(ctx context.Context, params *apigatewayv2.DeleteApiMappingInput, optFns ...func(*apigatewayv2.Options)) (*apigatewayv2.DeleteApiMappingOutput, error) +} + +// NewApiGatewayV2 creates a new ApiGatewayV2 resource using the generic resource pattern. +func NewApiGatewayV2() AwsResource { + return NewAwsResource(&resource.Resource[ApiGatewayV2API]{ + ResourceTypeName: "apigatewayv2", + BatchSize: 10, + InitClient: func(r *resource.Resource[ApiGatewayV2API], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for ApiGatewayV2 client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = apigatewayv2.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.APIGatewayV2 + }, + Lister: listApiGatewaysV2, + Nuker: deleteApiGatewaysV2, + }) +} + +// listApiGatewaysV2 retrieves all API Gateways V2 that match the config filters. +func listApiGatewaysV2(ctx context.Context, client ApiGatewayV2API, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + output, err := client.GetApis(ctx, &apigatewayv2.GetApisInput{}) if err != nil { return []*string{}, errors.WithStackTrace(err) } - var Ids []*string + var ids []*string for _, restapi := range output.Items { - if configObj.APIGatewayV2.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Time: restapi.CreatedDate, Name: restapi.Name, Tags: restapi.Tags, }) { - Ids = append(Ids, restapi.ApiId) + ids = append(ids, restapi.ApiId) } } - return Ids, nil + return ids, nil } -func (gw *ApiGatewayV2) nukeAll(identifiers []*string) error { +// deleteApiGatewaysV2 is a custom nuker for API Gateway V2 resources. +// It first deletes associated API mappings, then deletes the APIs themselves. +func deleteApiGatewaysV2(ctx context.Context, client ApiGatewayV2API, scope resource.Scope, resourceType string, identifiers []*string) error { if len(identifiers) == 0 { - logging.Debug(fmt.Sprintf("No API Gateways (v2) to nuke in region %s", gw.Region)) + logging.Debugf("No API Gateways (v2) to nuke in %s", scope) + return nil } if len(identifiers) > 100 { - logging.Debug(fmt.Sprintf( - "Nuking too many API Gateways (v2) at once (100): halting to avoid hitting AWS API rate limiting")) + logging.Debugf("Nuking too many API Gateways (v2) at once (100): halting to avoid hitting AWS API rate limiting") return TooManyApiGatewayV2Err{} } - err := deleteAssociatedApiMappings(gw.Context, gw.Client, identifiers) + err := deleteAssociatedApiMappings(ctx, client, identifiers) if err != nil { return errors.WithStackTrace(err) } // There is no bulk delete Api Gateway API, so we delete the batch of gateways concurrently using goroutines - logging.Debug(fmt.Sprintf("Deleting Api Gateways (v2) in region %s", gw.Region)) + logging.Debugf("Deleting Api Gateways (v2) in %s", scope) wg := new(sync.WaitGroup) wg.Add(len(identifiers)) errChans := make([]chan error, len(identifiers)) for i, apigwID := range identifiers { errChans[i] = make(chan error, 1) - go gw.deleteAsync(wg, errChans[i], apigwID) + go deleteApiGatewayV2Async(ctx, client, scope, resourceType, wg, errChans[i], apigwID) } wg.Wait() @@ -74,25 +108,25 @@ func (gw *ApiGatewayV2) nukeAll(identifiers []*string) error { return nil } -func (gw *ApiGatewayV2) deleteAsync(wg *sync.WaitGroup, errChan chan error, apiId *string) { +func deleteApiGatewayV2Async(ctx context.Context, client ApiGatewayV2API, scope resource.Scope, resourceType string, wg *sync.WaitGroup, errChan chan error, apiId *string) { defer wg.Done() input := &apigatewayv2.DeleteApiInput{ApiId: apiId} - _, err := gw.Client.DeleteApi(gw.Context, input) + _, err := client.DeleteApi(ctx, input) errChan <- err // Record status of this resource e := report.Entry{ Identifier: *apiId, - ResourceType: "APIGateway (v2)", + ResourceType: resourceType, Error: err, } report.Record(e) if err == nil { - logging.Debug(fmt.Sprintf("Successfully deleted API Gateway (v2) %s in %s", aws.ToString(apiId), gw.Region)) + logging.Debugf("Successfully deleted API Gateway (v2) %s in %s", aws.ToString(apiId), scope) } else { - logging.Debug(fmt.Sprintf("Failed to delete API Gateway (v2) %s in %s", aws.ToString(apiId), gw.Region)) + logging.Debugf("Failed to delete API Gateway (v2) %s in %s", aws.ToString(apiId), scope) } } @@ -105,17 +139,17 @@ func deleteAssociatedApiMappings(ctx context.Context, client ApiGatewayV2API, id domainNames, err := client.GetDomainNames(ctx, &apigatewayv2.GetDomainNamesInput{}) if err != nil { - logging.Debug(fmt.Sprintf("Failed to get domain names: %s", err)) + logging.Debugf("Failed to get domain names: %s", err) return errors.WithStackTrace(err) } - logging.Debug(fmt.Sprintf("Found %d domain names", len(domainNames.Items))) + logging.Debugf("Found %d domain names", len(domainNames.Items)) for _, domainName := range domainNames.Items { apiMappings, err := client.GetApiMappings(ctx, &apigatewayv2.GetApiMappingsInput{ DomainName: domainName.DomainName, }) if err != nil { - logging.Debug(fmt.Sprintf("Failed to get api mappings: %s", err)) + logging.Debugf("Failed to get api mappings: %s", err) return errors.WithStackTrace(err) } @@ -129,13 +163,19 @@ func deleteAssociatedApiMappings(ctx context.Context, client ApiGatewayV2API, id DomainName: domainName.DomainName, }) if err != nil { - logging.Debug(fmt.Sprintf("Failed to delete api mapping: %s", err)) + logging.Debugf("Failed to delete api mapping: %s", err) return errors.WithStackTrace(err) } - logging.Debug(fmt.Sprintf("Deleted api mapping: %s", *apiMapping.ApiMappingId)) + logging.Debugf("Deleted api mapping: %s", *apiMapping.ApiMappingId) } } return nil } + +type TooManyApiGatewayV2Err struct{} + +func (err TooManyApiGatewayV2Err) Error() string { + return "Too many Api Gateways requested at once." +} diff --git a/aws/resources/apigatewayv2_test.go b/aws/resources/apigatewayv2_test.go index 53e86363..4be1b14b 100644 --- a/aws/resources/apigatewayv2_test.go +++ b/aws/resources/apigatewayv2_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigatewayv2" "github.com/aws/aws-sdk-go-v2/service/apigatewayv2/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/assert" ) @@ -48,71 +49,63 @@ func TestApiGatewayV2GetAll(t *testing.T) { testApiID := "test-api-id" testApiName := "test-api-name" now := time.Now() - gw := ApiGatewayV2{ - Client: mockedApiGatewayV2{ - GetApisOutput: apigatewayv2.GetApisOutput{ - Items: []types.Api{ - { - ApiId: aws.String(testApiID), - Name: aws.String(testApiName), - CreatedDate: aws.Time(now), - }, + client := mockedApiGatewayV2{ + GetApisOutput: apigatewayv2.GetApisOutput{ + Items: []types.Api{ + { + ApiId: aws.String(testApiID), + Name: aws.String(testApiName), + CreatedDate: aws.Time(now), }, }, }, } // empty filter - apis, err := gw.getAll(context.Background(), config.Config{}) + apis, err := listApiGatewaysV2(context.Background(), client, resource.Scope{Region: "us-east-1"}, config.ResourceType{}) assert.NoError(t, err) assert.Contains(t, aws.ToStringSlice(apis), testApiID) // filter by name - apis, err = gw.getAll(context.Background(), config.Config{ - APIGatewayV2: config.ResourceType{ - ExcludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{{ - RE: *regexp.MustCompile("test-api-name"), - }}}}}) + apis, err = listApiGatewaysV2(context.Background(), client, resource.Scope{Region: "us-east-1"}, config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile("test-api-name"), + }}}}) assert.NoError(t, err) assert.NotContains(t, aws.ToStringSlice(apis), testApiID) // filter by date - apis, err = gw.getAll(context.Background(), config.Config{ - APIGatewayV2: config.ResourceType{ - ExcludeRule: config.FilterRule{ - TimeAfter: aws.Time(now.Add(-1))}}}) + apis, err = listApiGatewaysV2(context.Background(), client, resource.Scope{Region: "us-east-1"}, config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now.Add(-1))}}) assert.NoError(t, err) assert.NotContains(t, aws.ToStringSlice(apis), testApiID) // filter by tags - gwWithTags := ApiGatewayV2{ - Client: mockedApiGatewayV2{ - GetApisOutput: apigatewayv2.GetApisOutput{ - Items: []types.Api{{ - ApiId: aws.String(testApiID), - Name: aws.String(testApiName), - CreatedDate: aws.Time(now), - Tags: map[string]string{"Environment": "production"}, - }}, - }, + clientWithTags := mockedApiGatewayV2{ + GetApisOutput: apigatewayv2.GetApisOutput{ + Items: []types.Api{{ + ApiId: aws.String(testApiID), + Name: aws.String(testApiName), + CreatedDate: aws.Time(now), + Tags: map[string]string{"Environment": "production"}, + }}, }, } - apis, err = gwWithTags.getAll(context.Background(), config.Config{ - APIGatewayV2: config.ResourceType{ - IncludeRule: config.FilterRule{ - Tags: map[string]config.Expression{ - "Environment": {RE: *regexp.MustCompile("production")}, - }}}}) + apis, err = listApiGatewaysV2(context.Background(), clientWithTags, resource.Scope{Region: "us-east-1"}, config.ResourceType{ + IncludeRule: config.FilterRule{ + Tags: map[string]config.Expression{ + "Environment": {RE: *regexp.MustCompile("production")}, + }}}) assert.NoError(t, err) assert.Contains(t, aws.ToStringSlice(apis), testApiID) - apis, err = gwWithTags.getAll(context.Background(), config.Config{ - APIGatewayV2: config.ResourceType{ - ExcludeRule: config.FilterRule{ - Tags: map[string]config.Expression{ - "Environment": {RE: *regexp.MustCompile("production")}, - }}}}) + apis, err = listApiGatewaysV2(context.Background(), clientWithTags, resource.Scope{Region: "us-east-1"}, config.ResourceType{ + ExcludeRule: config.FilterRule{ + Tags: map[string]config.Expression{ + "Environment": {RE: *regexp.MustCompile("production")}, + }}}) assert.NoError(t, err) assert.NotContains(t, aws.ToStringSlice(apis), testApiID) } @@ -120,26 +113,24 @@ func TestApiGatewayV2GetAll(t *testing.T) { func TestApiGatewayV2NukeAll(t *testing.T) { t.Parallel() - gw := ApiGatewayV2{ - Client: mockedApiGatewayV2{ - DeleteApiOutput: apigatewayv2.DeleteApiOutput{}, - GetDomainNamesOutput: apigatewayv2.GetDomainNamesOutput{ - Items: []types.DomainName{ - { - DomainName: aws.String("test-domain-name"), - }, + client := mockedApiGatewayV2{ + DeleteApiOutput: apigatewayv2.DeleteApiOutput{}, + GetDomainNamesOutput: apigatewayv2.GetDomainNamesOutput{ + Items: []types.DomainName{ + { + DomainName: aws.String("test-domain-name"), }, }, - GetApisOutput: apigatewayv2.GetApisOutput{ - Items: []types.Api{ - { - ApiId: aws.String("test-api-id"), - }, + }, + GetApisOutput: apigatewayv2.GetApisOutput{ + Items: []types.Api{ + { + ApiId: aws.String("test-api-id"), }, }, - DeleteApiMappingOutput: apigatewayv2.DeleteApiMappingOutput{}, }, + DeleteApiMappingOutput: apigatewayv2.DeleteApiMappingOutput{}, } - err := gw.nukeAll([]*string{aws.String("test-api-id")}) + err := deleteApiGatewaysV2(context.Background(), client, resource.Scope{Region: "us-east-1"}, "apigatewayv2", []*string{aws.String("test-api-id")}) assert.NoError(t, err) } diff --git a/aws/resources/apigatewayv2_types.go b/aws/resources/apigatewayv2_types.go deleted file mode 100644 index 3adb54d8..00000000 --- a/aws/resources/apigatewayv2_types.go +++ /dev/null @@ -1,69 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/apigatewayv2" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type ApiGatewayV2API interface { - GetApis(ctx context.Context, params *apigatewayv2.GetApisInput, optFns ...func(*apigatewayv2.Options)) (*apigatewayv2.GetApisOutput, error) - GetDomainNames(ctx context.Context, params *apigatewayv2.GetDomainNamesInput, optFns ...func(*apigatewayv2.Options)) (*apigatewayv2.GetDomainNamesOutput, error) - GetApiMappings(ctx context.Context, params *apigatewayv2.GetApiMappingsInput, optFns ...func(*apigatewayv2.Options)) (*apigatewayv2.GetApiMappingsOutput, error) - DeleteApi(ctx context.Context, params *apigatewayv2.DeleteApiInput, optFns ...func(*apigatewayv2.Options)) (*apigatewayv2.DeleteApiOutput, error) - DeleteApiMapping(ctx context.Context, params *apigatewayv2.DeleteApiMappingInput, optFns ...func(*apigatewayv2.Options)) (*apigatewayv2.DeleteApiMappingOutput, error) -} - -type ApiGatewayV2 struct { - BaseAwsResource - Client ApiGatewayV2API - Region string - Ids []string -} - -func (gw *ApiGatewayV2) Init(cfg aws.Config) { - gw.Client = apigatewayv2.NewFromConfig(cfg) -} - -func (gw *ApiGatewayV2) ResourceName() string { - return "apigatewayv2" -} - -func (gw *ApiGatewayV2) ResourceIdentifiers() []string { - return gw.Ids -} - -func (gw *ApiGatewayV2) MaxBatchSize() int { - return 10 -} - -func (gw *ApiGatewayV2) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.APIGatewayV2 -} - -func (gw *ApiGatewayV2) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := gw.getAll(c, configObj) - if err != nil { - return nil, err - } - - gw.Ids = aws.ToStringSlice(identifiers) - return gw.Ids, nil -} - -func (gw *ApiGatewayV2) Nuke(ctx context.Context, identifiers []string) error { - if err := gw.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} - -type TooManyApiGatewayV2Err struct{} - -func (err TooManyApiGatewayV2Err) Error() string { - return "Too many Api Gateways requested at once." -} diff --git a/aws/resources/asg.go b/aws/resources/asg.go index ce65408c..58797829 100644 --- a/aws/resources/asg.go +++ b/aws/resources/asg.go @@ -2,83 +2,76 @@ package resources import ( "context" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/autoscaling" "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/gruntwork-io/go-commons/errors" ) -// Returns a formatted string of ASG Names -func (ag *ASGroups) getAll(c context.Context, configObj config.Config) ([]*string, error) { - result, err := ag.Client.DescribeAutoScalingGroups(ag.Context, &autoscaling.DescribeAutoScalingGroupsInput{}) - if err != nil { - return nil, errors.WithStackTrace(err) - } - - var groupNames []*string - for _, group := range result.AutoScalingGroups { - if configObj.AutoScalingGroup.ShouldInclude(config.ResourceValue{ - Time: group.CreatedTime, - Name: group.AutoScalingGroupName, - Tags: util.ConvertAutoScalingTagsToMap(group.Tags), - }) { - groupNames = append(groupNames, group.AutoScalingGroupName) - } - } - - return groupNames, nil +// ASGroupsAPI defines the interface for Auto Scaling Group operations. +type ASGroupsAPI interface { + DescribeAutoScalingGroups(ctx context.Context, params *autoscaling.DescribeAutoScalingGroupsInput, optFns ...func(*autoscaling.Options)) (*autoscaling.DescribeAutoScalingGroupsOutput, error) + DeleteAutoScalingGroup(ctx context.Context, params *autoscaling.DeleteAutoScalingGroupInput, optFns ...func(*autoscaling.Options)) (*autoscaling.DeleteAutoScalingGroupOutput, error) } -// Deletes all Auto Scaling Groups -func (ag *ASGroups) nukeAll(groupNames []*string) error { - if len(groupNames) == 0 { - logging.Debugf("No Auto Scaling Groups to nuke in region %s", ag.Region) - return nil - } - - logging.Debugf("Deleting all Auto Scaling Groups in region %s", ag.Region) - var deletedGroupNames []string - - for _, groupName := range groupNames { - params := &autoscaling.DeleteAutoScalingGroupInput{ - AutoScalingGroupName: groupName, - ForceDelete: aws.Bool(true), - } +// NewASGroups creates a new ASGroups resource using the generic resource pattern. +func NewASGroups() AwsResource { + return NewAwsResource(&resource.Resource[ASGroupsAPI]{ + ResourceTypeName: "asg", + BatchSize: 49, + InitClient: WrapAwsInitClient(func(r *resource.Resource[ASGroupsAPI], cfg aws.Config) { + r.Client = autoscaling.NewFromConfig(cfg) + }), + ConfigGetter: func(c config.Config) config.ResourceType { + return c.AutoScalingGroup + }, + Lister: listASGroups, + Nuker: resource.SequentialDeleteThenWaitAll(deleteASG, waitForASGsDeleted), + }) +} - _, err := ag.Client.DeleteAutoScalingGroup(ag.Context, params) +// listASGroups retrieves all Auto Scaling Groups that match the config filters. +func listASGroups(ctx context.Context, client ASGroupsAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + var groupNames []*string + paginator := autoscaling.NewDescribeAutoScalingGroupsPaginator(client, &autoscaling.DescribeAutoScalingGroupsInput{}) - // Record status of this resource - e := report.Entry{ - Identifier: *groupName, - ResourceType: "Auto-Scaling Group", - Error: err, + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return nil, errors.WithStackTrace(err) } - report.Record(e) - if err != nil { - logging.Debugf("[Failed] %s", err) - } else { - deletedGroupNames = append(deletedGroupNames, *groupName) - logging.Debugf("Deleted Auto Scaling Group: %s", *groupName) + for _, group := range page.AutoScalingGroups { + if cfg.ShouldInclude(config.ResourceValue{ + Time: group.CreatedTime, + Name: group.AutoScalingGroupName, + Tags: util.ConvertAutoScalingTagsToMap(group.Tags), + }) { + groupNames = append(groupNames, group.AutoScalingGroupName) + } } } - if len(deletedGroupNames) > 0 { - waiter := autoscaling.NewGroupNotExistsWaiter(ag.Client) - err := waiter.Wait(ag.Context, &autoscaling.DescribeAutoScalingGroupsInput{ - AutoScalingGroupNames: deletedGroupNames, - }, ag.Timeout) + return groupNames, nil +} - if err != nil { - logging.Errorf("[Failed] %s", err) - return errors.WithStackTrace(err) - } - } +// deleteASG deletes a single Auto Scaling Group by name. +func deleteASG(ctx context.Context, client ASGroupsAPI, name *string) error { + _, err := client.DeleteAutoScalingGroup(ctx, &autoscaling.DeleteAutoScalingGroupInput{ + AutoScalingGroupName: name, + ForceDelete: aws.Bool(true), + }) + return errors.WithStackTrace(err) +} - logging.Debugf("[OK] %d Auto Scaling Group(s) deleted in %s", len(deletedGroupNames), ag.Region) - return nil +// waitForASGsDeleted waits for all specified Auto Scaling Groups to be deleted. +func waitForASGsDeleted(ctx context.Context, client ASGroupsAPI, names []string) error { + waiter := autoscaling.NewGroupNotExistsWaiter(client) + return waiter.Wait(ctx, &autoscaling.DescribeAutoScalingGroupsInput{ + AutoScalingGroupNames: names, + }, 5*time.Minute) } diff --git a/aws/resources/asg_test.go b/aws/resources/asg_test.go index 6721b962..5326ceda 100644 --- a/aws/resources/asg_test.go +++ b/aws/resources/asg_test.go @@ -10,7 +10,8 @@ import ( "github.com/aws/aws-sdk-go-v2/service/autoscaling" "github.com/aws/aws-sdk-go-v2/service/autoscaling/types" "github.com/gruntwork-io/cloud-nuke/config" - "github.com/stretchr/testify/assert" + "github.com/gruntwork-io/cloud-nuke/resource" + "github.com/stretchr/testify/require" ) type mockedASGroups struct { @@ -27,57 +28,88 @@ func (m mockedASGroups) DeleteAutoScalingGroup(ctx context.Context, params *auto return &m.DeleteAutoScalingGroupOutput, nil } -func TestAutoScalingGroupGetAll(t *testing.T) { +func TestASGroups_GetAll(t *testing.T) { t.Parallel() - testName := "cloud-nuke-test" + testName1 := "test-asg-1" + testName2 := "test-asg-2" now := time.Now() - ag := ASGroups{ - Client: mockedASGroups{ - DescribeAutoScalingGroupsOutput: autoscaling.DescribeAutoScalingGroupsOutput{ - AutoScalingGroups: []types.AutoScalingGroup{{ - AutoScalingGroupName: aws.String(testName), - CreatedTime: aws.Time(now), - }}}}} - // empty filter - groups, err := ag.getAll(context.Background(), config.Config{}) - assert.NoError(t, err) - assert.Contains(t, aws.ToStringSlice(groups), testName) + mock := mockedASGroups{ + DescribeAutoScalingGroupsOutput: autoscaling.DescribeAutoScalingGroupsOutput{ + AutoScalingGroups: []types.AutoScalingGroup{ + { + AutoScalingGroupName: aws.String(testName1), + CreatedTime: aws.Time(now), + Tags: []types.TagDescription{ + {Key: aws.String("env"), Value: aws.String("dev")}, + }, + }, + { + AutoScalingGroupName: aws.String(testName2), + CreatedTime: aws.Time(now.Add(1 * time.Hour)), + Tags: []types.TagDescription{ + {Key: aws.String("env"), Value: aws.String("prod")}, + }, + }, + }, + }, + } - // name filter - groups, err = ag.getAll(context.Background(), config.Config{ - AutoScalingGroup: config.ResourceType{ - ExcludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{{ - RE: *regexp.MustCompile("^cloud-nuke-*"), - }}}}}) - assert.NoError(t, err) - assert.NotContains(t, aws.ToStringSlice(groups), testName) + tests := map[string]struct { + configObj config.ResourceType + expected []string + }{ + "emptyFilter": { + configObj: config.ResourceType{}, + expected: []string{testName1, testName2}, + }, + "nameExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile("test-asg-1"), + }}, + }, + }, + expected: []string{testName2}, + }, + "timeAfterExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now.Add(30 * time.Minute)), + }, + }, + expected: []string{testName1}, + }, + "tagExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + Tags: map[string]config.Expression{ + "env": {RE: *regexp.MustCompile("prod")}, + }, + }, + }, + expected: []string{testName1}, + }, + } - // time filter - groups, err = ag.getAll(context.Background(), config.Config{ - AutoScalingGroup: config.ResourceType{ - ExcludeRule: config.FilterRule{ - TimeAfter: aws.Time(now.Add(-1)), - }}}) - assert.NoError(t, err) - assert.NotContains(t, aws.ToStringSlice(groups), testName) + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + names, err := listASGroups(context.Background(), mock, resource.Scope{}, tc.configObj) + require.NoError(t, err) + require.Equal(t, tc.expected, aws.ToStringSlice(names)) + }) + } } -func TestAutoScalingGroupNukeAll(t *testing.T) { +func TestASGroups_NukeAll(t *testing.T) { t.Parallel() - ag := ASGroups{ - BaseAwsResource: BaseAwsResource{ - Context: context.Background(), - Timeout: DefaultWaitTimeout, - }, - Client: mockedASGroups{ - DeleteAutoScalingGroupOutput: autoscaling.DeleteAutoScalingGroupOutput{}, - }, + mock := mockedASGroups{ + DeleteAutoScalingGroupOutput: autoscaling.DeleteAutoScalingGroupOutput{}, } - err := ag.nukeAll([]*string{aws.String("cloud-nuke-test")}) - assert.NoError(t, err) + err := deleteASG(context.Background(), mock, aws.String("test-asg")) + require.NoError(t, err) } diff --git a/aws/resources/asg_types.go b/aws/resources/asg_types.go deleted file mode 100644 index 2782a599..00000000 --- a/aws/resources/asg_types.go +++ /dev/null @@ -1,64 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/autoscaling" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type ASGroupsAPI interface { - DescribeAutoScalingGroups(ctx context.Context, params *autoscaling.DescribeAutoScalingGroupsInput, optFns ...func(*autoscaling.Options)) (*autoscaling.DescribeAutoScalingGroupsOutput, error) - DeleteAutoScalingGroup(ctx context.Context, params *autoscaling.DeleteAutoScalingGroupInput, optFns ...func(*autoscaling.Options)) (*autoscaling.DeleteAutoScalingGroupOutput, error) -} - -// ASGroups - represents all auto-scaling groups -type ASGroups struct { - BaseAwsResource - Client ASGroupsAPI - Region string - GroupNames []string -} - -func (ag *ASGroups) Init(cfg aws.Config) { - ag.Client = autoscaling.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (ag *ASGroups) ResourceName() string { - return "asg" -} - -func (ag *ASGroups) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -// ResourceIdentifiers - The group names of the auto-scaling groups -func (ag *ASGroups) ResourceIdentifiers() []string { - return ag.GroupNames -} - -func (ag *ASGroups) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.AutoScalingGroup -} -func (ag *ASGroups) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := ag.getAll(c, configObj) - if err != nil { - return nil, err - } - - ag.GroupNames = aws.ToStringSlice(identifiers) - return ag.GroupNames, nil -} - -// Nuke - nuke 'em all!!! -func (ag *ASGroups) Nuke(ctx context.Context, identifiers []string) error { - if err := ag.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/codedeploy_application.go b/aws/resources/codedeploy_application.go index 5a086d60..7bcafec3 100644 --- a/aws/resources/codedeploy_application.go +++ b/aws/resources/codedeploy_application.go @@ -2,23 +2,51 @@ package resources import ( "context" - "sync" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/codedeploy" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" - "github.com/hashicorp/go-multierror" ) -func (cda *CodeDeployApplications) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// CodeDeployApplicationsAPI defines the interface for CodeDeploy operations. +type CodeDeployApplicationsAPI interface { + ListApplications(ctx context.Context, params *codedeploy.ListApplicationsInput, optFns ...func(*codedeploy.Options)) (*codedeploy.ListApplicationsOutput, error) + BatchGetApplications(ctx context.Context, params *codedeploy.BatchGetApplicationsInput, optFns ...func(*codedeploy.Options)) (*codedeploy.BatchGetApplicationsOutput, error) + DeleteApplication(ctx context.Context, params *codedeploy.DeleteApplicationInput, optFns ...func(*codedeploy.Options)) (*codedeploy.DeleteApplicationOutput, error) +} + +// NewCodeDeployApplications creates a new CodeDeployApplications resource using the generic resource pattern. +func NewCodeDeployApplications() AwsResource { + return NewAwsResource(&resource.Resource[CodeDeployApplicationsAPI]{ + ResourceTypeName: "codedeploy-application", + BatchSize: 100, + InitClient: func(r *resource.Resource[CodeDeployApplicationsAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for CodeDeployApplications client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = codedeploy.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.CodeDeployApplications + }, + Lister: listCodeDeployApplications, + Nuker: resource.SimpleBatchDeleter(deleteCodeDeployApplication), + }) +} + +// listCodeDeployApplications retrieves all CodeDeploy applications that match the config filters. +func listCodeDeployApplications(ctx context.Context, client CodeDeployApplicationsAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { var codeDeployApplicationsFilteredByName []string - paginator := codedeploy.NewListApplicationsPaginator(cda.Client, &codedeploy.ListApplicationsInput{}) + paginator := codedeploy.NewListApplicationsPaginator(client, &codedeploy.ListApplicationsInput{}) for paginator.HasMorePages() { - page, err := paginator.NextPage(c) + page, err := paginator.NextPage(ctx) if err != nil { return nil, errors.WithStackTrace(err) } @@ -27,7 +55,7 @@ func (cda *CodeDeployApplications) getAll(c context.Context, configObj config.Co // Check if the CodeDeploy Application should be excluded by name as that information is available to us here. // CreationDate is not available in the ListApplications API call, so we can't filter by that here, but we do filter by it later. // By filtering the name here, we can reduce the number of BatchGetApplication API calls we have to make. - if configObj.CodeDeployApplications.ShouldInclude(config.ResourceValue{Name: aws.String(application)}) { + if cfg.ShouldInclude(config.ResourceValue{Name: aws.String(application)}) { codeDeployApplicationsFilteredByName = append(codeDeployApplicationsFilteredByName, application) } } @@ -35,11 +63,11 @@ func (cda *CodeDeployApplications) getAll(c context.Context, configObj config.Co // Check if the CodeDeploy Application should be excluded by CreationDate and return. // We have to do this after the ListApplicationsPages API call because CreationDate is not available in that call. - return cda.batchDescribeAndFilter(codeDeployApplicationsFilteredByName, configObj) + return batchDescribeAndFilterCodeDeployApplications(ctx, client, codeDeployApplicationsFilteredByName, cfg) } // batchDescribeAndFilterCodeDeployApplications - Describe the CodeDeploy Applications and filter out the ones that should be excluded by CreationDate. -func (cda *CodeDeployApplications) batchDescribeAndFilter(identifiers []string, configObj config.Config) ([]*string, error) { +func batchDescribeAndFilterCodeDeployApplications(ctx context.Context, client CodeDeployApplicationsAPI, identifiers []string, cfg config.ResourceType) ([]*string, error) { // BatchGetApplications can only take 100 identifiers at a time, so we have to break up the identifiers into chunks of 100. batchSize := 100 var applicationNames []*string @@ -58,8 +86,8 @@ func (cda *CodeDeployApplications) batchDescribeAndFilter(identifiers []string, // get the next batch of identifiers batch := identifiers[:batchSize] // then using that batch of identifiers, get the applicationsinfo - resp, err := cda.Client.BatchGetApplications( - cda.Context, + resp, err := client.BatchGetApplications( + ctx, &codedeploy.BatchGetApplicationsInput{ApplicationNames: batch}, ) if err != nil { @@ -68,7 +96,7 @@ func (cda *CodeDeployApplications) batchDescribeAndFilter(identifiers []string, // for each applicationsinfo, check if it should be excluded by creation date for j := range resp.ApplicationsInfo { - if configObj.CodeDeployApplications.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Time: resp.ApplicationsInfo[j].CreateTime, }) { applicationNames = append(applicationNames, resp.ApplicationsInfo[j].ApplicationName) @@ -82,59 +110,8 @@ func (cda *CodeDeployApplications) batchDescribeAndFilter(identifiers []string, return applicationNames, nil } -func (cda *CodeDeployApplications) nukeAll(identifiers []string) error { - if len(identifiers) == 0 { - logging.Debugf("No CodeDeploy Applications to nuke in region %s", cda.Region) - return nil - } - - logging.Infof("Deleting CodeDeploy Applications in region %s", cda.Region) - - var wg sync.WaitGroup - errChan := make(chan error, len(identifiers)) - - for _, identifier := range identifiers { - wg.Add(1) - go cda.deleteAsync(&wg, errChan, identifier) - } - - wg.Wait() - close(errChan) - - var allErrors *multierror.Error - for err := range errChan { - allErrors = multierror.Append(allErrors, err) - - logging.Errorf("[Failed] Error deleting CodeDeploy Application: %s", err) - } - - finalErr := allErrors.ErrorOrNil() - if finalErr != nil { - return errors.WithStackTrace(finalErr) - } - - return nil -} - -func (cda *CodeDeployApplications) deleteAsync(wg *sync.WaitGroup, errChan chan<- error, identifier string) { - defer wg.Done() - - _, err := cda.Client.DeleteApplication(cda.Context, &codedeploy.DeleteApplicationInput{ApplicationName: &identifier}) - if err != nil { - errChan <- err - } - - // record the status of the nuke attempt - e := report.Entry{ - Identifier: identifier, - ResourceType: "CodeDeploy Application", - Error: err, - } - report.Record(e) - - if err == nil { - logging.Debugf("[OK] Deleted CodeDeploy Application: %s", identifier) - } else { - logging.Debugf("[Failed] Error deleting CodeDeploy Application %s: %s", identifier, err) - } +// deleteCodeDeployApplication deletes a single CodeDeploy Application. +func deleteCodeDeployApplication(ctx context.Context, client CodeDeployApplicationsAPI, identifier *string) error { + _, err := client.DeleteApplication(ctx, &codedeploy.DeleteApplicationInput{ApplicationName: identifier}) + return err } diff --git a/aws/resources/codedeploy_application_test.go b/aws/resources/codedeploy_application_test.go index ad4c4212..836bcd43 100644 --- a/aws/resources/codedeploy_application_test.go +++ b/aws/resources/codedeploy_application_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/codedeploy" "github.com/aws/aws-sdk-go-v2/service/codedeploy/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) @@ -53,24 +54,22 @@ func TestCodeDeployApplication_GetAll(t *testing.T) { testName1 := "cloud-nuke-test-1" testName2 := "cloud-nuke-test-2" now := time.Now() - c := CodeDeployApplications{ - Client: mockedCodeDeployApplications{ - ListApplicationsOutput: codedeploy.ListApplicationsOutput{ - Applications: []string{ - testName1, - testName2, - }, + client := mockedCodeDeployApplications{ + ListApplicationsOutput: codedeploy.ListApplicationsOutput{ + Applications: []string{ + testName1, + testName2, }, - BatchGetApplicationsOutput: codedeploy.BatchGetApplicationsOutput{ - ApplicationsInfo: []types.ApplicationInfo{ - { - ApplicationName: aws.String(testName1), - CreateTime: aws.Time(now), - }, - { - ApplicationName: aws.String(testName2), - CreateTime: aws.Time(now.Add(1)), - }, + }, + BatchGetApplicationsOutput: codedeploy.BatchGetApplicationsOutput{ + ApplicationsInfo: []types.ApplicationInfo{ + { + ApplicationName: aws.String(testName1), + CreateTime: aws.Time(now), + }, + { + ApplicationName: aws.String(testName2), + CreateTime: aws.Time(now.Add(1)), }, }, }, @@ -103,9 +102,7 @@ func TestCodeDeployApplication_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := c.getAll(context.Background(), config.Config{ - CodeDeployApplications: tc.configObj, - }) + names, err := listCodeDeployApplications(context.Background(), client, resource.Scope{Region: "us-east-1"}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) @@ -114,12 +111,10 @@ func TestCodeDeployApplication_GetAll(t *testing.T) { } func TestCodeDeployApplication_NukeAll(t *testing.T) { - c := CodeDeployApplications{ - Client: mockedCodeDeployApplications{ - DeleteApplicationOutput: codedeploy.DeleteApplicationOutput{}, - }, + client := mockedCodeDeployApplications{ + DeleteApplicationOutput: codedeploy.DeleteApplicationOutput{}, } - err := c.nukeAll([]string{"test"}) + err := deleteCodeDeployApplication(context.Background(), client, aws.String("test")) require.NoError(t, err) } diff --git a/aws/resources/codedeploy_application_types.go b/aws/resources/codedeploy_application_types.go deleted file mode 100644 index 57265744..00000000 --- a/aws/resources/codedeploy_application_types.go +++ /dev/null @@ -1,66 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/codedeploy" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type CodeDeployApplicationsAPI interface { - ListApplications(ctx context.Context, params *codedeploy.ListApplicationsInput, optFns ...func(*codedeploy.Options)) (*codedeploy.ListApplicationsOutput, error) - BatchGetApplications(ctx context.Context, params *codedeploy.BatchGetApplicationsInput, optFns ...func(*codedeploy.Options)) (*codedeploy.BatchGetApplicationsOutput, error) - DeleteApplication(ctx context.Context, params *codedeploy.DeleteApplicationInput, optFns ...func(*codedeploy.Options)) (*codedeploy.DeleteApplicationOutput, error) -} - -// CodeDeployApplications - represents all codedeploy applications -type CodeDeployApplications struct { - BaseAwsResource - Client CodeDeployApplicationsAPI - Region string - AppNames []string -} - -func (cda *CodeDeployApplications) Init(cfg aws.Config) { - cda.Client = codedeploy.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (cda *CodeDeployApplications) ResourceName() string { - return "codedeploy-application" -} - -// ResourceIdentifiers - The instance ids of the code deploy applications -func (cda *CodeDeployApplications) ResourceIdentifiers() []string { - return cda.AppNames -} - -func (cda *CodeDeployApplications) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle. - return 100 -} - -func (cda *CodeDeployApplications) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.CodeDeployApplications -} - -func (cda *CodeDeployApplications) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := cda.getAll(c, configObj) - if err != nil { - return nil, err - } - - cda.AppNames = aws.ToStringSlice(identifiers) - return cda.AppNames, nil -} - -// Nuke - nuke 'em all!!! -func (cda *CodeDeployApplications) Nuke(ctx context.Context, identifiers []string) error { - if err := cda.nukeAll(identifiers); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/ebs.go b/aws/resources/ebs.go index ea52ecd8..0db796ff 100644 --- a/aws/resources/ebs.go +++ b/aws/resources/ebs.go @@ -2,54 +2,72 @@ package resources import ( "context" - goerr "errors" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/aws/smithy-go" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" - "github.com/gruntwork-io/go-commons/errors" ) -// Returns a formatted string of EBS volume ids -func (ev *EBSVolumes) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// EBSVolumesAPI defines the interface for EBS Volume operations. +type EBSVolumesAPI interface { + DescribeVolumes(ctx context.Context, params *ec2.DescribeVolumesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVolumesOutput, error) + DeleteVolume(ctx context.Context, params *ec2.DeleteVolumeInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVolumeOutput, error) +} + +// NewEBSVolumes creates a new EBS Volumes resource using the generic resource pattern. +func NewEBSVolumes() AwsResource { + return NewAwsResource(&resource.Resource[EBSVolumesAPI]{ + ResourceTypeName: "ebs", + BatchSize: 49, + InitClient: func(r *resource.Resource[EBSVolumesAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for EC2 client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = ec2.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.EBSVolume + }, + Lister: listEBSVolumes, + Nuker: resource.SequentialDeleteThenWaitAll(deleteEBSVolume, waitForEBSVolumesDeleted), + PermissionVerifier: verifyEBSVolumePermission, + }) +} + +// listEBSVolumes retrieves all EBS volumes that match the config filters. +func listEBSVolumes(ctx context.Context, client EBSVolumesAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { // Available statuses: (creating | available | in-use | deleting | deleted | error). // Since the output of this function is used to delete the returned volumes // We want to only list EBS volumes with a status of "available" or "creating" // Since those are the only statuses that are eligible for deletion statusFilter := types.Filter{Name: aws.String("status"), Values: []string{"available", "creating", "error"}} - result, err := ev.Client.DescribeVolumes(ev.Context, &ec2.DescribeVolumesInput{ + result, err := client.DescribeVolumes(ctx, &ec2.DescribeVolumesInput{ Filters: []types.Filter{statusFilter}, }) if err != nil { - return nil, errors.WithStackTrace(err) + return nil, err } var volumeIds []*string for _, volume := range result.Volumes { - if shouldIncludeEBSVolume(volume, configObj) { + if shouldIncludeEBSVolume(volume, cfg) { volumeIds = append(volumeIds, volume.VolumeId) } } - // checking the nukable permissions - ev.VerifyNukablePermissions(volumeIds, func(id *string) error { - _, err := ev.Client.DeleteVolume(ev.Context, &ec2.DeleteVolumeInput{ - VolumeId: id, - DryRun: aws.Bool(true), - }) - return err - }) - return volumeIds, nil } -func shouldIncludeEBSVolume(volume types.Volume, configObj config.Config) bool { +func shouldIncludeEBSVolume(volume types.Volume, cfg config.ResourceType) bool { name := "" for _, tag := range volume.Tags { if aws.ToString(tag.Key) == "Name" { @@ -57,73 +75,34 @@ func shouldIncludeEBSVolume(volume types.Volume, configObj config.Config) bool { } } - return configObj.EBSVolume.ShouldInclude(config.ResourceValue{ + return cfg.ShouldInclude(config.ResourceValue{ Name: &name, Time: volume.CreateTime, Tags: util.ConvertTypesTagsToMap(volume.Tags), }) } -// Deletes all EBS Volumes -func (ev *EBSVolumes) nukeAll(volumeIds []*string) error { - if len(volumeIds) == 0 { - logging.Debugf("No EBS volumes to nuke in region %s", ev.Region) - return nil - } - - logging.Debugf("Deleting all EBS volumes in region %s", ev.Region) - var deletedVolumeIDs []*string - - for _, volumeID := range volumeIds { - - if nukable, reason := ev.IsNukable(aws.ToString(volumeID)); !nukable { - logging.Debugf("[Skipping] %s nuke because %v", aws.ToString(volumeID), reason) - continue - } - - params := &ec2.DeleteVolumeInput{ - VolumeId: volumeID, - } - - _, err := ev.Client.DeleteVolume(ev.Context, params) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(volumeID), - ResourceType: "EBS Volume", - Error: err, - } - report.Record(e) - - if err != nil { - var apiErr smithy.APIError - if goerr.As(err, &apiErr) { - switch apiErr.ErrorCode() { - case "VolumeInUse": - logging.Debugf("EBS volume %s can't be deleted, it is still attached to an active resource", *volumeID) - case "InvalidVolume.NotFound": - logging.Debugf("EBS volume %s has already been deleted", *volumeID) - default: - logging.Debugf("[Failed] %s", err) - } - } - } else { - deletedVolumeIDs = append(deletedVolumeIDs, volumeID) - logging.Debugf("Deleted EBS Volume: %s", *volumeID) - } - } +// verifyEBSVolumePermission performs a dry-run delete to check permissions. +func verifyEBSVolumePermission(ctx context.Context, client EBSVolumesAPI, id *string) error { + _, err := client.DeleteVolume(ctx, &ec2.DeleteVolumeInput{ + VolumeId: id, + DryRun: aws.Bool(true), + }) + return err +} - if len(deletedVolumeIDs) > 0 { - waiter := ec2.NewVolumeDeletedWaiter(ev.Client) - err := waiter.Wait(ev.Context, &ec2.DescribeVolumesInput{ - VolumeIds: aws.ToStringSlice(deletedVolumeIDs), - }, ev.Timeout) - if err != nil { - logging.Debugf("[Failed] %s", err) - return errors.WithStackTrace(err) - } - } +// deleteEBSVolume deletes a single EBS volume. +func deleteEBSVolume(ctx context.Context, client EBSVolumesAPI, volumeID *string) error { + _, err := client.DeleteVolume(ctx, &ec2.DeleteVolumeInput{ + VolumeId: volumeID, + }) + return err +} - logging.Debugf("[OK] %d EBS volumes(s) terminated in %s", len(deletedVolumeIDs), ev.Region) - return nil +// waitForEBSVolumesDeleted waits for all specified EBS volumes to be fully deleted. +func waitForEBSVolumesDeleted(ctx context.Context, client EBSVolumesAPI, ids []string) error { + waiter := ec2.NewVolumeDeletedWaiter(client) + return waiter.Wait(ctx, &ec2.DescribeVolumesInput{ + VolumeIds: ids, + }, 5*time.Minute) } diff --git a/aws/resources/ebs_test.go b/aws/resources/ebs_test.go index 6e8adf30..75402220 100644 --- a/aws/resources/ebs_test.go +++ b/aws/resources/ebs_test.go @@ -10,24 +10,35 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type mockedEBS struct { - EBSVolumesAPI +type mockEBSVolumesClient struct { DescribeVolumesOutput ec2.DescribeVolumesOutput DeleteVolumeOutput ec2.DeleteVolumeOutput } -func (m mockedEBS) DescribeVolumes(ctx context.Context, params *ec2.DescribeVolumesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVolumesOutput, error) { +func (m *mockEBSVolumesClient) DescribeVolumes(ctx context.Context, params *ec2.DescribeVolumesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVolumesOutput, error) { return &m.DescribeVolumesOutput, nil } -func (m mockedEBS) DeleteVolume(ctx context.Context, params *ec2.DeleteVolumeInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVolumeOutput, error) { +func (m *mockEBSVolumesClient) DeleteVolume(ctx context.Context, params *ec2.DeleteVolumeInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVolumeOutput, error) { return &m.DeleteVolumeOutput, nil } -func TestEBSVolume_GetAll(t *testing.T) { +func TestEBSVolumes_ResourceName(t *testing.T) { + r := NewEBSVolumes() + assert.Equal(t, "ebs", r.ResourceName()) +} + +func TestEBSVolumes_MaxBatchSize(t *testing.T) { + r := NewEBSVolumes() + assert.Equal(t, 49, r.MaxBatchSize()) +} + +func TestListEBSVolumes(t *testing.T) { t.Parallel() testName1 := "test-name1" @@ -35,74 +46,128 @@ func TestEBSVolume_GetAll(t *testing.T) { testVolume1 := "test-volume1" testVolume2 := "test-volume2" now := time.Now() - ev := EBSVolumes{ - Client: mockedEBS{ - DescribeVolumesOutput: ec2.DescribeVolumesOutput{ - Volumes: []types.Volume{ - { - VolumeId: aws.String(testVolume1), - CreateTime: aws.Time(now), - Tags: []types.Tag{{ - Key: aws.String("Name"), - Value: aws.String(testName1), - }}, - }, - { - VolumeId: aws.String(testVolume2), - CreateTime: aws.Time(now.Add(1)), - Tags: []types.Tag{{ - Key: aws.String("Name"), - Value: aws.String(testName2), - }}, - }, - }}}} - - tests := map[string]struct { - configObj config.ResourceType - expected []string - }{ - "emptyFilter": { - configObj: config.ResourceType{}, - expected: []string{testVolume1, testVolume2}, - }, - "nameExclusionFilter": { - configObj: config.ResourceType{ - ExcludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{{ - RE: *regexp.MustCompile(testName1), - }}}, + + mock := &mockEBSVolumesClient{ + DescribeVolumesOutput: ec2.DescribeVolumesOutput{ + Volumes: []types.Volume{ + { + VolumeId: aws.String(testVolume1), + CreateTime: aws.Time(now), + Tags: []types.Tag{{ + Key: aws.String("Name"), + Value: aws.String(testName1), + }}, + }, + { + VolumeId: aws.String(testVolume2), + CreateTime: aws.Time(now.Add(1)), + Tags: []types.Tag{{ + Key: aws.String("Name"), + Value: aws.String(testName2), + }}, + }, }, - expected: []string{testVolume2}, }, - "timeAfterExclusionFilter": { - configObj: config.ResourceType{ - ExcludeRule: config.FilterRule{ - TimeAfter: aws.Time(now.Add(-2 * time.Hour)), - }}, - expected: []string{}, + } + + ids, err := listEBSVolumes(context.Background(), mock, resource.Scope{}, config.ResourceType{}) + require.NoError(t, err) + require.ElementsMatch(t, []string{testVolume1, testVolume2}, aws.ToStringSlice(ids)) +} + +func TestListEBSVolumes_WithNameExclusionFilter(t *testing.T) { + t.Parallel() + + testName1 := "test-name1" + testName2 := "test-name2" + testVolume1 := "test-volume1" + testVolume2 := "test-volume2" + now := time.Now() + + mock := &mockEBSVolumesClient{ + DescribeVolumesOutput: ec2.DescribeVolumesOutput{ + Volumes: []types.Volume{ + { + VolumeId: aws.String(testVolume1), + CreateTime: aws.Time(now), + Tags: []types.Tag{{ + Key: aws.String("Name"), + Value: aws.String(testName1), + }}, + }, + { + VolumeId: aws.String(testVolume2), + CreateTime: aws.Time(now.Add(1)), + Tags: []types.Tag{{ + Key: aws.String("Name"), + Value: aws.String(testName2), + }}, + }, + }, }, } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - names, err := ev.getAll(context.Background(), config.Config{ - EBSVolume: tc.configObj, - }) - require.NoError(t, err) - require.Equal(t, tc.expected, aws.ToStringSlice(names)) - }) + cfg := config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{RE: *regexp.MustCompile(testName1)}}, + }, } + + ids, err := listEBSVolumes(context.Background(), mock, resource.Scope{}, cfg) + require.NoError(t, err) + require.Equal(t, []string{testVolume2}, aws.ToStringSlice(ids)) } -func TestEBSVolume_NukeAll(t *testing.T) { +func TestListEBSVolumes_WithTimeAfterExclusionFilter(t *testing.T) { t.Parallel() - ev := EBSVolumes{ - Client: mockedEBS{ - DeleteVolumeOutput: ec2.DeleteVolumeOutput{}, + testName1 := "test-name1" + testName2 := "test-name2" + testVolume1 := "test-volume1" + testVolume2 := "test-volume2" + now := time.Now() + + mock := &mockEBSVolumesClient{ + DescribeVolumesOutput: ec2.DescribeVolumesOutput{ + Volumes: []types.Volume{ + { + VolumeId: aws.String(testVolume1), + CreateTime: aws.Time(now), + Tags: []types.Tag{{ + Key: aws.String("Name"), + Value: aws.String(testName1), + }}, + }, + { + VolumeId: aws.String(testVolume2), + CreateTime: aws.Time(now.Add(1)), + Tags: []types.Tag{{ + Key: aws.String("Name"), + Value: aws.String(testName2), + }}, + }, + }, + }, + } + + cfg := config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now.Add(-2 * time.Hour)), }, } - err := ev.nukeAll([]*string{aws.String("test-volume")}) + ids, err := listEBSVolumes(context.Background(), mock, resource.Scope{}, cfg) + require.NoError(t, err) + require.Empty(t, ids) +} + +func TestDeleteEBSVolume(t *testing.T) { + t.Parallel() + + mock := &mockEBSVolumesClient{ + DeleteVolumeOutput: ec2.DeleteVolumeOutput{}, + } + + err := deleteEBSVolume(context.Background(), mock, aws.String("test-volume")) require.NoError(t, err) } diff --git a/aws/resources/ebs_types.go b/aws/resources/ebs_types.go deleted file mode 100644 index 3538346e..00000000 --- a/aws/resources/ebs_types.go +++ /dev/null @@ -1,65 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type EBSVolumesAPI interface { - DescribeVolumes(ctx context.Context, params *ec2.DescribeVolumesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVolumesOutput, error) - DeleteVolume(ctx context.Context, params *ec2.DeleteVolumeInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVolumeOutput, error) -} - -// EBSVolumes - represents all ebs volumes -type EBSVolumes struct { - BaseAwsResource - Client EBSVolumesAPI - Region string - VolumeIds []string -} - -func (ev *EBSVolumes) Init(cfg aws.Config) { - ev.Client = ec2.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (ev *EBSVolumes) ResourceName() string { - return "ebs" -} - -// ResourceIdentifiers - The volume ids of the ebs volumes -func (ev *EBSVolumes) ResourceIdentifiers() []string { - return ev.VolumeIds -} - -func (ev *EBSVolumes) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (ev *EBSVolumes) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.EBSVolume -} - -func (ev *EBSVolumes) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := ev.getAll(c, configObj) - if err != nil { - return nil, err - } - - ev.VolumeIds = aws.ToStringSlice(identifiers) - return ev.VolumeIds, nil -} - -// Nuke - nuke 'em all!!! -func (ev *EBSVolumes) Nuke(ctx context.Context, identifiers []string) error { - if err := ev.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/ec2.go b/aws/resources/ec2.go index 3efc294a..5d22a775 100644 --- a/aws/resources/ec2.go +++ b/aws/resources/ec2.go @@ -8,36 +8,45 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/gruntwork-io/go-commons/errors" ) -// returns only instance Ids of unprotected ec2 instances -func (ei *EC2Instances) filterOutProtectedInstances(output *ec2.DescribeInstancesOutput, configObj config.Config) ([]*string, error) { - var filteredIds []*string - for _, reservation := range output.Reservations { - for _, instance := range reservation.Instances { - instanceID := *instance.InstanceId - - attr, err := ei.Client.DescribeInstanceAttribute(ei.Context, &ec2.DescribeInstanceAttributeInput{ - Attribute: types.InstanceAttributeNameDisableApiTermination, - InstanceId: aws.String(instanceID), - }) - if err != nil { - return nil, errors.WithStackTrace(err) - } +// EC2InstancesAPI defines the interface for EC2 Instances operations. +type EC2InstancesAPI interface { + DescribeInstanceAttribute(ctx context.Context, params *ec2.DescribeInstanceAttributeInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceAttributeOutput, error) + DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) + DescribeAddresses(ctx context.Context, params *ec2.DescribeAddressesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeAddressesOutput, error) + ReleaseAddress(ctx context.Context, params *ec2.ReleaseAddressInput, optFns ...func(*ec2.Options)) (*ec2.ReleaseAddressOutput, error) + TerminateInstances(ctx context.Context, params *ec2.TerminateInstancesInput, optFns ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error) +} - if shouldIncludeInstanceId(instance, *attr.DisableApiTermination.Value, configObj) { - filteredIds = append(filteredIds, &instanceID) +// NewEC2Instances creates a new EC2 Instances resource using the generic resource pattern. +func NewEC2Instances() AwsResource { + return NewAwsResource(&resource.Resource[EC2InstancesAPI]{ + ResourceTypeName: "ec2", + // Tentative batch size to ensure AWS doesn't throttle + BatchSize: 49, + InitClient: func(r *resource.Resource[EC2InstancesAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for EC2 client: expected aws.Config") + return } - } - } - - return filteredIds, nil + r.Scope.Region = awsCfg.Region + r.Client = ec2.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.EC2 + }, + Lister: listEC2Instances, + Nuker: deleteEC2Instances, + }) } -// Returns a formatted string of EC2 instance ids -func (ei *EC2Instances) getAll(ctx context.Context, configObj config.Config) ([]*string, error) { +// listEC2Instances retrieves all EC2 instances that match the config filters. +func listEC2Instances(ctx context.Context, client EC2InstancesAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { params := &ec2.DescribeInstancesInput{ Filters: []types.Filter{ { @@ -50,12 +59,12 @@ func (ei *EC2Instances) getAll(ctx context.Context, configObj config.Config) ([] }, } - output, err := ei.Client.DescribeInstances(ctx, params) + output, err := client.DescribeInstances(ctx, params) if err != nil { return nil, errors.WithStackTrace(err) } - instanceIds, err := ei.filterOutProtectedInstances(output, configObj) + instanceIds, err := filterOutProtectedInstances(ctx, client, output, cfg) if err != nil { return nil, errors.WithStackTrace(err) } @@ -63,7 +72,31 @@ func (ei *EC2Instances) getAll(ctx context.Context, configObj config.Config) ([] return instanceIds, nil } -func shouldIncludeInstanceId(instance types.Instance, protected bool, configObj config.Config) bool { +// filterOutProtectedInstances returns only instance IDs of unprotected EC2 instances +func filterOutProtectedInstances(ctx context.Context, client EC2InstancesAPI, output *ec2.DescribeInstancesOutput, cfg config.ResourceType) ([]*string, error) { + var filteredIds []*string + for _, reservation := range output.Reservations { + for _, instance := range reservation.Instances { + instanceID := *instance.InstanceId + + attr, err := client.DescribeInstanceAttribute(ctx, &ec2.DescribeInstanceAttributeInput{ + Attribute: types.InstanceAttributeNameDisableApiTermination, + InstanceId: aws.String(instanceID), + }) + if err != nil { + return nil, errors.WithStackTrace(err) + } + + if shouldIncludeInstanceId(instance, *attr.DisableApiTermination.Value, cfg) { + filteredIds = append(filteredIds, &instanceID) + } + } + } + + return filteredIds, nil +} + +func shouldIncludeInstanceId(instance types.Instance, protected bool, cfg config.ResourceType) bool { if protected { return false } @@ -71,19 +104,19 @@ func shouldIncludeInstanceId(instance types.Instance, protected bool, configObj // If Name is unset, GetEC2ResourceNameTagValue returns error and zero value string // Ignore this error and pass empty string to config.ShouldInclude instanceName := util.GetEC2ResourceNameTagValue(instance.Tags) - return configObj.EC2.ShouldInclude(config.ResourceValue{ + return cfg.ShouldInclude(config.ResourceValue{ Name: instanceName, Time: instance.LaunchTime, Tags: util.ConvertTypesTagsToMap(instance.Tags), }) } -func (ei *EC2Instances) releaseEIPs(instanceIds []*string) error { +func releaseEIPs(ctx context.Context, client EC2InstancesAPI, instanceIds []*string) error { logging.Debugf("Releasing Elastic IP address(s) associated with instances") for _, instanceID := range instanceIds { // Get the Elastic IPs associated with the EC2 instances - output, err := ei.Client.DescribeAddresses(ei.Context, &ec2.DescribeAddressesInput{ + output, err := client.DescribeAddresses(ctx, &ec2.DescribeAddressesInput{ Filters: []types.Filter{ { Name: aws.String("instance-id"), @@ -109,7 +142,7 @@ func (ei *EC2Instances) releaseEIPs(instanceIds []*string) error { continue } - _, err := ei.Client.ReleaseAddress(ei.Context, &ec2.ReleaseAddressInput{ + _, err := client.ReleaseAddress(ctx, &ec2.ReleaseAddressInput{ AllocationId: address.AllocationId, }) @@ -125,48 +158,48 @@ func (ei *EC2Instances) releaseEIPs(instanceIds []*string) error { return nil } -// Deletes all non protected EC2 instances -func (ei *EC2Instances) nukeAll(instanceIds []*string) error { - if len(instanceIds) == 0 { - logging.Debugf("No EC2 instances to nuke in region %s", ei.Region) +// deleteEC2Instances deletes all non protected EC2 instances. +func deleteEC2Instances(ctx context.Context, client EC2InstancesAPI, scope resource.Scope, resourceType string, identifiers []*string) error { + if len(identifiers) == 0 { + logging.Debugf("No EC2 instances to nuke in region %s", scope.Region) return nil } // release the attached elastic ip's // Note: This should be done before terminating the EC2 instances - err := ei.releaseEIPs(instanceIds) + err := releaseEIPs(ctx, client, identifiers) if err != nil { logging.Debugf("[Failed EIP release] %s", err) return errors.WithStackTrace(err) } - logging.Debugf("Terminating all EC2 instances in region %s", ei.Region) + logging.Debugf("Terminating all EC2 instances in region %s", scope.Region) params := &ec2.TerminateInstancesInput{ - InstanceIds: aws.ToStringSlice(instanceIds), + InstanceIds: aws.ToStringSlice(identifiers), } - _, err = ei.Client.TerminateInstances(ei.Context, params) + _, err = client.TerminateInstances(ctx, params) if err != nil { logging.Debugf("[Failed] %s", err) return errors.WithStackTrace(err) } - waiter := ec2.NewInstanceTerminatedWaiter(ei.Client) - err = waiter.Wait(ei.Context, &ec2.DescribeInstancesInput{ + waiter := ec2.NewInstanceTerminatedWaiter(client) + err = waiter.Wait(ctx, &ec2.DescribeInstancesInput{ Filters: []types.Filter{ { Name: aws.String("instance-id"), - Values: aws.ToStringSlice(instanceIds), + Values: aws.ToStringSlice(identifiers), }, }, - }, ei.Timeout) + }, DefaultWaitTimeout) if err != nil { logging.Debugf("[Failed] %s", err) return errors.WithStackTrace(err) } - logging.Debugf("[OK] %d instance(s) terminated in %s", len(instanceIds), ei.Region) + logging.Debugf("[OK] %d instance(s) terminated in %s", len(identifiers), scope.Region) return nil } diff --git a/aws/resources/ec2_dedicated_host.go b/aws/resources/ec2_dedicated_host.go index eee390ed..88d00c63 100644 --- a/aws/resources/ec2_dedicated_host.go +++ b/aws/resources/ec2_dedicated_host.go @@ -10,11 +10,41 @@ import ( "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/gruntwork-io/go-commons/errors" ) -func (h *EC2DedicatedHosts) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// EC2DedicatedHostsAPI defines the interface for EC2 Dedicated Hosts operations. +type EC2DedicatedHostsAPI interface { + DescribeHosts(ctx context.Context, params *ec2.DescribeHostsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeHostsOutput, error) + ReleaseHosts(ctx context.Context, params *ec2.ReleaseHostsInput, optFns ...func(*ec2.Options)) (*ec2.ReleaseHostsOutput, error) +} + +// NewEC2DedicatedHosts creates a new EC2DedicatedHosts resource using the generic resource pattern. +func NewEC2DedicatedHosts() AwsResource { + return NewAwsResource(&resource.Resource[EC2DedicatedHostsAPI]{ + ResourceTypeName: "ec2-dedicated-hosts", + BatchSize: 49, + InitClient: func(r *resource.Resource[EC2DedicatedHostsAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for EC2 client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = ec2.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.EC2DedicatedHosts + }, + Lister: listEC2DedicatedHosts, + Nuker: deleteEC2DedicatedHosts, + }) +} + +// listEC2DedicatedHosts retrieves all EC2 dedicated hosts that match the config filters. +func listEC2DedicatedHosts(ctx context.Context, client EC2DedicatedHostsAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { var hostIds []*string describeHostsInput := &ec2.DescribeHostsInput{ Filter: []types.Filter{ @@ -29,15 +59,15 @@ func (h *EC2DedicatedHosts) getAll(c context.Context, configObj config.Config) ( }, } - paginator := ec2.NewDescribeHostsPaginator(h.Client, describeHostsInput) + paginator := ec2.NewDescribeHostsPaginator(client, describeHostsInput) for paginator.HasMorePages() { - page, err := paginator.NextPage(c) + page, err := paginator.NextPage(ctx) if err != nil { return nil, errors.WithStackTrace(err) } for _, host := range page.Hosts { - if shouldIncludeHostId(&host, configObj) { + if shouldIncludeHostId(&host, cfg) { hostIds = append(hostIds, host.HostId) } } @@ -46,7 +76,8 @@ func (h *EC2DedicatedHosts) getAll(c context.Context, configObj config.Config) ( return hostIds, nil } -func shouldIncludeHostId(host *types.Host, configObj config.Config) bool { +// shouldIncludeHostId determines if an EC2 dedicated host should be included for deletion. +func shouldIncludeHostId(host *types.Host, cfg config.ResourceType) bool { if host == nil { return false } @@ -61,23 +92,24 @@ func shouldIncludeHostId(host *types.Host, configObj config.Config) bool { // Ignore this error and pass empty string to config.ShouldInclude hostNameTagValue := util.GetEC2ResourceNameTagValue(host.Tags) - return configObj.EC2DedicatedHosts.ShouldInclude(config.ResourceValue{ + return cfg.ShouldInclude(config.ResourceValue{ Name: hostNameTagValue, Time: host.AllocationTime, }) } -func (h *EC2DedicatedHosts) nukeAll(hostIds []*string) error { +// deleteEC2DedicatedHosts releases all EC2 dedicated hosts. +func deleteEC2DedicatedHosts(ctx context.Context, client EC2DedicatedHostsAPI, scope resource.Scope, resourceType string, hostIds []*string) error { if len(hostIds) == 0 { - logging.Debugf("No EC2 dedicated hosts to nuke in region %s", h.Region) + logging.Debugf("No EC2 dedicated hosts to nuke in region %s", scope.Region) return nil } - logging.Debugf("Releasing all EC2 dedicated host allocations in region %s", h.Region) + logging.Debugf("Releasing all EC2 dedicated host allocations in region %s", scope.Region) input := &ec2.ReleaseHostsInput{HostIds: aws.ToStringSlice(hostIds)} - releaseResult, err := h.Client.ReleaseHosts(h.Context, input) + releaseResult, err := client.ReleaseHosts(ctx, input) if err != nil { logging.Debugf("[Failed] %s", err) @@ -86,7 +118,7 @@ func (h *EC2DedicatedHosts) nukeAll(hostIds []*string) error { // Report successes and failures from release host request for _, hostSuccess := range releaseResult.Successful { - logging.Debugf("[OK] Dedicated host %s was released in %s", hostSuccess, h.Region) + logging.Debugf("[OK] Dedicated host %s was released in %s", hostSuccess, scope.Region) e := report.Entry{ Identifier: hostSuccess, ResourceType: "EC2 Dedicated Host", @@ -95,7 +127,7 @@ func (h *EC2DedicatedHosts) nukeAll(hostIds []*string) error { } for _, hostFailed := range releaseResult.Unsuccessful { - logging.Debugf("[ERROR] Unable to release dedicated host %s in %s: %s", aws.ToString(hostFailed.ResourceId), h.Region, aws.ToString(hostFailed.Error.Message)) + logging.Debugf("[ERROR] Unable to release dedicated host %s in %s: %s", aws.ToString(hostFailed.ResourceId), scope.Region, aws.ToString(hostFailed.Error.Message)) e := report.Entry{ Identifier: aws.ToString(hostFailed.ResourceId), ResourceType: "EC2 Dedicated Host", diff --git a/aws/resources/ec2_dedicated_host_test.go b/aws/resources/ec2_dedicated_host_test.go index a8608117..5170bdec 100644 --- a/aws/resources/ec2_dedicated_host_test.go +++ b/aws/resources/ec2_dedicated_host_test.go @@ -10,54 +10,54 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) -type mockedEC2DedicatedHosts struct { - EC2DedicatedHostsAPI +type mockEC2DedicatedHostsClient struct { DescribeHostsOutput ec2.DescribeHostsOutput ReleaseHostsOutput ec2.ReleaseHostsOutput } -func (m mockedEC2DedicatedHosts) DescribeHosts(ctx context.Context, params *ec2.DescribeHostsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeHostsOutput, error) { +func (m *mockEC2DedicatedHostsClient) DescribeHosts(ctx context.Context, params *ec2.DescribeHostsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeHostsOutput, error) { return &m.DescribeHostsOutput, nil } -func (m mockedEC2DedicatedHosts) ReleaseHosts(ctx context.Context, params *ec2.ReleaseHostsInput, optFns ...func(*ec2.Options)) (*ec2.ReleaseHostsOutput, error) { +func (m *mockEC2DedicatedHostsClient) ReleaseHosts(ctx context.Context, params *ec2.ReleaseHostsInput, optFns ...func(*ec2.Options)) (*ec2.ReleaseHostsOutput, error) { return &m.ReleaseHostsOutput, nil } -func TestEC2DedicatedHosts_GetAll(t *testing.T) { +func TestListEC2DedicatedHosts(t *testing.T) { t.Parallel() + testId1 := "test-host-id-1" testId2 := "test-host-id-2" testName1 := "test-host-name-1" testName2 := "test-host-name-2" now := time.Now() - h := EC2DedicatedHosts{ - Client: mockedEC2DedicatedHosts{ - DescribeHostsOutput: ec2.DescribeHostsOutput{ - Hosts: []types.Host{ - { - HostId: aws.String(testId1), - Tags: []types.Tag{ - { - Key: aws.String("Name"), - Value: aws.String(testName1), - }, + + mock := &mockEC2DedicatedHostsClient{ + DescribeHostsOutput: ec2.DescribeHostsOutput{ + Hosts: []types.Host{ + { + HostId: aws.String(testId1), + Tags: []types.Tag{ + { + Key: aws.String("Name"), + Value: aws.String(testName1), }, - AllocationTime: aws.Time(now), }, - { - HostId: aws.String(testId2), - Tags: []types.Tag{ - { - Key: aws.String("Name"), - Value: aws.String(testName2), - }, + AllocationTime: aws.Time(now), + }, + { + HostId: aws.String(testId2), + Tags: []types.Tag{ + { + Key: aws.String("Name"), + Value: aws.String(testName2), }, - AllocationTime: aws.Time(now.Add(1)), }, + AllocationTime: aws.Time(now.Add(1)), }, }, }, @@ -90,24 +90,20 @@ func TestEC2DedicatedHosts_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := h.getAll(context.Background(), config.Config{ - EC2DedicatedHosts: tc.configObj, - }) + names, err := listEC2DedicatedHosts(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) } - } -func TestEC2DedicatedHosts_NukeAll(t *testing.T) { +func TestDeleteEC2DedicatedHosts(t *testing.T) { t.Parallel() - h := EC2DedicatedHosts{ - Client: mockedEC2DedicatedHosts{ - ReleaseHostsOutput: ec2.ReleaseHostsOutput{}, - }, + + mock := &mockEC2DedicatedHostsClient{ + ReleaseHostsOutput: ec2.ReleaseHostsOutput{}, } - err := h.nukeAll([]*string{aws.String("test-host-id-1"), aws.String("test-host-id-2")}) + err := deleteEC2DedicatedHosts(context.Background(), mock, resource.Scope{Region: "us-east-1"}, "ec2-dedicated-hosts", []*string{aws.String("test-host-id-1"), aws.String("test-host-id-2")}) require.NoError(t, err) } diff --git a/aws/resources/ec2_dedicated_host_types.go b/aws/resources/ec2_dedicated_host_types.go deleted file mode 100644 index 84279c79..00000000 --- a/aws/resources/ec2_dedicated_host_types.go +++ /dev/null @@ -1,65 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type EC2DedicatedHostsAPI interface { - DescribeHosts(ctx context.Context, params *ec2.DescribeHostsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeHostsOutput, error) - ReleaseHosts(ctx context.Context, params *ec2.ReleaseHostsInput, optFns ...func(*ec2.Options)) (*ec2.ReleaseHostsOutput, error) -} - -// EC2DedicatedHosts - represents all host allocation IDs -type EC2DedicatedHosts struct { - BaseAwsResource - Client EC2DedicatedHostsAPI - Region string - HostIds []string -} - -func (h *EC2DedicatedHosts) Init(cfg aws.Config) { - h.Client = ec2.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (h *EC2DedicatedHosts) ResourceName() string { - return "ec2-dedicated-hosts" -} - -// ResourceIdentifiers - The instance ids of the ec2 instances -func (h *EC2DedicatedHosts) ResourceIdentifiers() []string { - return h.HostIds -} - -func (h *EC2DedicatedHosts) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (h *EC2DedicatedHosts) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.EC2DedicatedHosts -} - -func (h *EC2DedicatedHosts) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := h.getAll(c, configObj) - if err != nil { - return nil, err - } - - h.HostIds = aws.ToStringSlice(identifiers) - return h.HostIds, nil -} - -// Nuke - nuke 'em all!!! -func (h *EC2DedicatedHosts) Nuke(ctx context.Context, identifiers []string) error { - if err := h.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/ec2_endpoints.go b/aws/resources/ec2_endpoints.go index 1b83b084..78effd38 100644 --- a/aws/resources/ec2_endpoints.go +++ b/aws/resources/ec2_endpoints.go @@ -2,7 +2,6 @@ package resources import ( "context" - "fmt" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -10,11 +9,40 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" - "github.com/gruntwork-io/go-commons/errors" ) +// EC2EndpointsAPI defines the interface for EC2 VPC Endpoints operations. +type EC2EndpointsAPI interface { + DescribeVpcEndpoints(ctx context.Context, params *ec2.DescribeVpcEndpointsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcEndpointsOutput, error) + DeleteVpcEndpoints(ctx context.Context, params *ec2.DeleteVpcEndpointsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVpcEndpointsOutput, error) +} + +// NewEC2Endpoints creates a new EC2 VPC Endpoints resource using the generic resource pattern. +func NewEC2Endpoints() AwsResource { + return NewAwsResource(&resource.Resource[EC2EndpointsAPI]{ + ResourceTypeName: "ec2-endpoint", + BatchSize: 49, + InitClient: func(r *resource.Resource[EC2EndpointsAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for EC2 client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = ec2.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.EC2Endpoint + }, + Lister: listEC2Endpoints, + Nuker: resource.SimpleBatchDeleter(deleteEC2Endpoint), + PermissionVerifier: verifyEC2EndpointPermission, + }) +} + +// ShouldIncludeVpcEndpoint determines if a VPC endpoint should be included based on config filters. func ShouldIncludeVpcEndpoint(endpoint *types.VpcEndpoint, firstSeenTime *time.Time, configObj config.Config) bool { var endpointName string // get the tags as map @@ -30,78 +58,61 @@ func ShouldIncludeVpcEndpoint(endpoint *types.VpcEndpoint, firstSeenTime *time.T }) } -func (e *EC2Endpoints) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// listEC2Endpoints retrieves all VPC endpoints that match the config filters. +func listEC2Endpoints(ctx context.Context, client EC2EndpointsAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { var result []*string - var firstSeenTime *time.Time - var err error - endpoints, err := e.Client.DescribeVpcEndpoints(e.Context, &ec2.DescribeVpcEndpointsInput{}) - + endpoints, err := client.DescribeVpcEndpoints(ctx, &ec2.DescribeVpcEndpointsInput{}) if err != nil { - return nil, errors.WithStackTrace(err) + return nil, err } for _, endpoint := range endpoints.VpcEndpoints { - firstSeenTime, err = util.GetOrCreateFirstSeen(c, e.Client, endpoint.VpcEndpointId, util.ConvertTypesTagsToMap(endpoint.Tags)) + firstSeenTime, err := util.GetOrCreateFirstSeen(ctx, client, endpoint.VpcEndpointId, util.ConvertTypesTagsToMap(endpoint.Tags)) if err != nil { logging.Error("Unable to retrieve tags") - return nil, errors.WithStackTrace(err) + return nil, err } - if ShouldIncludeVpcEndpoint(&endpoint, firstSeenTime, configObj) { - result = append(result, endpoint.VpcEndpointId) + tagMap := util.ConvertTypesTagsToMap(endpoint.Tags) + var endpointName string + if name, ok := tagMap["Name"]; ok { + endpointName = name } + if cfg.ShouldInclude(config.ResourceValue{ + Name: &endpointName, + Time: firstSeenTime, + Tags: tagMap, + }) { + result = append(result, endpoint.VpcEndpointId) + } } - e.VerifyNukablePermissions(result, func(id *string) error { - _, err := e.Client.DeleteVpcEndpoints(e.Context, &ec2.DeleteVpcEndpointsInput{ - VpcEndpointIds: []string{aws.ToString(id)}, - DryRun: aws.Bool(true), - }) - return err - }) - return result, nil } -func (e *EC2Endpoints) nukeAll(identifiers []*string) error { - if len(identifiers) == 0 { - logging.Debugf("No Vpc Endpoints to nuke in region %s", e.Region) - return nil - } - - logging.Debugf("Deleting all Vpc Endpoints in region %s", e.Region) - var deletedAddresses []*string +// deleteEC2Endpoint deletes a single VPC endpoint. +func deleteEC2Endpoint(ctx context.Context, client EC2EndpointsAPI, id *string) error { + logging.Debugf("Deleting VPC endpoint %s", aws.ToString(id)) - for _, id := range identifiers { - if nukable, reason := e.IsNukable(*id); !nukable { - logging.Debugf("[Skipping] %s nuke because %v", *id, reason) - continue - } - - err := nukeVpcEndpoint(e.Client, []*string{id}) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(id), - ResourceType: "Vpc Endpoint", - Error: err, - } - report.Record(e) - - if err != nil { - logging.Debugf("[Failed] %s", err) - } else { - deletedAddresses = append(deletedAddresses, id) - } - } - - logging.Debugf("[OK] %d Vpc Endpoint(s) deleted in %s", len(deletedAddresses), e.Region) + _, err := client.DeleteVpcEndpoints(ctx, &ec2.DeleteVpcEndpointsInput{ + VpcEndpointIds: []string{aws.ToString(id)}, + }) + return err +} - return nil +// verifyEC2EndpointPermission performs a dry-run delete to check permissions. +func verifyEC2EndpointPermission(ctx context.Context, client EC2EndpointsAPI, id *string) error { + _, err := client.DeleteVpcEndpoints(ctx, &ec2.DeleteVpcEndpointsInput{ + VpcEndpointIds: []string{aws.ToString(id)}, + DryRun: aws.Bool(true), + }) + return err } +// nukeVpcEndpoint deletes VPC endpoints by their IDs. +// This is exported for use by ec2_vpc.go when nuking VPCs. func nukeVpcEndpoint(client EC2EndpointsAPI, endpointIds []*string) error { logging.Debugf("Deleting VPC endpoints %s", aws.ToStringSlice(endpointIds)) @@ -109,11 +120,10 @@ func nukeVpcEndpoint(client EC2EndpointsAPI, endpointIds []*string) error { VpcEndpointIds: aws.ToStringSlice(endpointIds), }) if err != nil { - logging.Debug(fmt.Sprintf("Failed to delete VPC endpoints: %s", err.Error())) - return errors.WithStackTrace(err) + logging.Debugf("Failed to delete VPC endpoints: %s", err.Error()) + return err } - logging.Debug(fmt.Sprintf("Successfully deleted VPC endpoints %s", aws.ToStringSlice(endpointIds))) - + logging.Debugf("Successfully deleted VPC endpoints %s", aws.ToStringSlice(endpointIds)) return nil } diff --git a/aws/resources/ec2_endpoints_test.go b/aws/resources/ec2_endpoints_test.go index 0e736aa4..01d8937e 100644 --- a/aws/resources/ec2_endpoints_test.go +++ b/aws/resources/ec2_endpoints_test.go @@ -10,151 +10,163 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type mockedEc2VpcEndpoints struct { - EC2EndpointsAPI +type mockEC2EndpointsClient struct { DescribeVpcEndpointsOutput ec2.DescribeVpcEndpointsOutput DeleteVpcEndpointsOutput ec2.DeleteVpcEndpointsOutput } -func (m mockedEc2VpcEndpoints) DescribeVpcEndpoints(ctx context.Context, params *ec2.DescribeVpcEndpointsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcEndpointsOutput, error) { +func (m *mockEC2EndpointsClient) DescribeVpcEndpoints(ctx context.Context, params *ec2.DescribeVpcEndpointsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcEndpointsOutput, error) { return &m.DescribeVpcEndpointsOutput, nil } -func (m mockedEc2VpcEndpoints) DeleteVpcEndpoints(ctx context.Context, params *ec2.DeleteVpcEndpointsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVpcEndpointsOutput, error) { +func (m *mockEC2EndpointsClient) DeleteVpcEndpoints(ctx context.Context, params *ec2.DeleteVpcEndpointsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVpcEndpointsOutput, error) { return &m.DeleteVpcEndpointsOutput, nil } -func TestVcpEndpoint_GetAll(t *testing.T) { +func TestEC2Endpoints_ResourceName(t *testing.T) { + r := NewEC2Endpoints() + assert.Equal(t, "ec2-endpoint", r.ResourceName()) +} + +func TestEC2Endpoints_MaxBatchSize(t *testing.T) { + r := NewEC2Endpoints() + assert.Equal(t, 49, r.MaxBatchSize()) +} + +func TestListEC2Endpoints(t *testing.T) { t.Parallel() // Set excludeFirstSeenTag to false for testing ctx := context.WithValue(context.Background(), util.ExcludeFirstSeenTagKey, false) - var ( - now = time.Now() - endpoint1 = "vpce-0b201b2dcd4f77a2f001" - endpoint2 = "vpce-0b201b2dcd4f77a2f002" - - testName1 = "cloud-nuke-igw-001" - testName2 = "cloud-nuke-igw-002" - ) - vpcEndpoint := EC2Endpoints{ - Client: mockedEc2VpcEndpoints{ - DescribeVpcEndpointsOutput: ec2.DescribeVpcEndpointsOutput{ - VpcEndpoints: []types.VpcEndpoint{ - { - VpcEndpointId: aws.String(endpoint1), - Tags: []types.Tag{ - { - Key: aws.String("Name"), - Value: aws.String(testName1), - }, { - Key: aws.String(util.FirstSeenTagKey), - Value: aws.String(util.FormatTimestamp(now)), - }, - }, + now := time.Now() + endpoint1 := "vpce-0b201b2dcd4f77a2f001" + endpoint2 := "vpce-0b201b2dcd4f77a2f002" + testName1 := "cloud-nuke-endpoint-001" + testName2 := "cloud-nuke-endpoint-002" + + mock := &mockEC2EndpointsClient{ + DescribeVpcEndpointsOutput: ec2.DescribeVpcEndpointsOutput{ + VpcEndpoints: []types.VpcEndpoint{ + { + VpcEndpointId: aws.String(endpoint1), + Tags: []types.Tag{ + {Key: aws.String("Name"), Value: aws.String(testName1)}, + {Key: aws.String(util.FirstSeenTagKey), Value: aws.String(util.FormatTimestamp(now))}, }, - { - VpcEndpointId: aws.String(endpoint2), - Tags: []types.Tag{ - { - Key: aws.String("Name"), - Value: aws.String(testName2), - }, { - Key: aws.String(util.FirstSeenTagKey), - Value: aws.String(util.FormatTimestamp(now.Add(1 * time.Hour))), - }, - }, + }, + { + VpcEndpointId: aws.String(endpoint2), + Tags: []types.Tag{ + {Key: aws.String("Name"), Value: aws.String(testName2)}, + {Key: aws.String(util.FirstSeenTagKey), Value: aws.String(util.FormatTimestamp(now.Add(1 * time.Hour)))}, }, }, }, }, } - tests := map[string]struct { - ctx context.Context - configObj config.ResourceType - expected []string - }{ - "emptyFilter": { - ctx: ctx, - configObj: config.ResourceType{}, - expected: []string{endpoint1, endpoint2}, - }, - "nameExclusionFilter": { - ctx: ctx, - configObj: config.ResourceType{ - ExcludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{{ - RE: *regexp.MustCompile(testName1), - }}}, + ids, err := listEC2Endpoints(ctx, mock, resource.Scope{}, config.ResourceType{}) + require.NoError(t, err) + require.ElementsMatch(t, []string{endpoint1, endpoint2}, aws.ToStringSlice(ids)) +} + +func TestListEC2Endpoints_WithNameExclusionFilter(t *testing.T) { + t.Parallel() + + ctx := context.WithValue(context.Background(), util.ExcludeFirstSeenTagKey, false) + + now := time.Now() + endpoint1 := "vpce-0b201b2dcd4f77a2f001" + endpoint2 := "vpce-0b201b2dcd4f77a2f002" + testName1 := "cloud-nuke-endpoint-001" + testName2 := "cloud-nuke-endpoint-002" + + mock := &mockEC2EndpointsClient{ + DescribeVpcEndpointsOutput: ec2.DescribeVpcEndpointsOutput{ + VpcEndpoints: []types.VpcEndpoint{ + { + VpcEndpointId: aws.String(endpoint1), + Tags: []types.Tag{ + {Key: aws.String("Name"), Value: aws.String(testName1)}, + {Key: aws.String(util.FirstSeenTagKey), Value: aws.String(util.FormatTimestamp(now))}, + }, + }, + { + VpcEndpointId: aws.String(endpoint2), + Tags: []types.Tag{ + {Key: aws.String("Name"), Value: aws.String(testName2)}, + {Key: aws.String(util.FirstSeenTagKey), Value: aws.String(util.FormatTimestamp(now.Add(1 * time.Hour)))}, + }, + }, }, - expected: []string{endpoint2}, - }, - "timeAfterExclusionFilter": { - ctx: ctx, - configObj: config.ResourceType{ - ExcludeRule: config.FilterRule{ - TimeAfter: aws.Time(now), - }}, - expected: []string{endpoint1}, - }, - "timeBeforeExclusionFilter": { - ctx: ctx, - configObj: config.ResourceType{ - ExcludeRule: config.FilterRule{ - TimeBefore: aws.Time(now.Add(1)), - }}, - expected: []string{endpoint2}, }, } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - names, err := vpcEndpoint.getAll(tc.ctx, config.Config{ - EC2Endpoint: tc.configObj, - }) - require.NoError(t, err) - require.Equal(t, tc.expected, aws.ToStringSlice(names)) - }) + + cfg := config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{RE: *regexp.MustCompile(testName1)}}, + }, } + + ids, err := listEC2Endpoints(ctx, mock, resource.Scope{}, cfg) + require.NoError(t, err) + require.Equal(t, []string{endpoint2}, aws.ToStringSlice(ids)) } -func TestEc2Endpoints_NukeAll(t *testing.T) { +func TestListEC2Endpoints_WithTimeAfterExclusionFilter(t *testing.T) { t.Parallel() - var ( - endpoint1 = "vpce-0b201b2dcd4f77a2f001" - endpoint2 = "vpce-0b201b2dcd4f77a2f002" - ) - - igw := EC2Endpoints{ - BaseAwsResource: BaseAwsResource{ - Nukables: map[string]error{ - endpoint1: nil, - endpoint2: nil, - }, - }, - Client: mockedEc2VpcEndpoints{ - DescribeVpcEndpointsOutput: ec2.DescribeVpcEndpointsOutput{ - VpcEndpoints: []types.VpcEndpoint{ - { - VpcEndpointId: aws.String(endpoint1), + + ctx := context.WithValue(context.Background(), util.ExcludeFirstSeenTagKey, false) + + now := time.Now() + endpoint1 := "vpce-0b201b2dcd4f77a2f001" + endpoint2 := "vpce-0b201b2dcd4f77a2f002" + testName1 := "cloud-nuke-endpoint-001" + testName2 := "cloud-nuke-endpoint-002" + + mock := &mockEC2EndpointsClient{ + DescribeVpcEndpointsOutput: ec2.DescribeVpcEndpointsOutput{ + VpcEndpoints: []types.VpcEndpoint{ + { + VpcEndpointId: aws.String(endpoint1), + Tags: []types.Tag{ + {Key: aws.String("Name"), Value: aws.String(testName1)}, + {Key: aws.String(util.FirstSeenTagKey), Value: aws.String(util.FormatTimestamp(now))}, }, - { - VpcEndpointId: aws.String(endpoint2), + }, + { + VpcEndpointId: aws.String(endpoint2), + Tags: []types.Tag{ + {Key: aws.String("Name"), Value: aws.String(testName2)}, + {Key: aws.String(util.FirstSeenTagKey), Value: aws.String(util.FormatTimestamp(now.Add(1 * time.Hour)))}, }, }, }, - DeleteVpcEndpointsOutput: ec2.DeleteVpcEndpointsOutput{}, }, } - err := igw.nukeAll([]*string{ - aws.String(endpoint1), - aws.String(endpoint2), - }) + cfg := config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now), + }, + } + + ids, err := listEC2Endpoints(ctx, mock, resource.Scope{}, cfg) + require.NoError(t, err) + require.Equal(t, []string{endpoint1}, aws.ToStringSlice(ids)) +} + +func TestDeleteEC2Endpoint(t *testing.T) { + t.Parallel() + + mock := &mockEC2EndpointsClient{} + err := deleteEC2Endpoint(context.Background(), mock, aws.String("vpce-12345")) require.NoError(t, err) } diff --git a/aws/resources/ec2_endpoints_types.go b/aws/resources/ec2_endpoints_types.go deleted file mode 100644 index 25a937fd..00000000 --- a/aws/resources/ec2_endpoints_types.go +++ /dev/null @@ -1,64 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type EC2EndpointsAPI interface { - DescribeVpcEndpoints(ctx context.Context, params *ec2.DescribeVpcEndpointsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcEndpointsOutput, error) - DeleteVpcEndpoints(ctx context.Context, params *ec2.DeleteVpcEndpointsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVpcEndpointsOutput, error) -} - -// EC2Endpoints - represents all ec2 endpoints -type EC2Endpoints struct { - BaseAwsResource - Client EC2EndpointsAPI - Region string - Endpoints []string -} - -func (e *EC2Endpoints) Init(cfg aws.Config) { - e.Client = ec2.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (e *EC2Endpoints) ResourceName() string { - return "ec2-endpoint" -} - -func (e *EC2Endpoints) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (e *EC2Endpoints) ResourceIdentifiers() []string { - return e.Endpoints -} - -func (e *EC2Endpoints) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.EC2Endpoint -} - -func (e *EC2Endpoints) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := e.getAll(c, configObj) - if err != nil { - return nil, err - } - - e.Endpoints = aws.ToStringSlice(identifiers) - return e.Endpoints, nil -} - -// Nuke - nuke 'em all!!! -func (e *EC2Endpoints) Nuke(ctx context.Context, identifiers []string) error { - if err := e.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/ec2_test.go b/aws/resources/ec2_test.go index 34ce7328..55388d3b 100644 --- a/aws/resources/ec2_test.go +++ b/aws/resources/ec2_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) @@ -45,53 +46,62 @@ func (m mockedEC2Instances) ReleaseAddress(ctx context.Context, params *ec2.Rele return &m.ReleaseAddressOutput, nil } -func TestEc2Instances_GetAll(t *testing.T) { +func TestEC2Instances_ResourceName(t *testing.T) { + r := NewEC2Instances() + require.Equal(t, "ec2", r.ResourceName()) +} + +func TestEC2Instances_MaxBatchSize(t *testing.T) { + r := NewEC2Instances() + require.Equal(t, 49, r.MaxBatchSize()) +} + +func TestListEC2Instances(t *testing.T) { t.Parallel() testId1 := "testId1" testId2 := "testId2" testName1 := "testName1" testName2 := "testName2" now := time.Now() - ei := EC2Instances{ - Client: mockedEC2Instances{ - DescribeInstancesOutput: ec2.DescribeInstancesOutput{ - Reservations: []types.Reservation{ - { - Instances: []types.Instance{ - { - InstanceId: aws.String(testId1), - Tags: []types.Tag{ - { - Key: aws.String("Name"), - Value: aws.String(testName1), - }, + + mock := mockedEC2Instances{ + DescribeInstancesOutput: ec2.DescribeInstancesOutput{ + Reservations: []types.Reservation{ + { + Instances: []types.Instance{ + { + InstanceId: aws.String(testId1), + Tags: []types.Tag{ + { + Key: aws.String("Name"), + Value: aws.String(testName1), }, - LaunchTime: aws.Time(now), }, - { - InstanceId: aws.String(testId2), - Tags: []types.Tag{ - { - Key: aws.String("Name"), - Value: aws.String(testName2), - }, + LaunchTime: aws.Time(now), + }, + { + InstanceId: aws.String(testId2), + Tags: []types.Tag{ + { + Key: aws.String("Name"), + Value: aws.String(testName2), }, - LaunchTime: aws.Time(now.Add(1)), }, + LaunchTime: aws.Time(now.Add(1)), }, }, }, }, - DescribeInstanceAttributeOutput: map[string]ec2.DescribeInstanceAttributeOutput{ - testId1: { - DisableApiTermination: &types.AttributeBooleanValue{ - Value: aws.Bool(false), - }, + }, + DescribeInstanceAttributeOutput: map[string]ec2.DescribeInstanceAttributeOutput{ + testId1: { + DisableApiTermination: &types.AttributeBooleanValue{ + Value: aws.Bool(false), }, - testId2: { - DisableApiTermination: &types.AttributeBooleanValue{ - Value: aws.Bool(false), - }, + }, + testId2: { + DisableApiTermination: &types.AttributeBooleanValue{ + Value: aws.Bool(false), }, }, }, @@ -124,32 +134,25 @@ func TestEc2Instances_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := ei.getAll(context.Background(), config.Config{ - EC2: tc.configObj, - }) + names, err := listEC2Instances(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) } } -func TestEc2Instances_NukeAll(t *testing.T) { +func TestDeleteEC2Instances(t *testing.T) { t.Parallel() - ei := EC2Instances{ - BaseAwsResource: BaseAwsResource{ - Context: context.Background(), - Timeout: DefaultWaitTimeout, - }, - Client: mockedEC2Instances{ - DescribeInstancesOutput: ec2.DescribeInstancesOutput{ - Reservations: []types.Reservation{ - { - Instances: []types.Instance{ - { - InstanceId: aws.String("testId1"), - State: &types.InstanceState{ - Name: types.InstanceStateNameTerminated, - }, + + mock := mockedEC2Instances{ + DescribeInstancesOutput: ec2.DescribeInstancesOutput{ + Reservations: []types.Reservation{ + { + Instances: []types.Instance{ + { + InstanceId: aws.String("testId1"), + State: &types.InstanceState{ + Name: types.InstanceStateNameTerminated, }, }, }, @@ -158,44 +161,39 @@ func TestEc2Instances_NukeAll(t *testing.T) { }, } - err := ei.nukeAll([]*string{aws.String("testId1")}) + err := deleteEC2Instances(context.Background(), mock, resource.Scope{Region: "us-east-1"}, "ec2", []*string{aws.String("testId1")}) require.NoError(t, err) } -func TestEc2InstancesWithEIP_NukeAll(t *testing.T) { +func TestDeleteEC2InstancesWithEIP(t *testing.T) { t.Parallel() - ei := EC2Instances{ - BaseAwsResource: BaseAwsResource{ - Context: context.Background(), - Timeout: DefaultWaitTimeout, - }, - Client: mockedEC2Instances{ - DescribeInstancesOutput: ec2.DescribeInstancesOutput{ - Reservations: []types.Reservation{ - { - Instances: []types.Instance{ - { - InstanceId: aws.String("testId1"), - State: &types.InstanceState{ - Name: types.InstanceStateNameTerminated, - }, + + mock := mockedEC2Instances{ + DescribeInstancesOutput: ec2.DescribeInstancesOutput{ + Reservations: []types.Reservation{ + { + Instances: []types.Instance{ + { + InstanceId: aws.String("testId1"), + State: &types.InstanceState{ + Name: types.InstanceStateNameTerminated, }, }, }, }, }, - TerminateInstancesOutput: ec2.TerminateInstancesOutput{}, - DescribeAddressesOutput: ec2.DescribeAddressesOutput{ - Addresses: []types.Address{ - { - AllocationId: aws.String("alloc-test-id1"), - InstanceId: aws.String("testId1"), - }, + }, + TerminateInstancesOutput: ec2.TerminateInstancesOutput{}, + DescribeAddressesOutput: ec2.DescribeAddressesOutput{ + Addresses: []types.Address{ + { + AllocationId: aws.String("alloc-test-id1"), + InstanceId: aws.String("testId1"), }, }, }, } - err := ei.nukeAll([]*string{aws.String("testId1")}) + err := deleteEC2Instances(context.Background(), mock, resource.Scope{Region: "us-east-1"}, "ec2", []*string{aws.String("testId1")}) require.NoError(t, err) } diff --git a/aws/resources/ec2_types.go b/aws/resources/ec2_types.go deleted file mode 100644 index 14ed5bc8..00000000 --- a/aws/resources/ec2_types.go +++ /dev/null @@ -1,68 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type EC2InstancesAPI interface { - DescribeInstanceAttribute(ctx context.Context, params *ec2.DescribeInstanceAttributeInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceAttributeOutput, error) - DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) - DescribeAddresses(ctx context.Context, params *ec2.DescribeAddressesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeAddressesOutput, error) - ReleaseAddress(ctx context.Context, params *ec2.ReleaseAddressInput, optFns ...func(*ec2.Options)) (*ec2.ReleaseAddressOutput, error) - TerminateInstances(ctx context.Context, params *ec2.TerminateInstancesInput, optFns ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error) -} - -// EC2Instances - represents all ec2 instances -type EC2Instances struct { - BaseAwsResource - Client EC2InstancesAPI - Region string - InstanceIds []string -} - -func (ei *EC2Instances) Init(cfg aws.Config) { - ei.Client = ec2.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (ei *EC2Instances) ResourceName() string { - return "ec2" -} - -// ResourceIdentifiers - The instance ids of the ec2 instances -func (ei *EC2Instances) ResourceIdentifiers() []string { - return ei.InstanceIds -} - -func (ei *EC2Instances) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (ei *EC2Instances) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.EC2 -} - -func (ei *EC2Instances) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := ei.getAll(c, configObj) - if err != nil { - return nil, err - } - - ei.InstanceIds = aws.ToStringSlice(identifiers) - return ei.InstanceIds, nil -} - -// Nuke - nuke 'em all!!! -func (ei *EC2Instances) Nuke(ctx context.Context, identifiers []string) error { - if err := ei.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/ecs_cluster.go b/aws/resources/ecs_cluster.go index 2280ce09..b977af78 100644 --- a/aws/resources/ecs_cluster.go +++ b/aws/resources/ecs_cluster.go @@ -9,11 +9,22 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/gruntwork-io/go-commons/errors" ) +// ECSClustersAPI defines the interface for ECS Clusters operations. +type ECSClustersAPI interface { + DescribeClusters(ctx context.Context, params *ecs.DescribeClustersInput, optFns ...func(*ecs.Options)) (*ecs.DescribeClustersOutput, error) + DeleteCluster(ctx context.Context, params *ecs.DeleteClusterInput, optFns ...func(*ecs.Options)) (*ecs.DeleteClusterOutput, error) + ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) + ListTagsForResource(ctx context.Context, params *ecs.ListTagsForResourceInput, optFns ...func(*ecs.Options)) (*ecs.ListTagsForResourceOutput, error) + ListTasks(ctx context.Context, params *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) + StopTask(ctx context.Context, params *ecs.StopTaskInput, optFns ...func(*ecs.Options)) (*ecs.StopTaskOutput, error) + TagResource(ctx context.Context, params *ecs.TagResourceInput, optFns ...func(*ecs.Options)) (*ecs.TagResourceOutput, error) +} + // Used in this context to determine if the ECS Cluster is ready to be used & tagged // For more details on other valid status values: https://docs.aws.amazon.com/sdk-for-go/api/service/ecs/#Cluster const activeEcsClusterStatus string = "ACTIVE" @@ -22,37 +33,37 @@ const activeEcsClusterStatus string = "ACTIVE" // For more details on this, please read here: https://docs.aws.amazon.com/cli/latest/reference/ecs/describe-clusters.html#options const describeClustersRequestBatchSize = 100 -// getAllEcsClusters returns all ECS Cluster ARNs. -// Handles pagination until all pages are retrieved. -func (clusters *ECSClusters) getAllEcsClusters() ([]*string, error) { - var clusterArns []string - nextToken := (*string)(nil) - - for { - resp, err := clusters.Client.ListClusters(clusters.Context, &ecs.ListClustersInput{ - NextToken: nextToken, - }) - if err != nil { - return nil, errors.WithStackTrace(err) - } - - clusterArns = append(clusterArns, resp.ClusterArns...) - if resp.NextToken == nil || *resp.NextToken == "" { - break - } - nextToken = resp.NextToken - } - - return aws.StringSlice(clusterArns), nil +// NewECSClusters creates a new ECS Clusters resource using the generic resource pattern. +func NewECSClusters() AwsResource { + return NewAwsResource(&resource.Resource[ECSClustersAPI]{ + ResourceTypeName: "ecscluster", + BatchSize: maxBatchSize, + InitClient: func(r *resource.Resource[ECSClustersAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for ECS client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = ecs.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.ECSCluster + }, + Lister: listECSClusters, + Nuker: resource.MultiStepDeleter(stopClusterRunningTasks, deleteECSCluster), + }) } -func (clusters *ECSClusters) getAll(c context.Context, configObj config.Config) ([]*string, error) { - allClusters, err := clusters.getAllEcsClusters() +// listECSClusters retrieves all ECS clusters that match the config filters. +func listECSClusters(ctx context.Context, client ECSClustersAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + // Get all cluster ARNs + allClusters, err := getAllEcsClusters(ctx, client) if err != nil { return nil, errors.WithStackTrace(err) } - excludeFirstSeenTag, err := util.GetBoolFromContext(c, util.ExcludeFirstSeenTagKey) + excludeFirstSeenTag, err := util.GetBoolFromContext(ctx, util.ExcludeFirstSeenTagKey) if err != nil { return nil, errors.WithStackTrace(err) } @@ -62,7 +73,7 @@ func (clusters *ECSClusters) getAll(c context.Context, configObj config.Config) batches := util.Split(clusterList, describeClustersRequestBatchSize) for _, batch := range batches { - resp, err := clusters.Client.DescribeClusters(clusters.Context, &ecs.DescribeClustersInput{ + resp, err := client.DescribeClusters(ctx, &ecs.DescribeClustersInput{ Clusters: batch, }) if err != nil { @@ -75,12 +86,12 @@ func (clusters *ECSClusters) getAll(c context.Context, configObj config.Config) } // Get all tags for the cluster for filtering purposes - tags, err := clusters.getAllTags(cluster.ClusterArn) + tags, err := getAllEcsClusterTags(ctx, client, cluster.ClusterArn) if err != nil { return nil, errors.WithStackTrace(err) } - if !configObj.ECSCluster.ShouldInclude(config.ResourceValue{ + if !cfg.ShouldInclude(config.ResourceValue{ Name: cluster.ClusterName, Tags: tags, }) { @@ -92,19 +103,19 @@ func (clusters *ECSClusters) getAll(c context.Context, configObj config.Config) continue } - firstSeenTime, err := clusters.getFirstSeenTag(cluster.ClusterArn) + firstSeenTime, err := getEcsClusterFirstSeenTag(ctx, client, cluster.ClusterArn) if err != nil { return nil, errors.WithStackTrace(err) } if firstSeenTime == nil { - if err := clusters.setFirstSeenTag(cluster.ClusterArn, time.Now().UTC()); err != nil { + if err := setEcsClusterFirstSeenTag(ctx, client, cluster.ClusterArn, time.Now().UTC()); err != nil { return nil, errors.WithStackTrace(err) } continue } - if configObj.ECSCluster.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Time: firstSeenTime, Name: cluster.ClusterName, Tags: tags, @@ -117,111 +128,37 @@ func (clusters *ECSClusters) getAll(c context.Context, configObj config.Config) return result, nil } -func (clusters *ECSClusters) stopClusterRunningTasks(clusterArn *string) error { - logging.Debugf("[TASK] stopping tasks running on cluster %v", *clusterArn) - // before deleting the cluster, remove the active tasks on that cluster - runningTasks, err := clusters.Client.ListTasks(clusters.Context, &ecs.ListTasksInput{ - Cluster: clusterArn, - DesiredStatus: types.DesiredStatusRunning, - }) - - if err != nil { - return errors.WithStackTrace(err) - } +// getAllEcsClusters returns all ECS Cluster ARNs. +// Handles pagination until all pages are retrieved. +func getAllEcsClusters(ctx context.Context, client ECSClustersAPI) ([]*string, error) { + var clusterArns []string + nextToken := (*string)(nil) - // stop the listed tasks - for _, task := range runningTasks.TaskArns { - _, err := clusters.Client.StopTask(clusters.Context, &ecs.StopTaskInput{ - Cluster: clusterArn, - Task: aws.String(task), - Reason: aws.String("Terminating task due to cluster deletion"), + for { + resp, err := client.ListClusters(ctx, &ecs.ListClustersInput{ + NextToken: nextToken, }) if err != nil { - logging.Debugf("[TASK] Unable to stop the task %s on cluster %s. Reason: %v", task, *clusterArn, err) - return errors.WithStackTrace(err) - } - logging.Debugf("[TASK] Success, stopped task %v", task) - } - return nil -} - -func (clusters *ECSClusters) nukeAll(ecsClusterArns []*string) error { - numNuking := len(ecsClusterArns) - - if numNuking == 0 { - logging.Debugf("No ECS clusters to nuke in region %s", clusters.Region) - return nil - } - - logging.Debugf("Deleting %d ECS clusters in region %s", numNuking, clusters.Region) - - var nukedEcsClusters []*string - for _, clusterArn := range ecsClusterArns { - - // before nuking the clusters, do check active tasks on the cluster and stop all of them - err := clusters.stopClusterRunningTasks(clusterArn) - if err != nil { - logging.Debugf("Error, unable to stop the running stasks on the cluster %s %s", aws.ToString(clusterArn), err) - return errors.WithStackTrace(err) - } - - params := &ecs.DeleteClusterInput{ - Cluster: clusterArn, - } - _, err = clusters.Client.DeleteCluster(clusters.Context, params) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(clusterArn), - ResourceType: "ECS Cluster", - Error: err, + return nil, errors.WithStackTrace(err) } - report.Record(e) - if err != nil { - logging.Debugf("Error, failed to delete cluster with ARN %s %s", aws.ToString(clusterArn), err) - return errors.WithStackTrace(err) + clusterArns = append(clusterArns, resp.ClusterArns...) + if resp.NextToken == nil || *resp.NextToken == "" { + break } - - logging.Debugf("Success, deleted cluster: %s", aws.ToString(clusterArn)) - nukedEcsClusters = append(nukedEcsClusters, clusterArn) - } - - numNuked := len(nukedEcsClusters) - logging.Debugf("[OK] %d of %d ECS cluster(s) deleted in %s", numNuked, numNuking, clusters.Region) - - return nil -} - -// Tag an ECS cluster identified by the given cluster ARN when it's first seen by cloud-nuke -func (clusters *ECSClusters) setFirstSeenTag(clusterArn *string, timestamp time.Time) error { - firstSeenTime := util.FormatTimestamp(timestamp) - - input := &ecs.TagResourceInput{ - ResourceArn: clusterArn, - Tags: []types.Tag{ - { - Key: aws.String(firstSeenTagKey), - Value: aws.String(firstSeenTime), - }, - }, - } - - _, err := clusters.Client.TagResource(clusters.Context, input) - if err != nil { - return errors.WithStackTrace(err) + nextToken = resp.NextToken } - return nil + return aws.StringSlice(clusterArns), nil } -// getAllTags retrieves all tags for a given ECS cluster and returns them as a map -func (clusters *ECSClusters) getAllTags(clusterArn *string) (map[string]string, error) { +// getAllEcsClusterTags retrieves all tags for a given ECS cluster and returns them as a map +func getAllEcsClusterTags(ctx context.Context, client ECSClustersAPI, clusterArn *string) (map[string]string, error) { input := &ecs.ListTagsForResourceInput{ ResourceArn: clusterArn, } - clusterTags, err := clusters.Client.ListTagsForResource(clusters.Context, input) + clusterTags, err := client.ListTagsForResource(ctx, input) if err != nil { logging.Debugf("Error getting the tags for ECS cluster with ARN %s", aws.ToString(clusterArn)) return nil, errors.WithStackTrace(err) @@ -237,15 +174,15 @@ func (clusters *ECSClusters) getAllTags(clusterArn *string) (map[string]string, return tags, nil } -// Get the `cloud-nuke-first-seen` tag value for a given ECS cluster -func (clusters *ECSClusters) getFirstSeenTag(clusterArn *string) (*time.Time, error) { +// getEcsClusterFirstSeenTag gets the `cloud-nuke-first-seen` tag value for a given ECS cluster +func getEcsClusterFirstSeenTag(ctx context.Context, client ECSClustersAPI, clusterArn *string) (*time.Time, error) { var firstSeenTime *time.Time input := &ecs.ListTagsForResourceInput{ ResourceArn: clusterArn, } - clusterTags, err := clusters.Client.ListTagsForResource(clusters.Context, input) + clusterTags, err := client.ListTagsForResource(ctx, input) if err != nil { logging.Debugf("Error getting the tags for ECS cluster with ARN %s", aws.ToString(clusterArn)) return firstSeenTime, errors.WithStackTrace(err) @@ -253,8 +190,7 @@ func (clusters *ECSClusters) getFirstSeenTag(clusterArn *string) (*time.Time, er for _, tag := range clusterTags.Tags { if util.IsFirstSeenTag(tag.Key) { - - firstSeenTime, err := util.ParseTimestamp(tag.Value) + firstSeenTime, err = util.ParseTimestamp(tag.Value) if err != nil { logging.Debugf("Error parsing the `cloud-nuke-first-seen` tag for ECS cluster with ARN %s", aws.ToString(clusterArn)) return firstSeenTime, errors.WithStackTrace(err) @@ -266,3 +202,62 @@ func (clusters *ECSClusters) getFirstSeenTag(clusterArn *string) (*time.Time, er return firstSeenTime, nil } + +// setEcsClusterFirstSeenTag tags an ECS cluster with the first seen timestamp +func setEcsClusterFirstSeenTag(ctx context.Context, client ECSClustersAPI, clusterArn *string, timestamp time.Time) error { + firstSeenTime := util.FormatTimestamp(timestamp) + + input := &ecs.TagResourceInput{ + ResourceArn: clusterArn, + Tags: []types.Tag{ + { + Key: aws.String(firstSeenTagKey), + Value: aws.String(firstSeenTime), + }, + }, + } + + _, err := client.TagResource(ctx, input) + if err != nil { + return errors.WithStackTrace(err) + } + + return nil +} + +// stopClusterRunningTasks stops all running tasks on a cluster +func stopClusterRunningTasks(ctx context.Context, client ECSClustersAPI, clusterArn *string) error { + logging.Debugf("[TASK] stopping tasks running on cluster %v", *clusterArn) + // before deleting the cluster, remove the active tasks on that cluster + runningTasks, err := client.ListTasks(ctx, &ecs.ListTasksInput{ + Cluster: clusterArn, + DesiredStatus: types.DesiredStatusRunning, + }) + + if err != nil { + return errors.WithStackTrace(err) + } + + // stop the listed tasks + for _, task := range runningTasks.TaskArns { + _, err := client.StopTask(ctx, &ecs.StopTaskInput{ + Cluster: clusterArn, + Task: aws.String(task), + Reason: aws.String("Terminating task due to cluster deletion"), + }) + if err != nil { + logging.Debugf("[TASK] Unable to stop the task %s on cluster %s. Reason: %v", task, *clusterArn, err) + return errors.WithStackTrace(err) + } + logging.Debugf("[TASK] Success, stopped task %v", task) + } + return nil +} + +// deleteECSCluster deletes a single ECS cluster +func deleteECSCluster(ctx context.Context, client ECSClustersAPI, clusterArn *string) error { + _, err := client.DeleteCluster(ctx, &ecs.DeleteClusterInput{ + Cluster: clusterArn, + }) + return err +} diff --git a/aws/resources/ecs_cluster_test.go b/aws/resources/ecs_cluster_test.go index 1b98ab49..b20b397d 100644 --- a/aws/resources/ecs_cluster_test.go +++ b/aws/resources/ecs_cluster_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/stretchr/testify/require" ) @@ -92,7 +93,17 @@ func (m mockedEC2Cluster) TagResource(ctx context.Context, params *ecs.TagResour return &m.TagResourceOutput, nil } -func TestEC2Cluster_GetAll(t *testing.T) { +func TestECSClusters_ResourceName(t *testing.T) { + r := NewECSClusters() + require.Equal(t, "ecscluster", r.ResourceName()) +} + +func TestECSClusters_MaxBatchSize(t *testing.T) { + r := NewECSClusters() + require.Equal(t, maxBatchSize, r.MaxBatchSize()) +} + +func TestListECSClusters(t *testing.T) { t.Parallel() // Set excludeFirstSeenTag to false for testing ctx := context.WithValue(context.Background(), util.ExcludeFirstSeenTagKey, false) @@ -102,43 +113,42 @@ func TestEC2Cluster_GetAll(t *testing.T) { testName1 := "cluster1" testName2 := "cluster2" now := time.Now() - ec := ECSClusters{ - Client: mockedEC2Cluster{ - ListClustersOutput: ecs.ListClustersOutput{ - ClusterArns: []string{ - testArn1, - testArn2, - }, + + mock := mockedEC2Cluster{ + ListClustersOutput: ecs.ListClustersOutput{ + ClusterArns: []string{ + testArn1, + testArn2, }, + }, - DescribeClustersOutput: ecs.DescribeClustersOutput{ - Clusters: []types.Cluster{ - { - ClusterArn: aws.String(testArn1), - Status: aws.String("ACTIVE"), - ClusterName: aws.String(testName1), - }, - { - ClusterArn: aws.String(testArn2), - Status: aws.String("ACTIVE"), - ClusterName: aws.String(testName2), - }, + DescribeClustersOutput: ecs.DescribeClustersOutput{ + Clusters: []types.Cluster{ + { + ClusterArn: aws.String(testArn1), + Status: aws.String("ACTIVE"), + ClusterName: aws.String(testName1), + }, + { + ClusterArn: aws.String(testArn2), + Status: aws.String("ACTIVE"), + ClusterName: aws.String(testName2), }, }, + }, - ListTagsForResourceOutput: ecs.ListTagsForResourceOutput{ - Tags: []types.Tag{ - { - Key: aws.String(util.FirstSeenTagKey), - Value: aws.String(util.FormatTimestamp(now)), - }, + ListTagsForResourceOutput: ecs.ListTagsForResourceOutput{ + Tags: []types.Tag{ + { + Key: aws.String(util.FirstSeenTagKey), + Value: aws.String(util.FormatTimestamp(now)), }, }, - ListTasksOutput: ecs.ListTasksOutput{ - TaskArns: []string{ - "task-arn-001", - "task-arn-002", - }, + }, + ListTasksOutput: ecs.ListTasksOutput{ + TaskArns: []string{ + "task-arn-001", + "task-arn-002", }, }, } @@ -294,16 +304,14 @@ func TestEC2Cluster_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := ec.getAll(tc.ctx, config.Config{ - ECSCluster: tc.configObj, - }) + names, err := listECSClusters(tc.ctx, mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) } } -func TestEC2Cluster_GetAll_InactiveClusters(t *testing.T) { +func TestListECSClusters_InactiveClusters(t *testing.T) { t.Parallel() ctx := context.WithValue(context.Background(), util.ExcludeFirstSeenTagKey, false) @@ -311,134 +319,154 @@ func TestEC2Cluster_GetAll_InactiveClusters(t *testing.T) { testArn2 := "arn:aws:ecs:us-east-1:123456789012:cluster/active1" now := time.Now() - ec := ECSClusters{ - Client: mockedEC2Cluster{ - ListClustersOutput: ecs.ListClustersOutput{ - ClusterArns: []string{testArn1, testArn2}, - }, - DescribeClustersOutput: ecs.DescribeClustersOutput{ - Clusters: []types.Cluster{ - { - ClusterArn: aws.String(testArn1), - Status: aws.String("INACTIVE"), - ClusterName: aws.String("inactive1"), - }, - { - ClusterArn: aws.String(testArn2), - Status: aws.String("ACTIVE"), - ClusterName: aws.String("active1"), - }, + mock := mockedEC2Cluster{ + ListClustersOutput: ecs.ListClustersOutput{ + ClusterArns: []string{testArn1, testArn2}, + }, + DescribeClustersOutput: ecs.DescribeClustersOutput{ + Clusters: []types.Cluster{ + { + ClusterArn: aws.String(testArn1), + Status: aws.String("INACTIVE"), + ClusterName: aws.String("inactive1"), + }, + { + ClusterArn: aws.String(testArn2), + Status: aws.String("ACTIVE"), + ClusterName: aws.String("active1"), }, }, - ListTagsForResourceOutput: ecs.ListTagsForResourceOutput{ - Tags: []types.Tag{ - { - Key: aws.String(util.FirstSeenTagKey), - Value: aws.String(util.FormatTimestamp(now)), - }, + }, + ListTagsForResourceOutput: ecs.ListTagsForResourceOutput{ + Tags: []types.Tag{ + { + Key: aws.String(util.FirstSeenTagKey), + Value: aws.String(util.FormatTimestamp(now)), }, }, }, } - names, err := ec.getAll(ctx, config.Config{ECSCluster: config.ResourceType{}}) + names, err := listECSClusters(ctx, mock, resource.Scope{}, config.ResourceType{}) require.NoError(t, err) // Only active cluster should be returned require.Equal(t, []string{testArn2}, aws.ToStringSlice(names)) } -func TestEC2Cluster_GetAll_NoFirstSeenTag(t *testing.T) { +func TestListECSClusters_NoFirstSeenTag(t *testing.T) { t.Parallel() ctx := context.WithValue(context.Background(), util.ExcludeFirstSeenTagKey, false) testArn := "arn:aws:ecs:us-east-1:123456789012:cluster/new-cluster" - ec := ECSClusters{ - Client: mockedEC2Cluster{ - ListClustersOutput: ecs.ListClustersOutput{ - ClusterArns: []string{testArn}, - }, - DescribeClustersOutput: ecs.DescribeClustersOutput{ - Clusters: []types.Cluster{ - { - ClusterArn: aws.String(testArn), - Status: aws.String("ACTIVE"), - ClusterName: aws.String("new-cluster"), - }, + mock := mockedEC2Cluster{ + ListClustersOutput: ecs.ListClustersOutput{ + ClusterArns: []string{testArn}, + }, + DescribeClustersOutput: ecs.DescribeClustersOutput{ + Clusters: []types.Cluster{ + { + ClusterArn: aws.String(testArn), + Status: aws.String("ACTIVE"), + ClusterName: aws.String("new-cluster"), }, }, - ListTagsForResourceOutput: ecs.ListTagsForResourceOutput{ - Tags: []types.Tag{}, // No tags - }, - TagResourceOutput: ecs.TagResourceOutput{}, }, + ListTagsForResourceOutput: ecs.ListTagsForResourceOutput{ + Tags: []types.Tag{}, // No tags + }, + TagResourceOutput: ecs.TagResourceOutput{}, } - names, err := ec.getAll(ctx, config.Config{ECSCluster: config.ResourceType{}}) + names, err := listECSClusters(ctx, mock, resource.Scope{}, config.ResourceType{}) require.NoError(t, err) // Should return empty since cluster gets tagged but not included until next run require.Empty(t, names) } -func TestEC2Cluster_GetAll_EmptyList(t *testing.T) { +func TestListECSClusters_EmptyList(t *testing.T) { t.Parallel() ctx := context.WithValue(context.Background(), util.ExcludeFirstSeenTagKey, false) - ec := ECSClusters{ - Client: mockedEC2Cluster{ - ListClustersOutput: ecs.ListClustersOutput{ClusterArns: []string{}}, - DescribeClustersOutput: ecs.DescribeClustersOutput{Clusters: []types.Cluster{}}, - ListTagsForResourceOutput: ecs.ListTagsForResourceOutput{Tags: []types.Tag{}}, - }, + mock := mockedEC2Cluster{ + ListClustersOutput: ecs.ListClustersOutput{ClusterArns: []string{}}, + DescribeClustersOutput: ecs.DescribeClustersOutput{Clusters: []types.Cluster{}}, + ListTagsForResourceOutput: ecs.ListTagsForResourceOutput{Tags: []types.Tag{}}, } - names, err := ec.getAll(ctx, config.Config{ECSCluster: config.ResourceType{}}) + names, err := listECSClusters(ctx, mock, resource.Scope{}, config.ResourceType{}) require.NoError(t, err) require.Empty(t, names) } -func TestEC2Cluster_NukeAll(t *testing.T) { +func TestDeleteECSCluster(t *testing.T) { t.Parallel() - ec := ECSClusters{ - Client: mockedEC2Cluster{ - DeleteClusterOutput: ecs.DeleteClusterOutput{}, - ListTasksOutput: ecs.ListTasksOutput{TaskArns: []string{}}, - }, + + mock := mockedEC2Cluster{ + DeleteClusterOutput: ecs.DeleteClusterOutput{}, } - err := ec.nukeAll([]*string{aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster1")}) + err := deleteECSCluster(context.Background(), mock, aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster1")) require.NoError(t, err) } -func TestEC2Cluster_NukeAll_EmptyList(t *testing.T) { +func TestStopClusterRunningTasks(t *testing.T) { t.Parallel() - ec := ECSClusters{ - Client: mockedEC2Cluster{}, + + mock := mockedEC2Cluster{ + ListTasksOutput: ecs.ListTasksOutput{TaskArns: []string{}}, } - err := ec.nukeAll([]*string{}) + err := stopClusterRunningTasks(context.Background(), mock, aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster1")) require.NoError(t, err) } -func TestEC2ClusterWithTasks_NukeAll(t *testing.T) { +func TestStopClusterRunningTasksWithTasks(t *testing.T) { t.Parallel() - ec := ECSClusters{ - Client: mockedEC2Cluster{ - DeleteClusterOutput: ecs.DeleteClusterOutput{}, - ListTasksOutput: ecs.ListTasksOutput{ - TaskArns: []string{ - "task-arn-001", - "task-arn-002", - }, + + mock := mockedEC2Cluster{ + ListTasksOutput: ecs.ListTasksOutput{ + TaskArns: []string{ + "task-arn-001", + "task-arn-002", + }, + }, + StopTaskOutput: ecs.StopTaskOutput{}, + } + + err := stopClusterRunningTasks(context.Background(), mock, aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster1")) + require.NoError(t, err) +} + +func TestECSClustersMultiStepDeleter(t *testing.T) { + t.Parallel() + + mock := mockedEC2Cluster{ + DeleteClusterOutput: ecs.DeleteClusterOutput{}, + ListTasksOutput: ecs.ListTasksOutput{ + TaskArns: []string{ + "task-arn-001", + "task-arn-002", }, - StopTaskOutput: ecs.StopTaskOutput{}, }, + StopTaskOutput: ecs.StopTaskOutput{}, } - err := ec.nukeAll([]*string{aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster1")}) + nuker := resource.MultiStepDeleter(stopClusterRunningTasks, deleteECSCluster) + err := nuker(context.Background(), mock, resource.Scope{Region: "us-east-1"}, "ecscluster", []*string{aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster1")}) + require.NoError(t, err) +} + +func TestECSClustersMultiStepDeleter_EmptyList(t *testing.T) { + t.Parallel() + + mock := mockedEC2Cluster{} + + nuker := resource.MultiStepDeleter(stopClusterRunningTasks, deleteECSCluster) + err := nuker(context.Background(), mock, resource.Scope{Region: "us-east-1"}, "ecscluster", []*string{}) require.NoError(t, err) } -func TestEC2Cluster_GetAll_MultipleRegexPatterns(t *testing.T) { +func TestListECSClusters_MultipleRegexPatterns(t *testing.T) { t.Parallel() ctx := context.WithValue(context.Background(), util.ExcludeFirstSeenTagKey, false) @@ -447,36 +475,34 @@ func TestEC2Cluster_GetAll_MultipleRegexPatterns(t *testing.T) { testArn3 := "arn:aws:ecs:us-east-1:123456789012:cluster/test-service" now := time.Now() - ec := ECSClusters{ - Client: mockedEC2Cluster{ - ListClustersOutput: ecs.ListClustersOutput{ - ClusterArns: []string{testArn1, testArn2, testArn3}, - }, - DescribeClustersOutput: ecs.DescribeClustersOutput{ - Clusters: []types.Cluster{ - { - ClusterArn: aws.String(testArn1), - Status: aws.String("ACTIVE"), - ClusterName: aws.String("prod-cluster"), - }, - { - ClusterArn: aws.String(testArn2), - Status: aws.String("ACTIVE"), - ClusterName: aws.String("dev-cluster"), - }, - { - ClusterArn: aws.String(testArn3), - Status: aws.String("ACTIVE"), - ClusterName: aws.String("test-service"), - }, + mock := mockedEC2Cluster{ + ListClustersOutput: ecs.ListClustersOutput{ + ClusterArns: []string{testArn1, testArn2, testArn3}, + }, + DescribeClustersOutput: ecs.DescribeClustersOutput{ + Clusters: []types.Cluster{ + { + ClusterArn: aws.String(testArn1), + Status: aws.String("ACTIVE"), + ClusterName: aws.String("prod-cluster"), + }, + { + ClusterArn: aws.String(testArn2), + Status: aws.String("ACTIVE"), + ClusterName: aws.String("dev-cluster"), + }, + { + ClusterArn: aws.String(testArn3), + Status: aws.String("ACTIVE"), + ClusterName: aws.String("test-service"), }, }, - ListTagsForResourceOutput: ecs.ListTagsForResourceOutput{ - Tags: []types.Tag{ - { - Key: aws.String(util.FirstSeenTagKey), - Value: aws.String(util.FormatTimestamp(now)), - }, + }, + ListTagsForResourceOutput: ecs.ListTagsForResourceOutput{ + Tags: []types.Tag{ + { + Key: aws.String(util.FirstSeenTagKey), + Value: aws.String(util.FormatTimestamp(now)), }, }, }, @@ -512,7 +538,7 @@ func TestEC2Cluster_GetAll_MultipleRegexPatterns(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := ec.getAll(ctx, config.Config{ECSCluster: tc.configObj}) + names, err := listECSClusters(ctx, mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) diff --git a/aws/resources/ecs_cluster_types.go b/aws/resources/ecs_cluster_types.go deleted file mode 100644 index 5f4ba0b2..00000000 --- a/aws/resources/ecs_cluster_types.go +++ /dev/null @@ -1,68 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ecs" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type ECSClustersAPI interface { - DescribeClusters(ctx context.Context, params *ecs.DescribeClustersInput, optFns ...func(*ecs.Options)) (*ecs.DescribeClustersOutput, error) - DeleteCluster(ctx context.Context, params *ecs.DeleteClusterInput, optFns ...func(*ecs.Options)) (*ecs.DeleteClusterOutput, error) - ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) - ListTagsForResource(ctx context.Context, params *ecs.ListTagsForResourceInput, optFns ...func(*ecs.Options)) (*ecs.ListTagsForResourceOutput, error) - ListTasks(ctx context.Context, params *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) - StopTask(ctx context.Context, params *ecs.StopTaskInput, optFns ...func(*ecs.Options)) (*ecs.StopTaskOutput, error) - TagResource(ctx context.Context, params *ecs.TagResourceInput, optFns ...func(*ecs.Options)) (*ecs.TagResourceOutput, error) -} - -// ECSClusters - Represents all ECS clusters found in a region -type ECSClusters struct { - BaseAwsResource - Client ECSClustersAPI - Region string - ClusterArns []string -} - -func (clusters *ECSClusters) Init(cfg aws.Config) { - clusters.Client = ecs.NewFromConfig(cfg) -} - -// ResourceName - The simple name of the aws resource -func (clusters *ECSClusters) ResourceName() string { - return "ecscluster" -} - -// ResourceIdentifiers - the collected ECS clusters -func (clusters *ECSClusters) ResourceIdentifiers() []string { - return clusters.ClusterArns -} - -func (clusters *ECSClusters) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.ECSCluster -} - -func (clusters *ECSClusters) MaxBatchSize() int { - return maxBatchSize -} - -func (clusters *ECSClusters) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := clusters.getAll(c, configObj) - if err != nil { - return nil, err - } - - clusters.ClusterArns = aws.ToStringSlice(identifiers) - return clusters.ClusterArns, nil -} - -// Nuke - nuke all ECS Cluster resources -func (clusters *ECSClusters) Nuke(ctx context.Context, identifiers []string) error { - if err := clusters.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - return nil -} diff --git a/aws/resources/ecs_service.go b/aws/resources/ecs_service.go index f0558c1c..236c118a 100644 --- a/aws/resources/ecs_service.go +++ b/aws/resources/ecs_service.go @@ -2,6 +2,8 @@ package resources import ( "context" + "sync" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ecs" @@ -9,15 +11,65 @@ import ( "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/gruntwork-io/go-commons/errors" ) -// getAllEcsClusters - Returns a string of ECS Cluster ARNs, which uniquely identifies the cluster. +// ECSServicesAPI defines the interface for ECS Services operations. +type ECSServicesAPI interface { + ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) + ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) + DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) + DeleteService(ctx context.Context, params *ecs.DeleteServiceInput, optFns ...func(*ecs.Options)) (*ecs.DeleteServiceOutput, error) + UpdateService(ctx context.Context, params *ecs.UpdateServiceInput, optFns ...func(*ecs.Options)) (*ecs.UpdateServiceOutput, error) +} + +// ecsServicesState holds state that needs to be shared between list and nuke phases. +type ecsServicesState struct { + mu sync.Mutex + serviceClusterMap map[string]string + timeout time.Duration +} + +// globalECSServicesState is the global state for ECS Services operations. +var globalECSServicesState = &ecsServicesState{ + serviceClusterMap: make(map[string]string), + timeout: DefaultWaitTimeout, +} + +// NewECSServices creates a new ECSServices resource using the generic resource pattern. +func NewECSServices() AwsResource { + return NewAwsResource(&resource.Resource[ECSServicesAPI]{ + ResourceTypeName: "ecsserv", + BatchSize: 49, + InitClient: func(r *resource.Resource[ECSServicesAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for ECS client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = ecs.NewFromConfig(awsCfg) + // Reset global state on init + globalECSServicesState.mu.Lock() + globalECSServicesState.serviceClusterMap = make(map[string]string) + globalECSServicesState.timeout = DefaultWaitTimeout + globalECSServicesState.mu.Unlock() + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.ECSService + }, + Lister: listECSServices, + Nuker: deleteECSServices, + }) +} + +// getAllEcsClusterArnsForServices - Returns a string of ECS Cluster ARNs, which uniquely identifies the cluster. // We need to get all clusters before we can get all services. -func (services *ECSServices) getAllEcsClusters() ([]*string, error) { +func getAllEcsClusterArnsForServices(ctx context.Context, client ECSServicesAPI) ([]*string, error) { var clusterArns []string - result, err := services.Client.ListClusters(services.Context, &ecs.ListClustersInput{}) + result, err := client.ListClusters(ctx, &ecs.ListClustersInput{}) if err != nil { return nil, errors.WithStackTrace(err) } @@ -25,7 +77,7 @@ func (services *ECSServices) getAllEcsClusters() ([]*string, error) { // Handle pagination: continuously pull the next page if nextToken is set for aws.ToString(result.NextToken) != "" { - result, err = services.Client.ListClusters(services.Context, &ecs.ListClustersInput{NextToken: result.NextToken}) + result, err = client.ListClusters(ctx, &ecs.ListClustersInput{NextToken: result.NextToken}) if err != nil { return nil, errors.WithStackTrace(err) } @@ -38,7 +90,7 @@ func (services *ECSServices) getAllEcsClusters() ([]*string, error) { // filterOutRecentServices - Given a list of services and an excludeAfter // timestamp, filter out any services that were created after `excludeAfter. // Additionally, filter based on Config file patterns. -func (services *ECSServices) filterOutRecentServices(clusterArn *string, ecsServiceArns []string, configObj config.Config) ([]*string, error) { +func filterOutRecentServices(ctx context.Context, client ECSServicesAPI, clusterArn *string, ecsServiceArns []string, cfg config.ResourceType) ([]*string, error) { // Fetch descriptions in batches of 10, which is the max that AWS // accepts for describe service. var filteredEcsServiceArns []*string @@ -49,7 +101,7 @@ func (services *ECSServices) filterOutRecentServices(clusterArn *string, ecsServ Services: batch, Include: []types.ServiceField{types.ServiceFieldTags}, } - describeResult, err := services.Client.DescribeServices(services.Context, params) + describeResult, err := client.DescribeServices(ctx, params) if err != nil { return nil, errors.WithStackTrace(err) } @@ -61,7 +113,7 @@ func (services *ECSServices) filterOutRecentServices(clusterArn *string, ecsServ } } - if configObj.ECSService.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Name: service.ServiceName, Time: service.CreatedAt, Tags: tags, @@ -73,13 +125,13 @@ func (services *ECSServices) filterOutRecentServices(clusterArn *string, ecsServ return filteredEcsServiceArns, nil } -// getAllEcsServices - Returns a formatted string of ECS Service ARNs, which +// listECSServices - Returns a formatted string of ECS Service ARNs, which // uniquely identifies the service, in addition to a mapping of services to // clusters. For ECS, need to track ECS clusters of services as all service // level API endpoints require providing the corresponding cluster. // Note that this looks up services by ECS cluster ARNs. -func (services *ECSServices) getAll(c context.Context, configObj config.Config) ([]*string, error) { - ecsClusterArns, err := services.getAllEcsClusters() +func listECSServices(ctx context.Context, client ECSServicesAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + ecsClusterArns, err := getAllEcsClusterArnsForServices(ctx, client) if err != nil { return nil, errors.WithStackTrace(err) } @@ -90,11 +142,11 @@ func (services *ECSServices) getAll(c context.Context, configObj config.Config) // ones. var ecsServiceArns []*string for _, clusterArn := range ecsClusterArns { - result, err := services.Client.ListServices(services.Context, &ecs.ListServicesInput{Cluster: clusterArn}) + result, err := client.ListServices(ctx, &ecs.ListServicesInput{Cluster: clusterArn}) if err != nil { return nil, errors.WithStackTrace(err) } - filteredServiceArns, err := services.filterOutRecentServices(clusterArn, result.ServiceArns, configObj) + filteredServiceArns, err := filterOutRecentServices(ctx, client, clusterArn, result.ServiceArns, cfg) if err != nil { return nil, errors.WithStackTrace(err) } @@ -105,22 +157,26 @@ func (services *ECSServices) getAll(c context.Context, configObj config.Config) ecsServiceArns = append(ecsServiceArns, filteredServiceArns...) } - services.ServiceClusterMap = ecsServiceClusterMap + // Store mapping in global state for use during nuke phase + globalECSServicesState.mu.Lock() + globalECSServicesState.serviceClusterMap = ecsServiceClusterMap + globalECSServicesState.mu.Unlock() + return ecsServiceArns, nil } // drainEcsServices - Drain all tasks from all services requested. This will // return a list of service ARNs that have been successfully requested to be // drained. -func (services *ECSServices) drainEcsServices(ecsServiceArns []*string) []*string { +func drainEcsServices(ctx context.Context, client ECSServicesAPI, ecsServiceArns []*string, serviceClusterMap map[string]string) []*string { var requestedDrains []*string for _, ecsServiceArn := range ecsServiceArns { describeParams := &ecs.DescribeServicesInput{ - Cluster: aws.String(services.ServiceClusterMap[*ecsServiceArn]), + Cluster: aws.String(serviceClusterMap[*ecsServiceArn]), Services: []string{*ecsServiceArn}, } - describeServicesOutput, err := services.Client.DescribeServices(services.Context, describeParams) + describeServicesOutput, err := client.DescribeServices(ctx, describeParams) if err != nil { logging.Errorf("[Failed] Failed to describe service %s: %s", *ecsServiceArn, err) } else { @@ -130,11 +186,11 @@ func (services *ECSServices) drainEcsServices(ecsServiceArns []*string) []*strin requestedDrains = append(requestedDrains, ecsServiceArn) } else { params := &ecs.UpdateServiceInput{ - Cluster: aws.String(services.ServiceClusterMap[*ecsServiceArn]), + Cluster: aws.String(serviceClusterMap[*ecsServiceArn]), Service: ecsServiceArn, DesiredCount: aws.Int32(0), } - _, err = services.Client.UpdateService(services.Context, params) + _, err = client.UpdateService(ctx, params) if err != nil { logging.Errorf("[Failed] Failed to drain service %s: %s", *ecsServiceArn, err) } else { @@ -150,16 +206,16 @@ func (services *ECSServices) drainEcsServices(ecsServiceArns []*string) []*strin // given list of services, by waiting for stability which is defined as // desiredCount == runningCount. This will return a list of service ARNs that // have successfully been drained. -func (services *ECSServices) waitUntilServicesDrained(ecsServiceArns []*string) []*string { +func waitUntilServicesDrained(ctx context.Context, client ECSServicesAPI, ecsServiceArns []*string, serviceClusterMap map[string]string, timeout time.Duration) []*string { var successfullyDrained []*string for _, ecsServiceArn := range ecsServiceArns { params := &ecs.DescribeServicesInput{ - Cluster: aws.String(services.ServiceClusterMap[*ecsServiceArn]), + Cluster: aws.String(serviceClusterMap[*ecsServiceArn]), Services: []string{*ecsServiceArn}, } - waiter := ecs.NewServicesStableWaiter(services.Client) - err := waiter.Wait(services.Context, params, services.Timeout) + waiter := ecs.NewServicesStableWaiter(client) + err := waiter.Wait(ctx, params, timeout) if err != nil { logging.Debugf("[Failed] Failed waiting for service to be stable %s: %s", *ecsServiceArn, err) } else { @@ -170,16 +226,16 @@ func (services *ECSServices) waitUntilServicesDrained(ecsServiceArns []*string) return successfullyDrained } -// deleteEcsServices - Deletes all services requested. Returns a list of +// deleteEcsServicesIndividually - Deletes all services requested. Returns a list of // service ARNs that have been accepted by AWS for deletion. -func (services *ECSServices) deleteEcsServices(ecsServiceArns []*string) []*string { +func deleteEcsServicesIndividually(ctx context.Context, client ECSServicesAPI, ecsServiceArns []*string, serviceClusterMap map[string]string) []*string { var requestedDeletes []*string for _, ecsServiceArn := range ecsServiceArns { params := &ecs.DeleteServiceInput{ - Cluster: aws.String(services.ServiceClusterMap[*ecsServiceArn]), + Cluster: aws.String(serviceClusterMap[*ecsServiceArn]), Service: ecsServiceArn, } - _, err := services.Client.DeleteService(services.Context, params) + _, err := client.DeleteService(ctx, params) if err != nil { logging.Debugf("[Failed] Failed deleting service %s: %s", *ecsServiceArn, err) } else { @@ -192,16 +248,16 @@ func (services *ECSServices) deleteEcsServices(ecsServiceArns []*string) []*stri // waitUntilServicesDeleted - Waits until the service has been actually deleted // from AWS. Returns a list of service ARNs that have been successfully // deleted. -func (services *ECSServices) waitUntilServicesDeleted(ecsServiceArns []*string) []*string { +func waitUntilServicesDeleted(ctx context.Context, client ECSServicesAPI, ecsServiceArns []*string, serviceClusterMap map[string]string, timeout time.Duration) []*string { var successfullyDeleted []*string for _, ecsServiceArn := range ecsServiceArns { params := &ecs.DescribeServicesInput{ - Cluster: aws.String(services.ServiceClusterMap[*ecsServiceArn]), + Cluster: aws.String(serviceClusterMap[*ecsServiceArn]), Services: []string{*ecsServiceArn}, } - waiter := ecs.NewServicesInactiveWaiter(services.Client) - err := waiter.Wait(services.Context, params, services.Timeout) + waiter := ecs.NewServicesInactiveWaiter(client) + err := waiter.Wait(ctx, params, timeout) // Record status of this resource e := report.Entry{ @@ -221,22 +277,25 @@ func (services *ECSServices) waitUntilServicesDeleted(ecsServiceArns []*string) return successfullyDeleted } -// Deletes all provided ECS Services. At a high level this involves two steps: -// 1.) Drain all tasks from the service so that nothing is -// -// running. -// +// deleteECSServices deletes all provided ECS Services. At a high level this involves two steps: +// 1.) Drain all tasks from the service so that nothing is running. // 2.) Delete service object once no tasks are running. // Note that this will swallow failed deletes and continue along, logging the // service ARN so that we can find it later. -func (services *ECSServices) nukeAll(ecsServiceArns []*string) error { +func deleteECSServices(ctx context.Context, client ECSServicesAPI, scope resource.Scope, resourceType string, ecsServiceArns []*string) error { numNuking := len(ecsServiceArns) if numNuking == 0 { - logging.Debugf("No ECS services to nuke in region %s", services.Region) + logging.Debugf("No ECS services to nuke in region %s", scope.Region) return nil } - logging.Debugf("Deleting %d ECS services in region %s", numNuking, services.Region) + // Get service cluster map from global state + globalECSServicesState.mu.Lock() + serviceClusterMap := globalECSServicesState.serviceClusterMap + timeout := globalECSServicesState.timeout + globalECSServicesState.mu.Unlock() + + logging.Debugf("Deleting %d ECS services in region %s", numNuking, scope.Region) // First, drain all the services to 0. You can't delete a // service that is running tasks. @@ -244,12 +303,12 @@ func (services *ECSServices) nukeAll(ecsServiceArns []*string) error { // wait for them in a separate loop because it will take a // while to drain the services. // Then, we delete the services that have been successfully drained. - requestedDrains := services.drainEcsServices(ecsServiceArns) - successfullyDrained := services.waitUntilServicesDrained(requestedDrains) - requestedDeletes := services.deleteEcsServices(successfullyDrained) - successfullyDeleted := services.waitUntilServicesDeleted(requestedDeletes) + requestedDrains := drainEcsServices(ctx, client, ecsServiceArns, serviceClusterMap) + successfullyDrained := waitUntilServicesDrained(ctx, client, requestedDrains, serviceClusterMap, timeout) + requestedDeletes := deleteEcsServicesIndividually(ctx, client, successfullyDrained, serviceClusterMap) + successfullyDeleted := waitUntilServicesDeleted(ctx, client, requestedDeletes, serviceClusterMap, timeout) numNuked := len(successfullyDeleted) - logging.Debugf("[OK] %d of %d ECS service(s) deleted in %s", numNuked, numNuking, services.Client) + logging.Debugf("[OK] %d of %d ECS service(s) deleted in %s", numNuked, numNuking, scope.Region) return nil } diff --git a/aws/resources/ecs_service_test.go b/aws/resources/ecs_service_test.go index 392ea943..8c08ab11 100644 --- a/aws/resources/ecs_service_test.go +++ b/aws/resources/ecs_service_test.go @@ -10,10 +10,11 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) -type mockedEC2Service struct { +type mockedECSService struct { ECSServicesAPI ListClustersOutput ecs.ListClustersOutput ListServicesOutput ecs.ListServicesOutput @@ -22,57 +23,55 @@ type mockedEC2Service struct { UpdateServiceOutput ecs.UpdateServiceOutput } -func (m mockedEC2Service) ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) { +func (m mockedECSService) ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) { return &m.ListClustersOutput, nil } -func (m mockedEC2Service) ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) { +func (m mockedECSService) ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) { return &m.ListServicesOutput, nil } -func (m mockedEC2Service) DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) { +func (m mockedECSService) DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) { return &m.DescribeServicesOutput, nil } -func (m mockedEC2Service) DeleteService(ctx context.Context, params *ecs.DeleteServiceInput, optFns ...func(*ecs.Options)) (*ecs.DeleteServiceOutput, error) { +func (m mockedECSService) DeleteService(ctx context.Context, params *ecs.DeleteServiceInput, optFns ...func(*ecs.Options)) (*ecs.DeleteServiceOutput, error) { return &m.DeleteServiceOutput, nil } -func (m mockedEC2Service) UpdateService(ctx context.Context, params *ecs.UpdateServiceInput, optFns ...func(*ecs.Options)) (*ecs.UpdateServiceOutput, error) { +func (m mockedECSService) UpdateService(ctx context.Context, params *ecs.UpdateServiceInput, optFns ...func(*ecs.Options)) (*ecs.UpdateServiceOutput, error) { return &m.UpdateServiceOutput, nil } -func TestEC2Service_GetAll(t *testing.T) { +func TestECSService_GetAll(t *testing.T) { t.Parallel() testArn1 := "testArn1" testArn2 := "testArn2" testName1 := "testService1" testName2 := "testService2" now := time.Now() - es := ECSServices{ - Client: mockedEC2Service{ - ListClustersOutput: ecs.ListClustersOutput{ - ClusterArns: []string{ - testArn1, - }, + mockClient := mockedECSService{ + ListClustersOutput: ecs.ListClustersOutput{ + ClusterArns: []string{ + testArn1, }, - ListServicesOutput: ecs.ListServicesOutput{ - ServiceArns: []string{ - testArn1, - }, + }, + ListServicesOutput: ecs.ListServicesOutput{ + ServiceArns: []string{ + testArn1, }, - DescribeServicesOutput: ecs.DescribeServicesOutput{ - Services: []types.Service{ - { - ServiceArn: aws.String(testArn1), - ServiceName: aws.String(testName1), - CreatedAt: aws.Time(now), - }, - { - ServiceArn: aws.String(testArn2), - ServiceName: aws.String(testName2), - CreatedAt: aws.Time(now.Add(1)), - }, + }, + DescribeServicesOutput: ecs.DescribeServicesOutput{ + Services: []types.Service{ + { + ServiceArn: aws.String(testArn1), + ServiceName: aws.String(testName1), + CreatedAt: aws.Time(now), + }, + { + ServiceArn: aws.String(testArn2), + ServiceName: aws.String(testName2), + CreatedAt: aws.Time(now.Add(1)), }, }, }, @@ -105,9 +104,7 @@ func TestEC2Service_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := es.getAll(context.Background(), config.Config{ - ECSService: tc.configObj, - }) + names, err := listECSServices(context.Background(), mockClient, resource.Scope{Region: "us-east-1"}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) @@ -115,27 +112,28 @@ func TestEC2Service_GetAll(t *testing.T) { } -func TestEC2Service_NukeAll(t *testing.T) { +func TestECSService_NukeAll(t *testing.T) { t.Parallel() - es := ECSServices{ - BaseAwsResource: BaseAwsResource{ - Context: context.Background(), - }, - Client: mockedEC2Service{ - DescribeServicesOutput: ecs.DescribeServicesOutput{ - Services: []types.Service{ - { - SchedulingStrategy: types.SchedulingStrategyDaemon, - ServiceArn: aws.String("testArn1"), - Status: aws.String("DRAINING"), - }, + + // Setup global state with service cluster map + globalECSServicesState.serviceClusterMap = map[string]string{ + "testArn1": "testCluster1", + } + + mockClient := mockedECSService{ + DescribeServicesOutput: ecs.DescribeServicesOutput{ + Services: []types.Service{ + { + SchedulingStrategy: types.SchedulingStrategyDaemon, + ServiceArn: aws.String("testArn1"), + Status: aws.String("DRAINING"), }, }, - UpdateServiceOutput: ecs.UpdateServiceOutput{}, - DeleteServiceOutput: ecs.DeleteServiceOutput{}, }, + UpdateServiceOutput: ecs.UpdateServiceOutput{}, + DeleteServiceOutput: ecs.DeleteServiceOutput{}, } - err := es.nukeAll([]*string{aws.String("testArn1")}) + err := deleteECSServices(context.Background(), mockClient, resource.Scope{Region: "us-east-1"}, "ecsserv", []*string{aws.String("testArn1")}) require.NoError(t, err) } diff --git a/aws/resources/ecs_service_types.go b/aws/resources/ecs_service_types.go deleted file mode 100644 index 4eb418a9..00000000 --- a/aws/resources/ecs_service_types.go +++ /dev/null @@ -1,67 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ecs" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type ECSServicesAPI interface { - ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) - ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) - DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) - DeleteService(ctx context.Context, params *ecs.DeleteServiceInput, optFns ...func(*ecs.Options)) (*ecs.DeleteServiceOutput, error) - UpdateService(ctx context.Context, params *ecs.UpdateServiceInput, optFns ...func(*ecs.Options)) (*ecs.UpdateServiceOutput, error) -} - -// ECSServices - Represents all ECS services found in a region -type ECSServices struct { - BaseAwsResource - Client ECSServicesAPI - Region string - Services []string - ServiceClusterMap map[string]string -} - -func (services *ECSServices) Init(cfg aws.Config) { - services.Client = ecs.NewFromConfig(cfg) -} - -// ResourceName - The simple name of the aws resource -func (services *ECSServices) ResourceName() string { - return "ecsserv" -} - -// ResourceIdentifiers - The ARNs of the collected ECS services -func (services *ECSServices) ResourceIdentifiers() []string { - return services.Services -} - -func (services *ECSServices) MaxBatchSize() int { - return 49 -} - -func (services *ECSServices) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.ECSService -} - -func (services *ECSServices) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := services.getAll(c, configObj) - if err != nil { - return nil, err - } - - services.Services = aws.ToStringSlice(identifiers) - return services.Services, nil -} - -// Nuke - nuke all ECS service resources -func (services *ECSServices) Nuke(ctx context.Context, identifiers []string) error { - if err := services.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - return nil -} diff --git a/aws/resources/efs.go b/aws/resources/efs.go index 4502964c..dfe5d39d 100644 --- a/aws/resources/efs.go +++ b/aws/resources/efs.go @@ -2,7 +2,7 @@ package resources import ( "context" - "sync" + "fmt" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -10,22 +10,56 @@ import ( "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" "github.com/hashicorp/go-multierror" ) -func (ef *ElasticFileSystem) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// ElasticFileSystemAPI defines the interface for EFS operations. +type ElasticFileSystemAPI interface { + DeleteAccessPoint(ctx context.Context, params *efs.DeleteAccessPointInput, optFns ...func(*efs.Options)) (*efs.DeleteAccessPointOutput, error) + DeleteFileSystem(ctx context.Context, params *efs.DeleteFileSystemInput, optFns ...func(*efs.Options)) (*efs.DeleteFileSystemOutput, error) + DeleteMountTarget(ctx context.Context, params *efs.DeleteMountTargetInput, optFns ...func(*efs.Options)) (*efs.DeleteMountTargetOutput, error) + DescribeAccessPoints(ctx context.Context, params *efs.DescribeAccessPointsInput, optFns ...func(*efs.Options)) (*efs.DescribeAccessPointsOutput, error) + DescribeMountTargets(ctx context.Context, params *efs.DescribeMountTargetsInput, optFns ...func(*efs.Options)) (*efs.DescribeMountTargetsOutput, error) + DescribeFileSystems(ctx context.Context, params *efs.DescribeFileSystemsInput, optFns ...func(*efs.Options)) (*efs.DescribeFileSystemsOutput, error) +} + +// NewElasticFileSystem creates a new ElasticFileSystem resource using the generic resource pattern. +func NewElasticFileSystem() AwsResource { + return NewAwsResource(&resource.Resource[ElasticFileSystemAPI]{ + ResourceTypeName: "efs", + BatchSize: 10, + InitClient: func(r *resource.Resource[ElasticFileSystemAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for EFS client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = efs.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.ElasticFileSystem + }, + Lister: listElasticFileSystems, + Nuker: deleteElasticFileSystems, + }) +} + +// listElasticFileSystems retrieves all Elastic File Systems that match the config filters. +func listElasticFileSystems(ctx context.Context, client ElasticFileSystemAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { var allEfs []*string - paginator := efs.NewDescribeFileSystemsPaginator(ef.Client, &efs.DescribeFileSystemsInput{}) + paginator := efs.NewDescribeFileSystemsPaginator(client, &efs.DescribeFileSystemsInput{}) for paginator.HasMorePages() { - page, err := paginator.NextPage(c) + page, err := paginator.NextPage(ctx) if err != nil { return nil, errors.WithStackTrace(err) } for _, system := range page.FileSystems { - if configObj.ElasticFileSystem.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Name: system.Name, Time: system.CreationTime, }) { @@ -37,9 +71,12 @@ func (ef *ElasticFileSystem) getAll(c context.Context, configObj config.Config) return allEfs, nil } -func (ef *ElasticFileSystem) nukeAll(identifiers []*string) error { +// deleteElasticFileSystems is a custom nuker for Elastic File Systems. +// It deletes access points, mount targets, and then the file system. +func deleteElasticFileSystems(ctx context.Context, client ElasticFileSystemAPI, scope resource.Scope, resourceType string, identifiers []*string) error { if len(identifiers) == 0 { - logging.Debugf("No Elastic FileSystems (efs) to nuke in region %s", ef.Region) + logging.Debugf("No Elastic FileSystems (efs) to nuke in %s", scope) + return nil } if len(identifiers) > 100 { @@ -47,24 +84,16 @@ func (ef *ElasticFileSystem) nukeAll(identifiers []*string) error { return TooManyElasticFileSystemsErr{} } - // There is no bulk delete EFS API, so we delete the batch of Elastic FileSystems concurrently using goroutines - logging.Debugf("Deleting Elastic FileSystems (efs) in region %s", ef.Region) - wg := new(sync.WaitGroup) - wg.Add(len(identifiers)) - errChans := make([]chan error, len(identifiers)) - for i, efsID := range identifiers { - errChans[i] = make(chan error, 1) - go ef.deleteAsync(wg, errChans[i], efsID) - } - wg.Wait() + logging.Debugf("Deleting Elastic FileSystems (efs) in %s", scope) var allErrs *multierror.Error - for _, errChan := range errChans { - if err := <-errChan; err != nil { + for _, efsID := range identifiers { + if err := deleteElasticFileSystem(ctx, client, scope, resourceType, efsID); err != nil { allErrs = multierror.Append(allErrs, err) logging.Debugf("[Failed] %s", err) } } + finalErr := allErrs.ErrorOrNil() if finalErr != nil { return errors.WithStackTrace(finalErr) @@ -72,12 +101,9 @@ func (ef *ElasticFileSystem) nukeAll(identifiers []*string) error { return nil } -func (ef *ElasticFileSystem) deleteAsync(wg *sync.WaitGroup, errChan chan error, efsID *string) { +func deleteElasticFileSystem(ctx context.Context, client ElasticFileSystemAPI, scope resource.Scope, resourceType string, efsID *string) error { var allErrs *multierror.Error - defer wg.Done() - defer func() { errChan <- allErrs.ErrorOrNil() }() - // First, we need to check if the Elastic FileSystem is "in-use", because an in-use file system cannot be deleted // An Elastic FileSystem is considered in-use if it has any access points, or any mount targets // Here, we first look up and delete any and all access points for the given Elastic FileSystem @@ -87,7 +113,7 @@ func (ef *ElasticFileSystem) deleteAsync(wg *sync.WaitGroup, errChan chan error, FileSystemId: efsID, } - out, err := ef.Client.DescribeAccessPoints(ef.Context, accessPointParam) + out, err := client.DescribeAccessPoints(ctx, accessPointParam) if err != nil { allErrs = multierror.Append(allErrs, err) } @@ -102,13 +128,13 @@ func (ef *ElasticFileSystem) deleteAsync(wg *sync.WaitGroup, errChan chan error, AccessPointId: apID, } - logging.Debugf("Deleting access point (id=%s) for Elastic FileSystem (%s) in region: %s", aws.ToString(apID), aws.ToString(efsID), ef.Region) + logging.Debugf("Deleting access point (id=%s) for Elastic FileSystem (%s) in %s", aws.ToString(apID), aws.ToString(efsID), scope) - _, err := ef.Client.DeleteAccessPoint(ef.Context, deleteParam) + _, err := client.DeleteAccessPoint(ctx, deleteParam) if err != nil { allErrs = multierror.Append(allErrs, err) } else { - logging.Debugf("[OK] Deleted access point (id=%s) for Elastic FileSystem (%s) in region: %s", aws.ToString(apID), aws.ToString(efsID), ef.Region) + logging.Debugf("[OK] Deleted access point (id=%s) for Elastic FileSystem (%s) in %s", aws.ToString(apID), aws.ToString(efsID), scope) } } @@ -132,9 +158,9 @@ func (ef *ElasticFileSystem) deleteAsync(wg *sync.WaitGroup, errChan chan error, mountTargetParam.Marker = marker } - mountTargetsOutput, describeMountsErr := ef.Client.DescribeMountTargets(ef.Context, mountTargetParam) + mountTargetsOutput, describeMountsErr := client.DescribeMountTargets(ctx, mountTargetParam) if describeMountsErr != nil { - allErrs = multierror.Append(allErrs, err) + allErrs = multierror.Append(allErrs, describeMountsErr) } for _, mountTarget := range mountTargetsOutput.MountTargets { @@ -155,30 +181,32 @@ func (ef *ElasticFileSystem) deleteAsync(wg *sync.WaitGroup, errChan chan error, MountTargetId: mtID, } - logging.Debugf("Deleting mount target (id=%s) for Elastic FileSystem (%s) in region: %s", aws.ToString(mtID), aws.ToString(efsID), ef.Region) + logging.Debugf("Deleting mount target (id=%s) for Elastic FileSystem (%s) in %s", aws.ToString(mtID), aws.ToString(efsID), scope) - _, err := ef.Client.DeleteMountTarget(ef.Context, deleteMtParam) + _, err := client.DeleteMountTarget(ctx, deleteMtParam) if err != nil { allErrs = multierror.Append(allErrs, err) } else { - logging.Debugf("[OK] Deleted mount target (id=%s) for Elastic FileSystem (%s) in region: %s", aws.ToString(mtID), aws.ToString(efsID), ef.Region) + logging.Debugf("[OK] Deleted mount target (id=%s) for Elastic FileSystem (%s) in %s", aws.ToString(mtID), aws.ToString(efsID), scope) } } - logging.Debug("Sleeping 20 seconds to allow AWS to realize the Elastic FileSystem is no longer in use...") - time.Sleep(20 * time.Second) + // Wait for mount targets to be fully deleted before attempting to delete the file system + if err := waitForMountTargetsDeleted(ctx, client, efsID); err != nil { + allErrs = multierror.Append(allErrs, err) + } // Now we can attempt to delete the Elastic FileSystem itself deleteEfsParam := &efs.DeleteFileSystemInput{ FileSystemId: efsID, } - _, deleteErr := ef.Client.DeleteFileSystem(ef.Context, deleteEfsParam) + _, deleteErr := client.DeleteFileSystem(ctx, deleteEfsParam) // Record status of this resource e := report.Entry{ Identifier: aws.ToString(efsID), - ResourceType: "Elastic FileSystem (EFS)", - Error: err, + ResourceType: resourceType, + Error: deleteErr, } report.Record(e) @@ -186,9 +214,39 @@ func (ef *ElasticFileSystem) deleteAsync(wg *sync.WaitGroup, errChan chan error, allErrs = multierror.Append(allErrs, deleteErr) } - if err == nil { - logging.Debugf("[OK] Elastic FileSystem (efs) %s deleted in %s", aws.ToString(efsID), ef.Region) + if deleteErr == nil { + logging.Debugf("[OK] Elastic FileSystem (efs) %s deleted in %s", aws.ToString(efsID), scope) } else { - logging.Debugf("[Failed] Error deleting Elastic FileSystem (efs) %s in %s", aws.ToString(efsID), ef.Region) + logging.Debugf("[Failed] Error deleting Elastic FileSystem (efs) %s in %s", aws.ToString(efsID), scope) } + + return allErrs.ErrorOrNil() +} + +// waitForMountTargetsDeleted polls until all mount targets for the given EFS are deleted. +// It returns nil if mount targets are deleted or an error occurs (indicating the file system may not exist). +// It times out after 30 attempts (60 seconds total with 2-second intervals). +func waitForMountTargetsDeleted(ctx context.Context, client ElasticFileSystemAPI, efsID *string) error { + for i := 0; i < 30; i++ { + output, err := client.DescribeMountTargets(ctx, &efs.DescribeMountTargetsInput{ + FileSystemId: efsID, + }) + if err != nil { + // If we get an error (like FileSystemNotFound), mount targets are gone + return nil //nolint:nilerr // Error here indicates file system was deleted, which is success + } + if len(output.MountTargets) == 0 { + return nil + } + time.Sleep(2 * time.Second) + } + return fmt.Errorf("timed out waiting for mount targets to be deleted for EFS %s", aws.ToString(efsID)) +} + +// custom errors + +type TooManyElasticFileSystemsErr struct{} + +func (err TooManyElasticFileSystemsErr) Error() string { + return "Too many Elastic FileSystems requested at once." } diff --git a/aws/resources/efs_test.go b/aws/resources/efs_test.go index 91f95acb..901e072b 100644 --- a/aws/resources/efs_test.go +++ b/aws/resources/efs_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/efs" "github.com/aws/aws-sdk-go-v2/service/efs/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) @@ -20,6 +21,9 @@ type mockedElasticFileSystem struct { DescribeAccessPointsOutput efs.DescribeAccessPointsOutput DescribeMountTargetsOutput efs.DescribeMountTargetsOutput DescribeFileSystemsOutput efs.DescribeFileSystemsOutput + + // Track calls to DescribeMountTargets to simulate deletion + describeMountTargetsCalls int } func (m mockedElasticFileSystem) DeleteAccessPoint(ctx context.Context, params *efs.DeleteAccessPointInput, optFns ...func(*efs.Options)) (*efs.DeleteAccessPointOutput, error) { @@ -38,8 +42,14 @@ func (m mockedElasticFileSystem) DescribeAccessPoints(ctx context.Context, param return &m.DescribeAccessPointsOutput, nil } -func (m mockedElasticFileSystem) DescribeMountTargets(ctx context.Context, params *efs.DescribeMountTargetsInput, optFns ...func(*efs.Options)) (*efs.DescribeMountTargetsOutput, error) { - return &m.DescribeMountTargetsOutput, nil +func (m *mockedElasticFileSystem) DescribeMountTargets(ctx context.Context, params *efs.DescribeMountTargetsInput, optFns ...func(*efs.Options)) (*efs.DescribeMountTargetsOutput, error) { + m.describeMountTargetsCalls++ + // First call returns mount targets (used during enumeration and first waiter check) + // Subsequent calls return empty list (simulating mount targets being deleted) + if m.describeMountTargetsCalls <= 1 { + return &m.DescribeMountTargetsOutput, nil + } + return &efs.DescribeMountTargetsOutput{MountTargets: []types.MountTargetDescription{}}, nil } func (m mockedElasticFileSystem) DescribeFileSystems(ctx context.Context, params *efs.DescribeFileSystemsInput, optFns ...func(*efs.Options)) (*efs.DescribeFileSystemsOutput, error) { @@ -53,20 +63,18 @@ func TestEFS_GetAll(t *testing.T) { testId2 := "testId2" testName2 := "test-efs2" now := time.Now() - ef := ElasticFileSystem{ - Client: mockedElasticFileSystem{ - DescribeFileSystemsOutput: efs.DescribeFileSystemsOutput{ - FileSystems: []types.FileSystemDescription{ - { - FileSystemId: aws.String(testId1), - Name: aws.String(testName1), - CreationTime: aws.Time(now), - }, - { - FileSystemId: aws.String(testId2), - Name: aws.String(testName2), - CreationTime: aws.Time(now.Add(1)), - }, + client := &mockedElasticFileSystem{ + DescribeFileSystemsOutput: efs.DescribeFileSystemsOutput{ + FileSystems: []types.FileSystemDescription{ + { + FileSystemId: aws.String(testId1), + Name: aws.String(testName1), + CreationTime: aws.Time(now), + }, + { + FileSystemId: aws.String(testId2), + Name: aws.String(testName2), + CreationTime: aws.Time(now.Add(1)), }, }, }, @@ -99,9 +107,7 @@ func TestEFS_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := ef.getAll(context.Background(), config.Config{ - ElasticFileSystem: tc.configObj, - }) + names, err := listElasticFileSystems(context.Background(), client, resource.Scope{Region: "us-east-1"}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) @@ -110,28 +116,26 @@ func TestEFS_GetAll(t *testing.T) { func TestEFS_NukeAll(t *testing.T) { t.Parallel() - ef := ElasticFileSystem{ - Client: mockedElasticFileSystem{ - DescribeAccessPointsOutput: efs.DescribeAccessPointsOutput{ - AccessPoints: []types.AccessPointDescription{ - { - AccessPointId: aws.String("fsap-1234567890abcdef0"), - }, + client := &mockedElasticFileSystem{ + DescribeAccessPointsOutput: efs.DescribeAccessPointsOutput{ + AccessPoints: []types.AccessPointDescription{ + { + AccessPointId: aws.String("fsap-1234567890abcdef0"), }, }, - DescribeMountTargetsOutput: efs.DescribeMountTargetsOutput{ - MountTargets: []types.MountTargetDescription{ - { - MountTargetId: aws.String("fsmt-1234567890abcdef0"), - }, + }, + DescribeMountTargetsOutput: efs.DescribeMountTargetsOutput{ + MountTargets: []types.MountTargetDescription{ + { + MountTargetId: aws.String("fsmt-1234567890abcdef0"), }, }, - DeleteAccessPointOutput: efs.DeleteAccessPointOutput{}, - DeleteMountTargetOutput: efs.DeleteMountTargetOutput{}, - DeleteFileSystemOutput: efs.DeleteFileSystemOutput{}, }, + DeleteAccessPointOutput: efs.DeleteAccessPointOutput{}, + DeleteMountTargetOutput: efs.DeleteMountTargetOutput{}, + DeleteFileSystemOutput: efs.DeleteFileSystemOutput{}, } - err := ef.nukeAll([]*string{aws.String("fs-1234567890abcdef0")}) + err := deleteElasticFileSystems(context.Background(), client, resource.Scope{Region: "us-east-1"}, "efs", []*string{aws.String("fs-1234567890abcdef0")}) require.NoError(t, err) } diff --git a/aws/resources/efs_types.go b/aws/resources/efs_types.go deleted file mode 100644 index ea988224..00000000 --- a/aws/resources/efs_types.go +++ /dev/null @@ -1,72 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/efs" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type ElasticFileSystemAPI interface { - DeleteAccessPoint(ctx context.Context, params *efs.DeleteAccessPointInput, optFns ...func(*efs.Options)) (*efs.DeleteAccessPointOutput, error) - DeleteFileSystem(ctx context.Context, params *efs.DeleteFileSystemInput, optFns ...func(*efs.Options)) (*efs.DeleteFileSystemOutput, error) - DeleteMountTarget(ctx context.Context, params *efs.DeleteMountTargetInput, optFns ...func(*efs.Options)) (*efs.DeleteMountTargetOutput, error) - DescribeAccessPoints(ctx context.Context, params *efs.DescribeAccessPointsInput, optFns ...func(*efs.Options)) (*efs.DescribeAccessPointsOutput, error) - DescribeMountTargets(ctx context.Context, params *efs.DescribeMountTargetsInput, optFns ...func(*efs.Options)) (*efs.DescribeMountTargetsOutput, error) - DescribeFileSystems(ctx context.Context, params *efs.DescribeFileSystemsInput, optFns ...func(*efs.Options)) (*efs.DescribeFileSystemsOutput, error) -} - -type ElasticFileSystem struct { - BaseAwsResource - Client ElasticFileSystemAPI - Region string - Ids []string -} - -func (ef *ElasticFileSystem) Init(cfg aws.Config) { - ef.Client = efs.NewFromConfig(cfg) -} - -func (ef *ElasticFileSystem) ResourceName() string { - return "efs" -} - -func (ef *ElasticFileSystem) ResourceIdentifiers() []string { - return ef.Ids -} - -func (ef *ElasticFileSystem) MaxBatchSize() int { - return 10 -} - -func (ef *ElasticFileSystem) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.ElasticFileSystem -} - -func (ef *ElasticFileSystem) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := ef.getAll(c, configObj) - if err != nil { - return nil, err - } - - ef.Ids = aws.ToStringSlice(identifiers) - return ef.Ids, nil -} - -func (ef *ElasticFileSystem) Nuke(ctx context.Context, identifiers []string) error { - if err := ef.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} - -// custom errors - -type TooManyElasticFileSystemsErr struct{} - -func (err TooManyElasticFileSystemsErr) Error() string { - return "Too many Elastic FileSystems requested at once." -} diff --git a/aws/resources/eks.go b/aws/resources/eks.go index 5916f429..d19ff5ef 100644 --- a/aws/resources/eks.go +++ b/aws/resources/eks.go @@ -9,37 +9,73 @@ import ( "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" "github.com/hashicorp/go-multierror" ) -// getAll returns a list of strings of EKS Cluster Names that uniquely identify each cluster. -func (clusters *EKSClusters) getAll(c context.Context, configObj config.Config) ([]*string, error) { - result, err := clusters.Client.ListClusters(clusters.Context, &eks.ListClustersInput{}) +// EKSClustersAPI defines the interface for EKS Clusters operations. +type EKSClustersAPI interface { + DeleteCluster(ctx context.Context, params *eks.DeleteClusterInput, optFns ...func(*eks.Options)) (*eks.DeleteClusterOutput, error) + DeleteFargateProfile(ctx context.Context, params *eks.DeleteFargateProfileInput, optFns ...func(*eks.Options)) (*eks.DeleteFargateProfileOutput, error) + DeleteNodegroup(ctx context.Context, params *eks.DeleteNodegroupInput, optFns ...func(*eks.Options)) (*eks.DeleteNodegroupOutput, error) + DescribeCluster(ctx context.Context, params *eks.DescribeClusterInput, optFns ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) + DescribeFargateProfile(ctx context.Context, params *eks.DescribeFargateProfileInput, optFns ...func(*eks.Options)) (*eks.DescribeFargateProfileOutput, error) + DescribeNodegroup(ctx context.Context, params *eks.DescribeNodegroupInput, optFns ...func(*eks.Options)) (*eks.DescribeNodegroupOutput, error) + ListClusters(ctx context.Context, params *eks.ListClustersInput, optFns ...func(*eks.Options)) (*eks.ListClustersOutput, error) + ListFargateProfiles(ctx context.Context, params *eks.ListFargateProfilesInput, optFns ...func(*eks.Options)) (*eks.ListFargateProfilesOutput, error) + ListNodegroups(ctx context.Context, params *eks.ListNodegroupsInput, optFns ...func(*eks.Options)) (*eks.ListNodegroupsOutput, error) +} + +// NewEKSClusters creates a new EKS Clusters resource using the generic resource pattern. +func NewEKSClusters() AwsResource { + return NewAwsResource(&resource.Resource[EKSClustersAPI]{ + ResourceTypeName: "ekscluster", + // Tentative batch size to ensure AWS doesn't throttle. Note that deleting EKS clusters involves deleting many + // associated sub resources in tight loops, and they happen in parallel in go routines. We conservatively pick 10 + // here, both to limit overloading the runtime and to avoid AWS throttling with many API calls. + BatchSize: 10, + InitClient: func(r *resource.Resource[EKSClustersAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for EKS client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = eks.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.EKSCluster + }, + Lister: listEKSClusters, + Nuker: deleteEKSClusters, + }) +} + +// listEKSClusters retrieves all EKS clusters that match the config filters. +func listEKSClusters(ctx context.Context, client EKSClustersAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + result, err := client.ListClusters(ctx, &eks.ListClustersInput{}) if err != nil { return nil, errors.WithStackTrace(err) } - filteredClusters, err := clusters.filter(aws.StringSlice(result.Clusters), configObj) + filteredClusters, err := filterEKSClusters(ctx, client, aws.StringSlice(result.Clusters), cfg) if err != nil { return nil, errors.WithStackTrace(err) } return filteredClusters, nil } -// filter will take in the list of clusters and filter out any clusters that were created after -// `excludeAfter`, and those that are excluded by the config file. -func (clusters *EKSClusters) filter(clusterNames []*string, configObj config.Config) ([]*string, error) { +// filterEKSClusters filters EKS clusters based on the config. +func filterEKSClusters(ctx context.Context, client EKSClustersAPI, clusterNames []*string, cfg config.ResourceType) ([]*string, error) { var filteredEksClusterNames []*string for _, clusterName := range clusterNames { - describeResult, err := clusters.Client.DescribeCluster( - clusters.Context, - &eks.DescribeClusterInput{Name: clusterName}) + describeResult, err := client.DescribeCluster(ctx, &eks.DescribeClusterInput{Name: clusterName}) if err != nil { return nil, errors.WithStackTrace(err) } - if !configObj.EKSCluster.ShouldInclude(config.ResourceValue{ + if !cfg.ShouldInclude(config.ResourceValue{ Name: clusterName, Time: describeResult.Cluster.CreatedAt, Tags: describeResult.Cluster.Tags, @@ -53,10 +89,60 @@ func (clusters *EKSClusters) filter(clusterNames []*string, configObj config.Con return filteredEksClusterNames, nil } -// deleteAsync deletes the provided EKS Cluster asynchronously in a goroutine, using wait groups for +// deleteEKSClusters is a custom nuker function for EKS clusters. +// EKS clusters require deleting sub-resources (node groups, fargate profiles) before deletion. +func deleteEKSClusters(ctx context.Context, client EKSClustersAPI, scope resource.Scope, resourceType string, identifiers []*string) error { + numNuking := len(identifiers) + if numNuking == 0 { + logging.Debugf("No EKS clusters to nuke in region %s", scope.Region) + return nil + } + + // NOTE: we don't need to do pagination here, because the pagination is handled by the caller to this function, + // based on EKSCluster.MaxBatchSize, however we add a guard here to warn users when the batching fails and has a + // chance of throttling AWS. Since we concurrently make one call for each identifier, we pick 100 for the limit here + // because many APIs in AWS have a limit of 100 requests per second. + if numNuking > 100 { + logging.Debugf("Nuking too many EKS Clusters at once (100): halting to avoid hitting AWS API rate limiting") + return TooManyEKSClustersErr{} + } + + // We need to delete sub-resources associated with the EKS Cluster before being able to delete the cluster, so we + // spawn goroutines to drive the deletion of each EKS cluster. + logging.Debugf("Deleting %d EKS clusters in region %s", numNuking, scope.Region) + wg := new(sync.WaitGroup) + wg.Add(numNuking) + errChans := make([]chan error, numNuking) + for i, eksClusterName := range identifiers { + errChans[i] = make(chan error, 1) + go deleteEKSClusterAsync(ctx, client, wg, errChans[i], aws.ToString(eksClusterName)) + } + wg.Wait() + + // Collect all the errors from the async delete calls into a single error struct. + var allErrs *multierror.Error + for _, errChan := range errChans { + if err := <-errChan; err != nil { + allErrs = multierror.Append(allErrs, err) + logging.Debugf("[Failed] %s", err) + } + } + finalErr := allErrs.ErrorOrNil() + if finalErr != nil { + return errors.WithStackTrace(finalErr) + } + + // Now wait until the EKS Clusters are deleted + successfullyDeleted := waitUntilEksClustersDeleted(ctx, client, resourceType, identifiers) + numNuked := len(successfullyDeleted) + logging.Debugf("[OK] %d of %d EKS cluster(s) deleted in %s", numNuked, numNuking, scope.Region) + return nil +} + +// deleteEKSClusterAsync deletes the provided EKS Cluster asynchronously in a goroutine, using wait groups for // concurrency control and a return channel for errors. Note that this routine attempts to delete all managed compute // resources associated with the EKS cluster (Managed Node Groups and Fargate Profiles). -func (clusters *EKSClusters) deleteAsync(wg *sync.WaitGroup, errChan chan error, eksClusterName string) { +func deleteEKSClusterAsync(ctx context.Context, client EKSClustersAPI, wg *sync.WaitGroup, errChan chan error, eksClusterName string) { defer wg.Done() // Aggregate errors for each subresource being deleted @@ -65,23 +151,22 @@ func (clusters *EKSClusters) deleteAsync(wg *sync.WaitGroup, errChan chan error, // Since deleting node groups can take some time, we first schedule the deletion of them, and then move on to // deleting the Fargate profiles so that they can be done in parallel, before waiting for the node groups to be // deleted. - deletedNodeGroups, err := clusters.scheduleDeleteEKSClusterManagedNodeGroup(eksClusterName) + deletedNodeGroups, err := scheduleDeleteEKSClusterManagedNodeGroup(ctx, client, eksClusterName) if err != nil { allSubResourceErrs = multierror.Append(allSubResourceErrs, err) } - if err := clusters.deleteEKSClusterFargateProfiles(eksClusterName); err != nil { + if err := deleteEKSClusterFargateProfiles(ctx, client, eksClusterName); err != nil { allSubResourceErrs = multierror.Append(allSubResourceErrs, err) } // Make sure the node groups are actually deleted before returning. for _, nodeGroup := range deletedNodeGroups { - - waiter := eks.NewNodegroupDeletedWaiter(clusters.Client) - err := waiter.Wait(clusters.Context, &eks.DescribeNodegroupInput{ + waiter := eks.NewNodegroupDeletedWaiter(client) + err := waiter.Wait(ctx, &eks.DescribeNodegroupInput{ ClusterName: aws.String(eksClusterName), NodegroupName: nodeGroup, - }, clusters.Timeout) + }, DefaultWaitTimeout) if err != nil { logging.Debugf("[Failed] Failed waiting for Node Group %s associated with cluster %s to be deleted: %s", aws.ToString(nodeGroup), eksClusterName, err) allSubResourceErrs = multierror.Append(allSubResourceErrs, err) @@ -96,7 +181,7 @@ func (clusters *EKSClusters) deleteAsync(wg *sync.WaitGroup, errChan chan error, // At this point, all the sub resources of the EKS cluster has been confirmed to be deleted, so we can go ahead to // request to delete the EKS cluster. - _, deleteErr := clusters.Client.DeleteCluster(clusters.Context, &eks.DeleteClusterInput{Name: aws.String(eksClusterName)}) + _, deleteErr := client.DeleteCluster(ctx, &eks.DeleteClusterInput{Name: aws.String(eksClusterName)}) if deleteErr != nil { logging.Debugf("[Failed] Failed deleting EKS cluster %s: %s", eksClusterName, deleteErr) } @@ -106,14 +191,14 @@ func (clusters *EKSClusters) deleteAsync(wg *sync.WaitGroup, errChan chan error, // scheduleDeleteEKSClusterManagedNodeGroup looks up all the associated Managed Node Group resources on the EKS cluster // and requests each one to be deleted. Note that this function will not wait for the Node Groups to be deleted. This // will return the list of Node Groups that were successfully scheduled for deletion. -func (clusters *EKSClusters) scheduleDeleteEKSClusterManagedNodeGroup(eksClusterName string) ([]*string, error) { +func scheduleDeleteEKSClusterManagedNodeGroup(ctx context.Context, client EKSClustersAPI, eksClusterName string) ([]*string, error) { var allNodeGroups []*string - paginator := eks.NewListNodegroupsPaginator(clusters.Client, &eks.ListNodegroupsInput{ + paginator := eks.NewListNodegroupsPaginator(client, &eks.ListNodegroupsInput{ ClusterName: aws.String(eksClusterName), }) for paginator.HasMorePages() { - page, err := paginator.NextPage(context.Background()) + page, err := paginator.NextPage(ctx) if err != nil { return nil, errors.WithStackTrace(err) } @@ -128,12 +213,10 @@ func (clusters *EKSClusters) scheduleDeleteEKSClusterManagedNodeGroup(eksCluster var allDeleteErrs error var deletedNodeGroups []*string for _, nodeGroup := range allNodeGroups { - _, err := clusters.Client.DeleteNodegroup( - clusters.Context, - &eks.DeleteNodegroupInput{ - ClusterName: aws.String(eksClusterName), - NodegroupName: nodeGroup, - }) + _, err := client.DeleteNodegroup(ctx, &eks.DeleteNodegroupInput{ + ClusterName: aws.String(eksClusterName), + NodegroupName: nodeGroup, + }) if err != nil { logging.Debugf("[Failed] Failed deleting Node Group %s associated with cluster %s: %s", aws.ToString(nodeGroup), eksClusterName, err) allDeleteErrs = multierror.Append(allDeleteErrs, err) @@ -147,12 +230,12 @@ func (clusters *EKSClusters) scheduleDeleteEKSClusterManagedNodeGroup(eksCluster // deleteEKSClusterFargateProfiles looks up all the associated Fargate Profile resources on the EKS cluster and requests // each one to be deleted. Since only one Fargate Profile can be deleted at a time, this function will wait until the // Fargate Profile is actually deleted for each one before moving on to the next one. -func (clusters *EKSClusters) deleteEKSClusterFargateProfiles(eksClusterName string) error { +func deleteEKSClusterFargateProfiles(ctx context.Context, client EKSClustersAPI, eksClusterName string) error { var allFargateProfiles []*string - paginator := eks.NewListFargateProfilesPaginator(clusters.Client, &eks.ListFargateProfilesInput{ClusterName: aws.String(eksClusterName)}) + paginator := eks.NewListFargateProfilesPaginator(client, &eks.ListFargateProfilesInput{ClusterName: aws.String(eksClusterName)}) for paginator.HasMorePages() { - page, err := paginator.NextPage(context.Background()) + page, err := paginator.NextPage(ctx) if err != nil { return errors.WithStackTrace(err) } @@ -168,23 +251,21 @@ func (clusters *EKSClusters) deleteEKSClusterFargateProfiles(eksClusterName stri // Note that we aggregate deletion errors so that we at least attempt to delete all of them once. var allDeleteErrs error for _, fargateProfile := range allFargateProfiles { - _, err := clusters.Client.DeleteFargateProfile( - clusters.Context, - &eks.DeleteFargateProfileInput{ - ClusterName: aws.String(eksClusterName), - FargateProfileName: fargateProfile, - }) + _, err := client.DeleteFargateProfile(ctx, &eks.DeleteFargateProfileInput{ + ClusterName: aws.String(eksClusterName), + FargateProfileName: fargateProfile, + }) if err != nil { logging.Debugf("[Failed] Failed deleting Fargate Profile %s associated with cluster %s: %s", aws.ToString(fargateProfile), eksClusterName, err) allDeleteErrs = multierror.Append(allDeleteErrs, err) continue } - waiter := eks.NewFargateProfileDeletedWaiter(clusters.Client) - waitErr := waiter.Wait(clusters.Context, &eks.DescribeFargateProfileInput{ + waiter := eks.NewFargateProfileDeletedWaiter(client) + waitErr := waiter.Wait(ctx, &eks.DescribeFargateProfileInput{ ClusterName: aws.String(eksClusterName), FargateProfileName: fargateProfile, - }, clusters.Timeout) + }, DefaultWaitTimeout) if waitErr != nil { logging.Debugf("[Failed] Failed waiting for Fargate Profile %s associated with cluster %s to be deleted: %s", aws.ToString(fargateProfile), eksClusterName, waitErr) allDeleteErrs = multierror.Append(allDeleteErrs, waitErr) @@ -198,19 +279,18 @@ func (clusters *EKSClusters) deleteEKSClusterFargateProfiles(eksClusterName stri // waitUntilEksClustersDeleted waits until the EKS cluster has been actually deleted from AWS. Returns a list of EKS // cluster names that have been successfully deleted. -func (clusters *EKSClusters) waitUntilEksClustersDeleted(eksClusterNames []*string) []*string { +func waitUntilEksClustersDeleted(ctx context.Context, client EKSClustersAPI, resourceType string, eksClusterNames []*string) []*string { var successfullyDeleted []*string for _, eksClusterName := range eksClusterNames { - - waiter := eks.NewClusterDeletedWaiter(clusters.Client) - err := waiter.Wait(clusters.Context, &eks.DescribeClusterInput{ + waiter := eks.NewClusterDeletedWaiter(client) + err := waiter.Wait(ctx, &eks.DescribeClusterInput{ Name: eksClusterName, - }, clusters.Timeout) + }, DefaultWaitTimeout) // Record status of this resource e := report.Entry{ Identifier: aws.ToString(eksClusterName), - ResourceType: "EKS Cluster", + ResourceType: resourceType, Error: err, } report.Record(e) @@ -225,55 +305,6 @@ func (clusters *EKSClusters) waitUntilEksClustersDeleted(eksClusterNames []*stri return successfullyDeleted } -// nukeAll deletes all provided EKS clusters, waiting for them to be deleted before returning. -func (clusters *EKSClusters) nukeAll(eksClusterNames []*string) error { - numNuking := len(eksClusterNames) - if numNuking == 0 { - logging.Debugf("No EKS clusters to nuke in region %s", clusters.Region) - return nil - } - - // NOTE: we don't need to do pagination here, because the pagination is handled by the caller to this function, - // based on EKSCluster.MaxBatchSize, however we add a guard here to warn users when the batching fails and has a - // chance of throttling AWS. Since we concurrently make one call for each identifier, we pick 100 for the limit here - // because many APIs in AWS have a limit of 100 requests per second. - if numNuking > 100 { - logging.Debugf("Nuking too many EKS Clusters at once (100): halting to avoid hitting AWS API rate limiting") - return TooManyEKSClustersErr{} - } - - // We need to delete sub-resources associated with the EKS Cluster before being able to delete the cluster, so we - // spawn goroutines to drive the deletion of each EKS cluster. - logging.Debugf("Deleting %d EKS clusters in region %s", numNuking, clusters.Region) - wg := new(sync.WaitGroup) - wg.Add(numNuking) - errChans := make([]chan error, numNuking) - for i, eksClusterName := range eksClusterNames { - errChans[i] = make(chan error, 1) - go clusters.deleteAsync(wg, errChans[i], aws.ToString(eksClusterName)) - } - wg.Wait() - - // Collect all the errors from the async delete calls into a single error struct. - var allErrs *multierror.Error - for _, errChan := range errChans { - if err := <-errChan; err != nil { - allErrs = multierror.Append(allErrs, err) - logging.Debugf("[Failed] %s", err) - } - } - finalErr := allErrs.ErrorOrNil() - if finalErr != nil { - return errors.WithStackTrace(finalErr) - } - - // Now wait until the EKS Clusters are deleted - successfullyDeleted := clusters.waitUntilEksClustersDeleted(eksClusterNames) - numNuked := len(successfullyDeleted) - logging.Debugf("[OK] %d of %d EKS cluster(s) deleted in %s", numNuked, numNuking, clusters.Region) - return nil -} - // Custom errors type TooManyEKSClustersErr struct{} diff --git a/aws/resources/eks_test.go b/aws/resources/eks_test.go index 32a0ad60..b4183b1a 100644 --- a/aws/resources/eks_test.go +++ b/aws/resources/eks_test.go @@ -2,6 +2,7 @@ package resources import ( "context" + "fmt" "regexp" "testing" "time" @@ -9,7 +10,9 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/eks" "github.com/aws/aws-sdk-go-v2/service/eks/types" + "github.com/aws/smithy-go" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) @@ -19,6 +22,7 @@ type mockedEKSCluster struct { DeleteFargateProfileOutput eks.DeleteFargateProfileOutput DeleteNodegroupOutput eks.DeleteNodegroupOutput DescribeClusterOutputByName map[string]*eks.DescribeClusterOutput + DescribeClusterError error // Error to return for DescribeCluster (simulates deleted cluster) DescribeFargateProfileOutput eks.DescribeFargateProfileOutput DescribeNodegroupOutput eks.DescribeNodegroupOutput ListClustersOutput eks.ListClustersOutput @@ -39,6 +43,9 @@ func (m mockedEKSCluster) DeleteNodegroup(ctx context.Context, params *eks.Delet } func (m mockedEKSCluster) DescribeCluster(ctx context.Context, params *eks.DescribeClusterInput, optFns ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) { + if m.DescribeClusterError != nil { + return nil, m.DescribeClusterError + } return m.DescribeClusterOutputByName[aws.ToString(params.Name)], nil } @@ -61,39 +68,49 @@ func (m mockedEKSCluster) ListFargateProfiles(ctx context.Context, params *eks.L func (m mockedEKSCluster) ListNodegroups(ctx context.Context, params *eks.ListNodegroupsInput, optFns ...func(*eks.Options)) (*eks.ListNodegroupsOutput, error) { return &m.ListNodegroupsOutput, nil } -func TestEKSClusterGetAll(t *testing.T) { + +func TestEKSClusters_ResourceName(t *testing.T) { + r := NewEKSClusters() + require.Equal(t, "ekscluster", r.ResourceName()) +} + +func TestEKSClusters_MaxBatchSize(t *testing.T) { + r := NewEKSClusters() + require.Equal(t, 10, r.MaxBatchSize()) +} + +func TestListEKSClusters(t *testing.T) { t.Parallel() testClusterName1 := "test_cluster1" testClusterName2 := "test_cluster2" testClusterName3 := "test_cluster3" now := time.Now() - eksCluster := EKSClusters{ - Client: mockedEKSCluster{ - ListClustersOutput: eks.ListClustersOutput{ - Clusters: []string{testClusterName1, testClusterName2, testClusterName3}, - }, - DescribeClusterOutputByName: map[string]*eks.DescribeClusterOutput{ - testClusterName1: { - Cluster: &types.Cluster{ - Name: aws.String(testClusterName1), - CreatedAt: &now, - Tags: map[string]string{"foo": "bar"}, - }, + + mock := mockedEKSCluster{ + ListClustersOutput: eks.ListClustersOutput{ + Clusters: []string{testClusterName1, testClusterName2, testClusterName3}, + }, + DescribeClusterOutputByName: map[string]*eks.DescribeClusterOutput{ + testClusterName1: { + Cluster: &types.Cluster{ + Name: aws.String(testClusterName1), + CreatedAt: &now, + Tags: map[string]string{"foo": "bar"}, }, - testClusterName2: { - Cluster: &types.Cluster{ - Name: aws.String(testClusterName1), - CreatedAt: &now, - Tags: map[string]string{"foz": "boz"}, - }, + }, + testClusterName2: { + Cluster: &types.Cluster{ + Name: aws.String(testClusterName1), + CreatedAt: &now, + Tags: map[string]string{"foz": "boz"}, }, - testClusterName3: { - Cluster: &types.Cluster{ - Name: aws.String(testClusterName3), - CreatedAt: &now, - Tags: map[string]string{"faz": "baz"}, - }, + }, + testClusterName3: { + Cluster: &types.Cluster{ + Name: aws.String(testClusterName3), + CreatedAt: &now, + Tags: map[string]string{"faz": "baz"}, }, }, }, @@ -119,7 +136,7 @@ func TestEKSClusterGetAll(t *testing.T) { "tagInclusionFilter": { configObj: config.ResourceType{ IncludeRule: config.FilterRule{ - Tags: map[string]config.Expression{"foo": config.Expression{RE: *regexp.MustCompile("bar")}}, + Tags: map[string]config.Expression{"foo": {RE: *regexp.MustCompile("bar")}}, }, }, expected: []string{testClusterName1}, @@ -127,7 +144,7 @@ func TestEKSClusterGetAll(t *testing.T) { "tagExclusionFilter": { configObj: config.ResourceType{ ExcludeRule: config.FilterRule{ - Tags: map[string]config.Expression{"foo": config.Expression{RE: *regexp.MustCompile("bar")}}, + Tags: map[string]config.Expression{"foo": {RE: *regexp.MustCompile("bar")}}, }, }, expected: []string{testClusterName2, testClusterName3}, @@ -136,41 +153,48 @@ func TestEKSClusterGetAll(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := eksCluster.getAll(context.Background(), config.Config{ - EKSCluster: tc.configObj, - }) - + names, err := listEKSClusters(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) } } -func TestEKSClusterNukeAll(t *testing.T) { +// errMockEKSClusterNotFound simulates ResourceNotFoundException for EKS +type errMockEKSClusterNotFound struct{} + +func (e errMockEKSClusterNotFound) Error() string { + return fmt.Sprintf("%s: %s", e.ErrorCode(), e.ErrorMessage()) +} + +func (e errMockEKSClusterNotFound) ErrorCode() string { + return "ResourceNotFoundException" +} + +func (e errMockEKSClusterNotFound) ErrorMessage() string { + return "The specified cluster does not exist." +} + +func (e errMockEKSClusterNotFound) ErrorFault() smithy.ErrorFault { + return smithy.FaultClient +} + +func TestDeleteEKSClusters(t *testing.T) { t.Parallel() testClusterName := "test_cluster1" - eksCluster := EKSClusters{ - BaseAwsResource: BaseAwsResource{ - Context: context.Background(), - }, - Client: mockedEKSCluster{ - ListNodegroupsOutput: eks.ListNodegroupsOutput{}, - DescribeClusterOutputByName: map[string]*eks.DescribeClusterOutput{ - testClusterName: { - Cluster: &types.Cluster{ - Name: aws.String(testClusterName), - CreatedAt: aws.Time(time.Now()), - }, - }, - }, - ListFargateProfilesOutput: eks.ListFargateProfilesOutput{}, - DescribeNodegroupOutput: eks.DescribeNodegroupOutput{}, - DeleteFargateProfileOutput: eks.DeleteFargateProfileOutput{}, - DeleteClusterOutput: eks.DeleteClusterOutput{}, - DescribeFargateProfileOutput: eks.DescribeFargateProfileOutput{}, - }, + + // Mock returns ResourceNotFoundException to simulate the cluster being deleted + // This is required for the SDK waiter to succeed + mock := mockedEKSCluster{ + ListNodegroupsOutput: eks.ListNodegroupsOutput{}, + ListFargateProfilesOutput: eks.ListFargateProfilesOutput{}, + DescribeNodegroupOutput: eks.DescribeNodegroupOutput{}, + DeleteFargateProfileOutput: eks.DeleteFargateProfileOutput{}, + DeleteClusterOutput: eks.DeleteClusterOutput{}, + DescribeFargateProfileOutput: eks.DescribeFargateProfileOutput{}, + DescribeClusterError: errMockEKSClusterNotFound{}, // Simulate deleted cluster } - err := eksCluster.nukeAll([]*string{&testClusterName}) + err := deleteEKSClusters(context.Background(), mock, resource.Scope{Region: "us-east-1"}, "ekscluster", []*string{&testClusterName}) require.NoError(t, err) } diff --git a/aws/resources/eks_types.go b/aws/resources/eks_types.go deleted file mode 100644 index 3435d8ff..00000000 --- a/aws/resources/eks_types.go +++ /dev/null @@ -1,73 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/eks" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type EKSClustersAPI interface { - DeleteCluster(ctx context.Context, params *eks.DeleteClusterInput, optFns ...func(*eks.Options)) (*eks.DeleteClusterOutput, error) - DeleteFargateProfile(ctx context.Context, params *eks.DeleteFargateProfileInput, optFns ...func(*eks.Options)) (*eks.DeleteFargateProfileOutput, error) - DeleteNodegroup(ctx context.Context, params *eks.DeleteNodegroupInput, optFns ...func(*eks.Options)) (*eks.DeleteNodegroupOutput, error) - DescribeCluster(ctx context.Context, params *eks.DescribeClusterInput, optFns ...func(*eks.Options)) (*eks.DescribeClusterOutput, error) - DescribeFargateProfile(ctx context.Context, params *eks.DescribeFargateProfileInput, optFns ...func(*eks.Options)) (*eks.DescribeFargateProfileOutput, error) - DescribeNodegroup(ctx context.Context, params *eks.DescribeNodegroupInput, optFns ...func(*eks.Options)) (*eks.DescribeNodegroupOutput, error) - ListClusters(ctx context.Context, params *eks.ListClustersInput, optFns ...func(*eks.Options)) (*eks.ListClustersOutput, error) - ListFargateProfiles(ctx context.Context, params *eks.ListFargateProfilesInput, optFns ...func(*eks.Options)) (*eks.ListFargateProfilesOutput, error) - ListNodegroups(ctx context.Context, params *eks.ListNodegroupsInput, optFns ...func(*eks.Options)) (*eks.ListNodegroupsOutput, error) -} - -// EKSClusters - Represents all EKS clusters found in a region -type EKSClusters struct { - BaseAwsResource - Client EKSClustersAPI - Region string - Clusters []string -} - -func (clusters *EKSClusters) Init(cfg aws.Config) { - clusters.Client = eks.NewFromConfig(cfg) -} - -// ResourceName - The simple name of the aws resource -func (clusters *EKSClusters) ResourceName() string { - return "ekscluster" -} - -// ResourceIdentifiers - The Name of the collected EKS clusters -func (clusters *EKSClusters) ResourceIdentifiers() []string { - return clusters.Clusters -} - -func (clusters *EKSClusters) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.EKSCluster -} - -func (clusters *EKSClusters) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle. Note that deleting EKS clusters involves deleting many - // associated sub resources in tight loops, and they happen in parallel in go routines. We conservatively pick 10 - // here, both to limit overloading the runtime and to avoid AWS throttling with many API calls. - return 10 -} - -func (clusters *EKSClusters) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := clusters.getAll(c, configObj) - if err != nil { - return nil, err - } - - clusters.Clusters = aws.ToStringSlice(identifiers) - return clusters.Clusters, nil -} - -// Nuke - nuke all EKS Cluster resources -func (clusters *EKSClusters) Nuke(ctx context.Context, identifiers []string) error { - if err := clusters.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - return nil -} diff --git a/aws/resources/elasticache.go b/aws/resources/elasticache.go index af390f21..88ce588f 100644 --- a/aws/resources/elasticache.go +++ b/aws/resources/elasticache.go @@ -4,21 +4,52 @@ import ( "context" "errors" "fmt" - "strings" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/elasticache" "github.com/aws/aws-sdk-go-v2/service/elasticache/types" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" goerrors "github.com/gruntwork-io/go-commons/errors" ) -// Returns a formatted string of Elasticache cluster Ids -func (cache *Elasticaches) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// ElasticachesAPI defines the interface for Elasticache operations. +type ElasticachesAPI interface { + DescribeReplicationGroups(ctx context.Context, params *elasticache.DescribeReplicationGroupsInput, optFns ...func(*elasticache.Options)) (*elasticache.DescribeReplicationGroupsOutput, error) + DescribeCacheClusters(ctx context.Context, params *elasticache.DescribeCacheClustersInput, optFns ...func(*elasticache.Options)) (*elasticache.DescribeCacheClustersOutput, error) + DeleteCacheCluster(ctx context.Context, params *elasticache.DeleteCacheClusterInput, optFns ...func(*elasticache.Options)) (*elasticache.DeleteCacheClusterOutput, error) + DeleteReplicationGroup(ctx context.Context, params *elasticache.DeleteReplicationGroupInput, optFns ...func(*elasticache.Options)) (*elasticache.DeleteReplicationGroupOutput, error) +} + +// NewElasticaches creates a new Elasticaches resource using the generic resource pattern. +func NewElasticaches() AwsResource { + return NewAwsResource(&resource.Resource[ElasticachesAPI]{ + ResourceTypeName: "elasticache", + // Tentative batch size to ensure AWS doesn't throttle + BatchSize: 49, + InitClient: func(r *resource.Resource[ElasticachesAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for Elasticache client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = elasticache.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.Elasticache + }, + Lister: listElasticaches, + // Use SequentialDeleter since each deletion involves waiters + Nuker: resource.SequentialDeleter(deleteElasticacheCluster), + }) +} + +// listElasticaches retrieves all Elasticache clusters that match the config filters. +func listElasticaches(ctx context.Context, client ElasticachesAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { // First, get any cache clusters that are replication groups, which will be the case for all multi-node Redis clusters - replicationGroupsResult, replicationGroupsErr := cache.Client.DescribeReplicationGroups(cache.Context, &elasticache.DescribeReplicationGroupsInput{}) + replicationGroupsResult, replicationGroupsErr := client.DescribeReplicationGroups(ctx, &elasticache.DescribeReplicationGroupsInput{}) if replicationGroupsErr != nil { return nil, goerrors.WithStackTrace(replicationGroupsErr) } @@ -26,18 +57,16 @@ func (cache *Elasticaches) getAll(c context.Context, configObj config.Config) ([ // Next, get any cache clusters that are not members of a replication group: meaning: // 1. any cache clusters with an Engine of "memcached" // 2. any single node Redis clusters - cacheClustersResult, cacheClustersErr := cache.Client.DescribeCacheClusters( - cache.Context, - &elasticache.DescribeCacheClustersInput{ - ShowCacheClustersNotInReplicationGroups: aws.Bool(true), - }) + cacheClustersResult, cacheClustersErr := client.DescribeCacheClusters(ctx, &elasticache.DescribeCacheClustersInput{ + ShowCacheClustersNotInReplicationGroups: aws.Bool(true), + }) if cacheClustersErr != nil { return nil, goerrors.WithStackTrace(cacheClustersErr) } var clusterIds []*string for _, replicationGroup := range replicationGroupsResult.ReplicationGroups { - if configObj.Elasticache.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Name: replicationGroup.ReplicationGroupId, Time: replicationGroup.ReplicationGroupCreateTime, }) { @@ -46,7 +75,7 @@ func (cache *Elasticaches) getAll(c context.Context, configObj config.Config) ([ } for _, cluster := range cacheClustersResult.CacheClusters { - if configObj.Elasticache.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Name: cluster.CacheClusterId, Time: cluster.CacheClusterCreateTime, }) { @@ -64,12 +93,12 @@ const ( Single CacheClusterType = "single" ) -func (cache *Elasticaches) determineCacheClusterType(clusterId *string) (*string, CacheClusterType, error) { +func determineCacheClusterType(ctx context.Context, client ElasticachesAPI, clusterId *string) (*string, CacheClusterType, error) { replicationGroupDescribeParams := &elasticache.DescribeReplicationGroupsInput{ ReplicationGroupId: clusterId, } - replicationGroupOutput, describeReplicationGroupsErr := cache.Client.DescribeReplicationGroups(cache.Context, replicationGroupDescribeParams) + replicationGroupOutput, describeReplicationGroupsErr := client.DescribeReplicationGroups(ctx, replicationGroupDescribeParams) if describeReplicationGroupsErr != nil { // GlobalReplicationGroupNotFoundFault var eRG404 *types.ReplicationGroupNotFoundFault @@ -91,7 +120,7 @@ func (cache *Elasticaches) determineCacheClusterType(clusterId *string) (*string CacheClusterId: clusterId, } - cacheClustersOutput, describeErr := cache.Client.DescribeCacheClusters(cache.Context, describeParams) + cacheClustersOutput, describeErr := client.DescribeCacheClusters(ctx, describeParams) if describeErr != nil { var eC404 *types.CacheClusterNotFoundFault if errors.As(describeErr, &eC404) { @@ -108,44 +137,36 @@ func (cache *Elasticaches) determineCacheClusterType(clusterId *string) (*string return nil, Single, CouldNotLookupCacheClusterErr{ClusterId: clusterId} } -func (cache *Elasticaches) nukeNonReplicationGroupElasticacheCluster(clusterId *string) error { +func nukeNonReplicationGroupElasticacheCluster(ctx context.Context, client ElasticachesAPI, clusterId *string) error { logging.Debugf("Deleting Elasticache cluster Id: %s which is not a member of a replication group", aws.ToString(clusterId)) params := elasticache.DeleteCacheClusterInput{ CacheClusterId: clusterId, } - _, err := cache.Client.DeleteCacheCluster(cache.Context, ¶ms) + _, err := client.DeleteCacheCluster(ctx, ¶ms) if err != nil { return err } - waiter := elasticache.NewCacheClusterDeletedWaiter(cache.Client) + waiter := elasticache.NewCacheClusterDeletedWaiter(client) - return waiter.Wait( - cache.Context, - &elasticache.DescribeCacheClustersInput{ - CacheClusterId: clusterId, - }, - cache.Timeout, - ) + return waiter.Wait(ctx, &elasticache.DescribeCacheClustersInput{ + CacheClusterId: clusterId, + }, DefaultWaitTimeout) } -func (cache *Elasticaches) nukeReplicationGroupMemberElasticacheCluster(clusterId *string) error { +func nukeReplicationGroupMemberElasticacheCluster(ctx context.Context, client ElasticachesAPI, clusterId *string) error { logging.Debugf("Elasticache cluster Id: %s is a member of a replication group. Therefore, deleting its replication group", aws.ToString(clusterId)) params := &elasticache.DeleteReplicationGroupInput{ ReplicationGroupId: clusterId, } - _, err := cache.Client.DeleteReplicationGroup(cache.Context, params) + _, err := client.DeleteReplicationGroup(ctx, params) if err != nil { return err } - waiter := elasticache.NewReplicationGroupDeletedWaiter(cache.Client) - waitErr := waiter.Wait( - cache.Context, - &elasticache.DescribeReplicationGroupsInput{ReplicationGroupId: clusterId}, - cache.Timeout, - ) + waiter := elasticache.NewReplicationGroupDeletedWaiter(client) + waitErr := waiter.Wait(ctx, &elasticache.DescribeReplicationGroupsInput{ReplicationGroupId: clusterId}, DefaultWaitTimeout) if waitErr != nil { return waitErr @@ -156,50 +177,23 @@ func (cache *Elasticaches) nukeReplicationGroupMemberElasticacheCluster(clusterI return nil } -func (cache *Elasticaches) nukeAll(clusterIds []*string) error { - if len(clusterIds) == 0 { - logging.Debugf("No Elasticache clusters to nuke in region %s", cache.Region) - return nil +// deleteElasticacheCluster deletes a single Elasticache cluster. +// It determines whether the cluster is standalone or part of a replication group +// and calls the appropriate delete function. +func deleteElasticacheCluster(ctx context.Context, client ElasticachesAPI, clusterId *string) error { + // We need to look up the cache cluster to determine if it is a member of a replication group or not, + // because there are two separate codepaths for deleting a cluster. Cache clusters that are not members of a + // replication group can be deleted via DeleteCacheCluster, whereas those that are members require a call to + // DeleteReplicationGroup, which will destroy both the replication group and its member clusters + resolvedClusterId, clusterType, err := determineCacheClusterType(ctx, client, clusterId) + if err != nil { + return err } - logging.Debugf("Deleting %d Elasticache clusters in region %s", len(clusterIds), cache.Region) - - var deletedClusterIds []*string - for _, clusterId := range clusterIds { - // We need to look up the cache cluster again to determine if it is a member of a replication group or not, - // because there are two separate codepaths for deleting a cluster. Cache clusters that are not members of a - // replication group can be deleted via DeleteCacheCluster, whereas those that are members require a call to - // DeleteReplicationGroup, which will destroy both the replication group and its member clusters - clusterId, clusterType, describeErr := cache.determineCacheClusterType(clusterId) - if describeErr != nil { - return describeErr - } - - var err error - if clusterType == Single { - err = cache.nukeNonReplicationGroupElasticacheCluster(clusterId) - } else if clusterType == Replication { - err = cache.nukeReplicationGroupMemberElasticacheCluster(clusterId) - } - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(clusterId), - ResourceType: "Elasticache", - Error: err, - } - report.Record(e) - - if err != nil { - logging.Debugf("[Failed] %s", err) - } else { - deletedClusterIds = append(deletedClusterIds, clusterId) - logging.Debugf("Deleted Elasticache cluster: %s", *clusterId) - } + if clusterType == Single { + return nukeNonReplicationGroupElasticacheCluster(ctx, client, resolvedClusterId) } - - logging.Debugf("[OK] %d Elasticache clusters deleted in %s", len(deletedClusterIds), cache.Region) - return nil + return nukeReplicationGroupMemberElasticacheCluster(ctx, client, resolvedClusterId) } // Custom errors @@ -211,121 +205,3 @@ type CouldNotLookupCacheClusterErr struct { func (err CouldNotLookupCacheClusterErr) Error() string { return fmt.Sprintf("Failed to lookup clusterId: %s", aws.ToString(err.ClusterId)) } - -/* -Elasticache Parameter Groups -*/ - -func (pg *ElasticacheParameterGroups) getAll(c context.Context, configObj config.Config) ([]*string, error) { - var paramGroupNames []*string - - paginator := elasticache.NewDescribeCacheParameterGroupsPaginator(pg.Client, &elasticache.DescribeCacheParameterGroupsInput{}) - for paginator.HasMorePages() { - page, err := paginator.NextPage(c) - if err != nil { - return nil, goerrors.WithStackTrace(err) - } - - for _, paramGroup := range page.CacheParameterGroups { - if pg.shouldInclude(¶mGroup, configObj) { - paramGroupNames = append(paramGroupNames, paramGroup.CacheParameterGroupName) - } - } - } - - return paramGroupNames, nil -} - -func (pg *ElasticacheParameterGroups) shouldInclude(paramGroup *types.CacheParameterGroup, configObj config.Config) bool { - if paramGroup == nil { - return false - } - // Exclude AWS managed resources. user defined resources are unable to begin with "default." - if strings.HasPrefix(aws.ToString(paramGroup.CacheParameterGroupName), "default.") { - return false - } - - return configObj.ElasticacheParameterGroups.ShouldInclude(config.ResourceValue{ - Name: paramGroup.CacheParameterGroupName, - }) -} - -func (pg *ElasticacheParameterGroups) nukeAll(paramGroupNames []*string) error { - if len(paramGroupNames) == 0 { - logging.Debugf("No Elasticache parameter groups to nuke in region %s", pg.Region) - return nil - } - var deletedGroupNames []*string - for _, pgName := range paramGroupNames { - _, err := pg.Client.DeleteCacheParameterGroup(pg.Context, &elasticache.DeleteCacheParameterGroupInput{CacheParameterGroupName: pgName}) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(pgName), - ResourceType: "Elasticache Parameter Group", - Error: err, - } - report.Record(e) - if err != nil { - logging.Debugf("[Failed] %s", err) - } else { - deletedGroupNames = append(deletedGroupNames, pgName) - logging.Debugf("Deleted Elasticache parameter group: %s", aws.ToString(pgName)) - } - } - logging.Debugf("[OK] %d Elasticache parameter groups deleted in %s", len(deletedGroupNames), pg.Region) - return nil -} - -/* -Elasticache Subnet Groups -*/ -func (sg *ElasticacheSubnetGroups) getAll(c context.Context, configObj config.Config) ([]*string, error) { - var subnetGroupNames []*string - - paginator := elasticache.NewDescribeCacheSubnetGroupsPaginator(sg.Client, &elasticache.DescribeCacheSubnetGroupsInput{}) - for paginator.HasMorePages() { - page, err := paginator.NextPage(c) - if err != nil { - return nil, goerrors.WithStackTrace(err) - } - - for _, subnetGroup := range page.CacheSubnetGroups { - if !strings.Contains(*subnetGroup.CacheSubnetGroupName, "default") && - configObj.ElasticacheSubnetGroups.ShouldInclude(config.ResourceValue{ - Name: subnetGroup.CacheSubnetGroupName, - }) { - subnetGroupNames = append(subnetGroupNames, subnetGroup.CacheSubnetGroupName) - } - } - } - - return subnetGroupNames, nil -} - -func (sg *ElasticacheSubnetGroups) nukeAll(subnetGroupNames []*string) error { - if len(subnetGroupNames) == 0 { - logging.Debugf("No Elasticache subnet groups to nuke in region %s", sg.Region) - return nil - } - var deletedGroupNames []*string - for _, sgName := range subnetGroupNames { - _, err := sg.Client.DeleteCacheSubnetGroup(sg.Context, &elasticache.DeleteCacheSubnetGroupInput{CacheSubnetGroupName: sgName}) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(sgName), - ResourceType: "Elasticache Subnet Group", - Error: err, - } - report.Record(e) - if err != nil { - logging.Debugf("[Failed] %s", err) - } else { - deletedGroupNames = append(deletedGroupNames, sgName) - logging.Debugf("Deleted Elasticache subnet group: %s", aws.ToString(sgName)) - } - } - logging.Debugf("[OK] %d Elasticache subnet groups deleted in %s", len(deletedGroupNames), sg.Region) - return nil -} diff --git a/aws/resources/elasticache_parameter_group.go b/aws/resources/elasticache_parameter_group.go new file mode 100644 index 00000000..442b3000 --- /dev/null +++ b/aws/resources/elasticache_parameter_group.go @@ -0,0 +1,85 @@ +package resources + +import ( + "context" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/elasticache" + "github.com/aws/aws-sdk-go-v2/service/elasticache/types" + "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/logging" + "github.com/gruntwork-io/cloud-nuke/resource" + "github.com/gruntwork-io/go-commons/errors" +) + +// ElasticacheParameterGroupsAPI defines the interface for Elasticache Parameter Group operations. +type ElasticacheParameterGroupsAPI interface { + DescribeCacheParameterGroups(ctx context.Context, params *elasticache.DescribeCacheParameterGroupsInput, optFns ...func(*elasticache.Options)) (*elasticache.DescribeCacheParameterGroupsOutput, error) + DeleteCacheParameterGroup(ctx context.Context, params *elasticache.DeleteCacheParameterGroupInput, optFns ...func(*elasticache.Options)) (*elasticache.DeleteCacheParameterGroupOutput, error) +} + +// NewElasticacheParameterGroups creates a new Elasticache Parameter Groups resource using the generic resource pattern. +func NewElasticacheParameterGroups() AwsResource { + return NewAwsResource(&resource.Resource[ElasticacheParameterGroupsAPI]{ + ResourceTypeName: "elasticacheParameterGroups", + BatchSize: 49, + InitClient: func(r *resource.Resource[ElasticacheParameterGroupsAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for Elasticache client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = elasticache.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.ElasticacheParameterGroups + }, + Lister: listElasticacheParameterGroups, + Nuker: resource.SimpleBatchDeleter(deleteElasticacheParameterGroup), + }) +} + +// listElasticacheParameterGroups retrieves all Elasticache parameter groups that match the config filters. +func listElasticacheParameterGroups(ctx context.Context, client ElasticacheParameterGroupsAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + var paramGroupNames []*string + + paginator := elasticache.NewDescribeCacheParameterGroupsPaginator(client, &elasticache.DescribeCacheParameterGroupsInput{}) + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return nil, errors.WithStackTrace(err) + } + + for _, paramGroup := range page.CacheParameterGroups { + if shouldIncludeElasticacheParameterGroup(¶mGroup, cfg) { + paramGroupNames = append(paramGroupNames, paramGroup.CacheParameterGroupName) + } + } + } + + return paramGroupNames, nil +} + +func shouldIncludeElasticacheParameterGroup(paramGroup *types.CacheParameterGroup, cfg config.ResourceType) bool { + if paramGroup == nil { + return false + } + // Exclude AWS managed resources. user defined resources are unable to begin with "default." + if strings.HasPrefix(aws.ToString(paramGroup.CacheParameterGroupName), "default.") { + return false + } + + return cfg.ShouldInclude(config.ResourceValue{ + Name: paramGroup.CacheParameterGroupName, + }) +} + +// deleteElasticacheParameterGroup deletes a single Elasticache parameter group. +func deleteElasticacheParameterGroup(ctx context.Context, client ElasticacheParameterGroupsAPI, identifier *string) error { + _, err := client.DeleteCacheParameterGroup(ctx, &elasticache.DeleteCacheParameterGroupInput{ + CacheParameterGroupName: identifier, + }) + return err +} diff --git a/aws/resources/elasticache_subnet_group.go b/aws/resources/elasticache_subnet_group.go new file mode 100644 index 00000000..62a20276 --- /dev/null +++ b/aws/resources/elasticache_subnet_group.go @@ -0,0 +1,73 @@ +package resources + +import ( + "context" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/elasticache" + "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/logging" + "github.com/gruntwork-io/cloud-nuke/resource" + "github.com/gruntwork-io/go-commons/errors" +) + +// ElasticacheSubnetGroupsAPI defines the interface for Elasticache Subnet Group operations. +type ElasticacheSubnetGroupsAPI interface { + DescribeCacheSubnetGroups(ctx context.Context, params *elasticache.DescribeCacheSubnetGroupsInput, optFns ...func(*elasticache.Options)) (*elasticache.DescribeCacheSubnetGroupsOutput, error) + DeleteCacheSubnetGroup(ctx context.Context, params *elasticache.DeleteCacheSubnetGroupInput, optFns ...func(*elasticache.Options)) (*elasticache.DeleteCacheSubnetGroupOutput, error) +} + +// NewElasticacheSubnetGroups creates a new Elasticache Subnet Groups resource using the generic resource pattern. +func NewElasticacheSubnetGroups() AwsResource { + return NewAwsResource(&resource.Resource[ElasticacheSubnetGroupsAPI]{ + ResourceTypeName: "elasticacheSubnetGroups", + BatchSize: 49, + InitClient: func(r *resource.Resource[ElasticacheSubnetGroupsAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for Elasticache client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = elasticache.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.ElasticacheSubnetGroups + }, + Lister: listElasticacheSubnetGroups, + Nuker: resource.SimpleBatchDeleter(deleteElasticacheSubnetGroup), + }) +} + +// listElasticacheSubnetGroups retrieves all Elasticache subnet groups that match the config filters. +func listElasticacheSubnetGroups(ctx context.Context, client ElasticacheSubnetGroupsAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + var subnetGroupNames []*string + + paginator := elasticache.NewDescribeCacheSubnetGroupsPaginator(client, &elasticache.DescribeCacheSubnetGroupsInput{}) + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return nil, errors.WithStackTrace(err) + } + + for _, subnetGroup := range page.CacheSubnetGroups { + if !strings.Contains(aws.ToString(subnetGroup.CacheSubnetGroupName), "default") && + cfg.ShouldInclude(config.ResourceValue{ + Name: subnetGroup.CacheSubnetGroupName, + }) { + subnetGroupNames = append(subnetGroupNames, subnetGroup.CacheSubnetGroupName) + } + } + } + + return subnetGroupNames, nil +} + +// deleteElasticacheSubnetGroup deletes a single Elasticache subnet group. +func deleteElasticacheSubnetGroup(ctx context.Context, client ElasticacheSubnetGroupsAPI, identifier *string) error { + _, err := client.DeleteCacheSubnetGroup(ctx, &elasticache.DeleteCacheSubnetGroupInput{ + CacheSubnetGroupName: identifier, + }) + return err +} diff --git a/aws/resources/elasticache_test.go b/aws/resources/elasticache_test.go index 9618972c..f44520d7 100644 --- a/aws/resources/elasticache_test.go +++ b/aws/resources/elasticache_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/elasticache" "github.com/aws/aws-sdk-go-v2/service/elasticache/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) @@ -37,30 +38,39 @@ func (m mockedElasticache) DeleteReplicationGroup(ctx context.Context, params *e return &m.DeleteReplicationGroupOutput, nil } -func TestElasticache_GetAll(t *testing.T) { +func TestElasticaches_ResourceName(t *testing.T) { + r := NewElasticaches() + require.Equal(t, "elasticache", r.ResourceName()) +} + +func TestElasticaches_MaxBatchSize(t *testing.T) { + r := NewElasticaches() + require.Equal(t, 49, r.MaxBatchSize()) +} + +func TestListElasticaches(t *testing.T) { t.Parallel() now := time.Now() testName1 := "test-name-1" testName2 := "test-name-2" - ec := Elasticaches{ - Client: mockedElasticache{ - DescribeReplicationGroupsOutput: elasticache.DescribeReplicationGroupsOutput{ - ReplicationGroups: []types.ReplicationGroup{ - { - ReplicationGroupId: aws.String(testName1), - ReplicationGroupCreateTime: aws.Time(now), - }, - { - ReplicationGroupId: aws.String(testName2), - ReplicationGroupCreateTime: aws.Time(now.Add(1)), - }, + + mock := mockedElasticache{ + DescribeReplicationGroupsOutput: elasticache.DescribeReplicationGroupsOutput{ + ReplicationGroups: []types.ReplicationGroup{ + { + ReplicationGroupId: aws.String(testName1), + ReplicationGroupCreateTime: aws.Time(now), + }, + { + ReplicationGroupId: aws.String(testName2), + ReplicationGroupCreateTime: aws.Time(now.Add(1)), }, - }, - DescribeCacheClustersOutput: elasticache.DescribeCacheClustersOutput{ - CacheClusters: []types.CacheCluster{}, }, }, + DescribeCacheClustersOutput: elasticache.DescribeCacheClustersOutput{ + CacheClusters: []types.CacheCluster{}, + }, } tests := map[string]struct { @@ -90,37 +100,30 @@ func TestElasticache_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := ec.getAll(context.Background(), config.Config{ - Elasticache: tc.configObj, - }) + names, err := listElasticaches(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) } } -func TestElasticache_NukeAll(t *testing.T) { +func TestDeleteElasticacheCluster(t *testing.T) { t.Parallel() - ec := Elasticaches{ - BaseAwsResource: BaseAwsResource{ - Context: context.Background(), - }, - Client: mockedElasticache{ - DescribeReplicationGroupsOutput: elasticache.DescribeReplicationGroupsOutput{ - ReplicationGroups: []types.ReplicationGroup{ - { - ReplicationGroupId: aws.String("test-name-1"), - ReplicationGroupCreateTime: aws.Time(time.Now()), - Status: aws.String("deleted"), - }, + mock := mockedElasticache{ + DescribeReplicationGroupsOutput: elasticache.DescribeReplicationGroupsOutput{ + ReplicationGroups: []types.ReplicationGroup{ + { + ReplicationGroupId: aws.String("test-name-1"), + ReplicationGroupCreateTime: aws.Time(time.Now()), + Status: aws.String("deleted"), }, }, - DescribeCacheClustersOutput: elasticache.DescribeCacheClustersOutput{}, - DeleteReplicationGroupOutput: elasticache.DeleteReplicationGroupOutput{}, }, + DescribeCacheClustersOutput: elasticache.DescribeCacheClustersOutput{}, + DeleteReplicationGroupOutput: elasticache.DeleteReplicationGroupOutput{}, } - err := ec.nukeAll(aws.StringSlice([]string{"test-name-1"})) + err := deleteElasticacheCluster(context.Background(), mock, aws.String("test-name-1")) require.NoError(t, err) } diff --git a/aws/resources/elasticache_types.go b/aws/resources/elasticache_types.go deleted file mode 100644 index 589b04c8..00000000 --- a/aws/resources/elasticache_types.go +++ /dev/null @@ -1,172 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/elasticache" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type ElasticachesAPI interface { - DescribeReplicationGroups(ctx context.Context, params *elasticache.DescribeReplicationGroupsInput, optFns ...func(*elasticache.Options)) (*elasticache.DescribeReplicationGroupsOutput, error) - DescribeCacheClusters(ctx context.Context, params *elasticache.DescribeCacheClustersInput, optFns ...func(*elasticache.Options)) (*elasticache.DescribeCacheClustersOutput, error) - DeleteCacheCluster(ctx context.Context, params *elasticache.DeleteCacheClusterInput, optFns ...func(*elasticache.Options)) (*elasticache.DeleteCacheClusterOutput, error) - DeleteReplicationGroup(ctx context.Context, params *elasticache.DeleteReplicationGroupInput, optFns ...func(*elasticache.Options)) (*elasticache.DeleteReplicationGroupOutput, error) -} - -// Elasticaches - represents all Elasticache clusters -type Elasticaches struct { - BaseAwsResource - Client ElasticachesAPI - Region string - ClusterIds []string -} - -func (cache *Elasticaches) Init(cfg aws.Config) { - cache.Client = elasticache.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (cache *Elasticaches) ResourceName() string { - return "elasticache" -} - -// ResourceIdentifiers - The instance ids of the elasticache clusters -func (cache *Elasticaches) ResourceIdentifiers() []string { - return cache.ClusterIds -} - -func (cache *Elasticaches) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (cache *Elasticaches) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.Elasticache -} - -func (cache *Elasticaches) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := cache.getAll(c, configObj) - if err != nil { - return nil, err - } - - cache.ClusterIds = aws.ToStringSlice(identifiers) - return cache.ClusterIds, nil -} - -// Nuke - nuke 'em all!!! -func (cache *Elasticaches) Nuke(ctx context.Context, identifiers []string) error { - if err := cache.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} - -/* -Elasticache Parameter Groups -*/ -type ElasticacheParameterGroupsAPI interface { - DescribeCacheParameterGroups(ctx context.Context, params *elasticache.DescribeCacheParameterGroupsInput, optFns ...func(*elasticache.Options)) (*elasticache.DescribeCacheParameterGroupsOutput, error) - DeleteCacheParameterGroup(ctx context.Context, params *elasticache.DeleteCacheParameterGroupInput, optFns ...func(*elasticache.Options)) (*elasticache.DeleteCacheParameterGroupOutput, error) -} - -type ElasticacheParameterGroups struct { - BaseAwsResource - Client ElasticacheParameterGroupsAPI - Region string - GroupNames []string -} - -func (pg *ElasticacheParameterGroups) Init(cfg aws.Config) { - pg.Client = elasticache.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (pg *ElasticacheParameterGroups) ResourceName() string { - return "elasticacheParameterGroups" -} - -// ResourceIdentifiers - The instance ids of the ec2 instances -func (pg *ElasticacheParameterGroups) ResourceIdentifiers() []string { - return pg.GroupNames -} - -func (pg *ElasticacheParameterGroups) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (pg *ElasticacheParameterGroups) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := pg.getAll(c, configObj) - if err != nil { - return nil, err - } - - pg.GroupNames = aws.ToStringSlice(identifiers) - return pg.GroupNames, nil -} - -// Nuke - nuke 'em all!!! -func (pg *ElasticacheParameterGroups) Nuke(ctx context.Context, identifiers []string) error { - if err := pg.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} - -/* -Elasticache Subnet Groups -*/ -type ElasticacheSubnetGroupsAPI interface { - DescribeCacheSubnetGroups(ctx context.Context, params *elasticache.DescribeCacheSubnetGroupsInput, optFns ...func(*elasticache.Options)) (*elasticache.DescribeCacheSubnetGroupsOutput, error) - DeleteCacheSubnetGroup(ctx context.Context, params *elasticache.DeleteCacheSubnetGroupInput, optFns ...func(*elasticache.Options)) (*elasticache.DeleteCacheSubnetGroupOutput, error) -} - -type ElasticacheSubnetGroups struct { - BaseAwsResource - Client ElasticacheSubnetGroupsAPI - Region string - GroupNames []string -} - -func (sg *ElasticacheSubnetGroups) Init(cfg aws.Config) { - sg.Client = elasticache.NewFromConfig(cfg) -} - -func (sg *ElasticacheSubnetGroups) ResourceName() string { - return "elasticacheSubnetGroups" -} - -// ResourceIdentifiers - The instance ids of the ec2 instances -func (sg *ElasticacheSubnetGroups) ResourceIdentifiers() []string { - return sg.GroupNames -} - -func (sg *ElasticacheSubnetGroups) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (sg *ElasticacheSubnetGroups) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := sg.getAll(c, configObj) - if err != nil { - return nil, err - } - - sg.GroupNames = aws.ToStringSlice(identifiers) - return sg.GroupNames, nil -} - -// Nuke - nuke 'em all!!! -func (sg *ElasticacheSubnetGroups) Nuke(ctx context.Context, identifiers []string) error { - if err := sg.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/elb.go b/aws/resources/elb.go index 52262b49..22bc1823 100644 --- a/aws/resources/elb.go +++ b/aws/resources/elb.go @@ -8,38 +8,48 @@ import ( "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" ) -func (balancer *LoadBalancers) waitUntilElbDeleted(input *elasticloadbalancing.DescribeLoadBalancersInput) error { - for i := 0; i < 30; i++ { - output, err := balancer.Client.DescribeLoadBalancers(balancer.Context, input) - if err != nil { - return err - } - - if len(output.LoadBalancerDescriptions) == 0 { - return nil - } - - time.Sleep(1 * time.Second) - logging.Debug("Waiting for ELB to be deleted") - } +// LoadBalancersAPI defines the interface for ELB operations. +type LoadBalancersAPI interface { + DescribeLoadBalancers(ctx context.Context, params *elasticloadbalancing.DescribeLoadBalancersInput, optFns ...func(*elasticloadbalancing.Options)) (*elasticloadbalancing.DescribeLoadBalancersOutput, error) + DeleteLoadBalancer(ctx context.Context, params *elasticloadbalancing.DeleteLoadBalancerInput, optFns ...func(*elasticloadbalancing.Options)) (*elasticloadbalancing.DeleteLoadBalancerOutput, error) +} - return ElbDeleteError{} +// NewLoadBalancers creates a new LoadBalancers resource using the generic resource pattern. +func NewLoadBalancers() AwsResource { + return NewAwsResource(&resource.Resource[LoadBalancersAPI]{ + ResourceTypeName: "elb", + BatchSize: 49, + InitClient: func(r *resource.Resource[LoadBalancersAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for LoadBalancers client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = elasticloadbalancing.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.ELBv1 + }, + Lister: listLoadBalancers, + Nuker: resource.SequentialDeleter(deleteLoadBalancer), + }) } -// Returns a formatted string of ELB names -func (balancer *LoadBalancers) getAll(c context.Context, configObj config.Config) ([]*string, error) { - result, err := balancer.Client.DescribeLoadBalancers(balancer.Context, &elasticloadbalancing.DescribeLoadBalancersInput{}) +// listLoadBalancers retrieves all Classic ELB load balancers that match the config filters. +func listLoadBalancers(ctx context.Context, client LoadBalancersAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + result, err := client.DescribeLoadBalancers(ctx, &elasticloadbalancing.DescribeLoadBalancersInput{}) if err != nil { return nil, errors.WithStackTrace(err) } var names []*string for _, balancer := range result.LoadBalancerDescriptions { - if configObj.ELBv1.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Name: balancer.LoadBalancerName, Time: balancer.CreatedTime, }) { @@ -50,49 +60,39 @@ func (balancer *LoadBalancers) getAll(c context.Context, configObj config.Config return names, nil } -// Deletes all Elastic Load Balancers -func (balancer *LoadBalancers) nukeAll(names []*string) error { - if len(names) == 0 { - logging.Debugf("No Elastic Load Balancers to nuke in region %s", balancer.Region) - return nil +// deleteLoadBalancer deletes a single Classic ELB load balancer. +func deleteLoadBalancer(ctx context.Context, client LoadBalancersAPI, name *string) error { + _, err := client.DeleteLoadBalancer(ctx, &elasticloadbalancing.DeleteLoadBalancerInput{ + LoadBalancerName: name, + }) + if err != nil { + return err } - logging.Debugf("Deleting all Elastic Load Balancers in region %s", balancer.Region) - var deletedNames []*string - - for _, name := range names { - params := &elasticloadbalancing.DeleteLoadBalancerInput{ - LoadBalancerName: name, - } - - _, err := balancer.Client.DeleteLoadBalancer(balancer.Context, params) + return waitUntilElbDeleted(ctx, client, name) +} - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(name), - ResourceType: "Load Balancer (v1)", - Error: err, +// waitUntilElbDeleted waits until the ELB is deleted. +func waitUntilElbDeleted(ctx context.Context, client LoadBalancersAPI, name *string) error { + for i := 0; i < 30; i++ { + output, err := client.DescribeLoadBalancers(ctx, &elasticloadbalancing.DescribeLoadBalancersInput{ + LoadBalancerNames: []string{aws.ToString(name)}, + }) + // Any error from DescribeLoadBalancers (including AccessPointNotFound) means the ELB is deleted + if err != nil || len(output.LoadBalancerDescriptions) == 0 { + return nil //nolint:nilerr // Error here (e.g., AccessPointNotFound) indicates successful deletion } - report.Record(e) - if err != nil { - logging.Debugf("[Failed] %s", err) - } else { - deletedNames = append(deletedNames, name) - logging.Debugf("Deleted ELB: %s", *name) - } + time.Sleep(1 * time.Second) + logging.Debug("Waiting for ELB to be deleted") } - if len(deletedNames) > 0 { - err := balancer.waitUntilElbDeleted(&elasticloadbalancing.DescribeLoadBalancersInput{ - LoadBalancerNames: aws.ToStringSlice(deletedNames), - }) - if err != nil { - logging.Debugf("[Failed] %s", err) - return errors.WithStackTrace(err) - } - } + return ElbDeleteError{} +} + +// ElbDeleteError represents an error when deleting ELB. +type ElbDeleteError struct{} - logging.Debugf("[OK] %d Elastic Load Balancer(s) deleted in %s", len(deletedNames), balancer.Region) - return nil +func (e ElbDeleteError) Error() string { + return "ELB was not deleted" } diff --git a/aws/resources/elb_test.go b/aws/resources/elb_test.go index b96757e0..600d944c 100644 --- a/aws/resources/elb_test.go +++ b/aws/resources/elb_test.go @@ -10,11 +10,11 @@ import ( "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) type mockedLoadBalancers struct { - LoadBalancersAPI DescribeLoadBalancersOutput elasticloadbalancing.DescribeLoadBalancersOutput DeleteLoadBalancerOutput elasticloadbalancing.DeleteLoadBalancerOutput } @@ -33,18 +33,16 @@ func TestElb_GetAll(t *testing.T) { testName1 := "test-name-1" testName2 := "test-name-2" now := time.Now() - balancer := LoadBalancers{ - Client: mockedLoadBalancers{ - DescribeLoadBalancersOutput: elasticloadbalancing.DescribeLoadBalancersOutput{ - LoadBalancerDescriptions: []types.LoadBalancerDescription{ - { - LoadBalancerName: aws.String(testName1), - CreatedTime: aws.Time(now), - }, - { - LoadBalancerName: aws.String(testName2), - CreatedTime: aws.Time(now.Add(1)), - }, + mock := mockedLoadBalancers{ + DescribeLoadBalancersOutput: elasticloadbalancing.DescribeLoadBalancersOutput{ + LoadBalancerDescriptions: []types.LoadBalancerDescription{ + { + LoadBalancerName: aws.String(testName1), + CreatedTime: aws.Time(now), + }, + { + LoadBalancerName: aws.String(testName2), + CreatedTime: aws.Time(now.Add(1)), }, }, }, @@ -77,9 +75,7 @@ func TestElb_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := balancer.getAll(context.Background(), config.Config{ - ELBv1: tc.configObj, - }) + names, err := listLoadBalancers(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) @@ -89,12 +85,10 @@ func TestElb_GetAll(t *testing.T) { func TestElb_NukeAll(t *testing.T) { t.Parallel() - balancer := LoadBalancers{ - Client: mockedLoadBalancers{ - DeleteLoadBalancerOutput: elasticloadbalancing.DeleteLoadBalancerOutput{}, - }, + mock := mockedLoadBalancers{ + DeleteLoadBalancerOutput: elasticloadbalancing.DeleteLoadBalancerOutput{}, } - err := balancer.nukeAll([]*string{aws.String("test-arn-1"), aws.String("test-arn-2")}) + err := deleteLoadBalancer(context.Background(), mock, aws.String("test-arn-1")) require.NoError(t, err) } diff --git a/aws/resources/elb_types.go b/aws/resources/elb_types.go deleted file mode 100644 index e8ce2ec8..00000000 --- a/aws/resources/elb_types.go +++ /dev/null @@ -1,71 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type LoadBalancersAPI interface { - DescribeLoadBalancers(ctx context.Context, params *elasticloadbalancing.DescribeLoadBalancersInput, optFns ...func(*elasticloadbalancing.Options)) (*elasticloadbalancing.DescribeLoadBalancersOutput, error) - DeleteLoadBalancer(ctx context.Context, params *elasticloadbalancing.DeleteLoadBalancerInput, optFns ...func(*elasticloadbalancing.Options)) (*elasticloadbalancing.DeleteLoadBalancerOutput, error) -} - -// LoadBalancers - represents all load balancers -type LoadBalancers struct { - BaseAwsResource - Client LoadBalancersAPI - Region string - Names []string -} - -func (balancer *LoadBalancers) Init(cfg aws.Config) { - balancer.Client = elasticloadbalancing.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (balancer *LoadBalancers) ResourceName() string { - return "elb" -} - -// ResourceIdentifiers - The names of the load balancers -func (balancer *LoadBalancers) ResourceIdentifiers() []string { - return balancer.Names -} - -func (balancer *LoadBalancers) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (balancer *LoadBalancers) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.ELBv1 -} - -func (balancer *LoadBalancers) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := balancer.getAll(c, configObj) - if err != nil { - return nil, err - } - - balancer.Names = aws.ToStringSlice(identifiers) - return balancer.Names, nil -} - -// Nuke - nuke 'em all!!! -func (balancer *LoadBalancers) Nuke(ctx context.Context, identifiers []string) error { - if err := balancer.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} - -type ElbDeleteError struct{} - -func (e ElbDeleteError) Error() string { - return "ELB was not deleted" -} diff --git a/aws/resources/elbv2.go b/aws/resources/elbv2.go index 1f4b740a..a15cb283 100644 --- a/aws/resources/elbv2.go +++ b/aws/resources/elbv2.go @@ -2,25 +2,54 @@ package resources import ( "context" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" ) -// Returns a formatted string of ELBv2 ARNs -func (balancer *LoadBalancersV2) getAll(c context.Context, configObj config.Config) ([]*string, error) { - result, err := balancer.Client.DescribeLoadBalancers(balancer.Context, &elasticloadbalancingv2.DescribeLoadBalancersInput{}) +// LoadBalancersV2API defines the interface for ELBv2 operations. +type LoadBalancersV2API interface { + DescribeLoadBalancers(ctx context.Context, params *elasticloadbalancingv2.DescribeLoadBalancersInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeLoadBalancersOutput, error) + DeleteLoadBalancer(ctx context.Context, params *elasticloadbalancingv2.DeleteLoadBalancerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteLoadBalancerOutput, error) +} + +// NewLoadBalancersV2 creates a new LoadBalancersV2 resource using the generic resource pattern. +func NewLoadBalancersV2() AwsResource { + return NewAwsResource(&resource.Resource[LoadBalancersV2API]{ + ResourceTypeName: "elbv2", + BatchSize: 49, + InitClient: func(r *resource.Resource[LoadBalancersV2API], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for LoadBalancersV2 client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = elasticloadbalancingv2.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.ELBv2 + }, + Lister: listLoadBalancersV2, + Nuker: resource.SequentialDeleter(resource.DeleteThenWait(deleteLoadBalancerV2, waitForLoadBalancerV2Deleted)), + }) +} + +// listLoadBalancersV2 retrieves all ELBv2 load balancers that match the config filters. +func listLoadBalancersV2(ctx context.Context, client LoadBalancersV2API, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + result, err := client.DescribeLoadBalancers(ctx, &elasticloadbalancingv2.DescribeLoadBalancersInput{}) if err != nil { return nil, errors.WithStackTrace(err) } var arns []*string for _, balancer := range result.LoadBalancers { - if configObj.ELBv2.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Name: balancer.LoadBalancerName, Time: balancer.CreatedTime, }) { @@ -31,50 +60,18 @@ func (balancer *LoadBalancersV2) getAll(c context.Context, configObj config.Conf return arns, nil } -// Deletes all Elastic Load Balancers -func (balancer *LoadBalancersV2) nukeAll(arns []*string) error { - if len(arns) == 0 { - logging.Debugf("No V2 Elastic Load Balancers to nuke in region %s", balancer.Region) - return nil - } - - logging.Debugf("Deleting all V2 Elastic Load Balancers in region %s", balancer.Region) - var deletedArns []*string - - for _, arn := range arns { - params := &elasticloadbalancingv2.DeleteLoadBalancerInput{ - LoadBalancerArn: arn, - } - - _, err := balancer.Client.DeleteLoadBalancer(balancer.Context, params) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(arn), - ResourceType: "Load Balancer (v2)", - Error: err, - } - report.Record(e) - - if err != nil { - logging.Debugf("[Failed] %s", err) - } else { - deletedArns = append(deletedArns, arn) - logging.Debugf("Deleted ELBv2: %s", *arn) - } - } - - if len(deletedArns) > 0 { - waiter := elasticloadbalancingv2.NewLoadBalancersDeletedWaiter(balancer.Client) - err := waiter.Wait(balancer.Context, &elasticloadbalancingv2.DescribeLoadBalancersInput{ - LoadBalancerArns: aws.ToStringSlice(deletedArns), - }, balancer.Timeout) - if err != nil { - logging.Debugf("[Failed] %s", err) - return errors.WithStackTrace(err) - } - } +// deleteLoadBalancerV2 deletes a single ELBv2 load balancer. +func deleteLoadBalancerV2(ctx context.Context, client LoadBalancersV2API, arn *string) error { + _, err := client.DeleteLoadBalancer(ctx, &elasticloadbalancingv2.DeleteLoadBalancerInput{ + LoadBalancerArn: arn, + }) + return err +} - logging.Debugf("[OK] %d V2 Elastic Load Balancer(s) deleted in %s", len(deletedArns), balancer.Region) - return nil +// waitForLoadBalancerV2Deleted waits for an ELBv2 load balancer to be deleted. +func waitForLoadBalancerV2Deleted(ctx context.Context, client LoadBalancersV2API, arn *string) error { + waiter := elasticloadbalancingv2.NewLoadBalancersDeletedWaiter(client) + return waiter.Wait(ctx, &elasticloadbalancingv2.DescribeLoadBalancersInput{ + LoadBalancerArns: []string{aws.ToString(arn)}, + }, 5*time.Minute) } diff --git a/aws/resources/elbv2_test.go b/aws/resources/elbv2_test.go index db031273..37c4308e 100644 --- a/aws/resources/elbv2_test.go +++ b/aws/resources/elbv2_test.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/aws/smithy-go" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) @@ -37,20 +38,18 @@ func TestElbV2_GetAll(t *testing.T) { testName2 := "test-name-2" testArn2 := "test-arn-2" now := time.Now() - balancer := LoadBalancersV2{ - Client: mockedElbV2{ - DescribeLoadBalancersOutput: elasticloadbalancingv2.DescribeLoadBalancersOutput{ - LoadBalancers: []types.LoadBalancer{ - { - LoadBalancerArn: aws.String(testArn1), - LoadBalancerName: aws.String(testName1), - CreatedTime: aws.Time(now), - }, - { - LoadBalancerArn: aws.String(testArn2), - LoadBalancerName: aws.String(testName2), - CreatedTime: aws.Time(now.Add(1)), - }, + mock := mockedElbV2{ + DescribeLoadBalancersOutput: elasticloadbalancingv2.DescribeLoadBalancersOutput{ + LoadBalancers: []types.LoadBalancer{ + { + LoadBalancerArn: aws.String(testArn1), + LoadBalancerName: aws.String(testName1), + CreatedTime: aws.Time(now), + }, + { + LoadBalancerArn: aws.String(testArn2), + LoadBalancerName: aws.String(testName2), + CreatedTime: aws.Time(now.Add(1)), }, }, }, @@ -83,9 +82,7 @@ func TestElbV2_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := balancer.getAll(context.Background(), config.Config{ - ELBv2: tc.configObj, - }) + names, err := listLoadBalancersV2(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) @@ -113,18 +110,17 @@ func (e errMockLoadBalancerNotFound) ErrorFault() smithy.ErrorFault { func TestElbV2_NukeAll(t *testing.T) { t.Parallel() var eLBNotFound errMockLoadBalancerNotFound - balancer := LoadBalancersV2{ - BaseAwsResource: BaseAwsResource{ - Context: context.Background(), - Timeout: DefaultWaitTimeout, - }, - Client: mockedElbV2{ - DescribeLoadBalancersOutput: elasticloadbalancingv2.DescribeLoadBalancersOutput{}, - ErrDescribeLoadBalancersOutput: eLBNotFound, - DeleteLoadBalancerOutput: elasticloadbalancingv2.DeleteLoadBalancerOutput{}, - }, + mock := mockedElbV2{ + DescribeLoadBalancersOutput: elasticloadbalancingv2.DescribeLoadBalancersOutput{}, + ErrDescribeLoadBalancersOutput: eLBNotFound, + DeleteLoadBalancerOutput: elasticloadbalancingv2.DeleteLoadBalancerOutput{}, } - err := balancer.nukeAll([]*string{aws.String("test-arn-1")}) + // Test the delete function + err := deleteLoadBalancerV2(context.Background(), mock, aws.String("test-arn-1")) + require.NoError(t, err) + + // Test the wait function (returns nil because LoadBalancerNotFound means it's deleted) + err = waitForLoadBalancerV2Deleted(context.Background(), mock, aws.String("test-arn-1")) require.NoError(t, err) } diff --git a/aws/resources/elbv2_types.go b/aws/resources/elbv2_types.go deleted file mode 100644 index 0d0c93bc..00000000 --- a/aws/resources/elbv2_types.go +++ /dev/null @@ -1,65 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type LoadBalancersV2API interface { - DescribeLoadBalancers(ctx context.Context, params *elasticloadbalancingv2.DescribeLoadBalancersInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DescribeLoadBalancersOutput, error) - DeleteLoadBalancer(ctx context.Context, params *elasticloadbalancingv2.DeleteLoadBalancerInput, optFns ...func(*elasticloadbalancingv2.Options)) (*elasticloadbalancingv2.DeleteLoadBalancerOutput, error) -} - -// LoadBalancersV2 - represents all load balancers -type LoadBalancersV2 struct { - BaseAwsResource - Client LoadBalancersV2API - Region string - Arns []string -} - -func (balancer *LoadBalancersV2) Init(cfg aws.Config) { - balancer.Client = elasticloadbalancingv2.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (balancer *LoadBalancersV2) ResourceName() string { - return "elbv2" -} - -func (balancer *LoadBalancersV2) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -// ResourceIdentifiers - The arns of the load balancers -func (balancer *LoadBalancersV2) ResourceIdentifiers() []string { - return balancer.Arns -} - -func (balancer *LoadBalancersV2) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.ELBv2 -} - -func (balancer *LoadBalancersV2) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := balancer.getAll(c, configObj) - if err != nil { - return nil, err - } - - balancer.Arns = aws.ToStringSlice(identifiers) - return balancer.Arns, nil -} - -// Nuke - nuke 'em all!!! -func (balancer *LoadBalancersV2) Nuke(ctx context.Context, identifiers []string) error { - if err := balancer.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/kms_customer_key.go b/aws/resources/kms_customer_key.go index 8a825c5b..525ebea7 100644 --- a/aws/resources/kms_customer_key.go +++ b/aws/resources/kms_customer_key.go @@ -9,17 +9,68 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" - "github.com/hashicorp/go-multierror" ) -func (kck *KmsCustomerKeys) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// https://docs.aws.amazon.com/sdk-for-go/api/service/kms/#ScheduleKeyDeletionInput +// must be between 7 and 30, inclusive +const kmsRemovalWindow = 7 + +// Context key for passing IncludeUnaliasedKeys config +type kmsIncludeUnaliasedKeysKeyType struct{} + +var kmsIncludeUnaliasedKeysKey = kmsIncludeUnaliasedKeysKeyType{} + +// KmsCustomerKeysAPI defines the interface for KMS operations. +type KmsCustomerKeysAPI interface { + ListKeys(ctx context.Context, params *kms.ListKeysInput, optFns ...func(*kms.Options)) (*kms.ListKeysOutput, error) + ListAliases(ctx context.Context, params *kms.ListAliasesInput, optFns ...func(*kms.Options)) (*kms.ListAliasesOutput, error) + DescribeKey(ctx context.Context, params *kms.DescribeKeyInput, optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) + ScheduleKeyDeletion(ctx context.Context, params *kms.ScheduleKeyDeletionInput, optFns ...func(*kms.Options)) (*kms.ScheduleKeyDeletionOutput, error) +} + +// NewKmsCustomerKeys creates a new KMS Customer Keys resource using the generic resource pattern. +func NewKmsCustomerKeys() AwsResource { + return NewAwsResource(&resource.Resource[KmsCustomerKeysAPI]{ + ResourceTypeName: "kmscustomerkeys", + BatchSize: 49, + InitClient: func(r *resource.Resource[KmsCustomerKeysAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for KMS client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = kms.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.KMSCustomerKeys.ResourceType + }, + Lister: listKmsCustomerKeys, + Nuker: resource.SimpleBatchDeleter(deleteKmsCustomerKey), + }) +} + +// KmsCheckIncludeResult - structure used for results of evaluation +type KmsCheckIncludeResult struct { + KeyId string + Error error +} + +// listKmsCustomerKeys retrieves all KMS customer keys that match the config filters. +func listKmsCustomerKeys(ctx context.Context, client KmsCustomerKeysAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + // Try to get IncludeUnaliasedKeys from context (set by the outer caller if needed) + includeUnaliasedKeys := false + if val, ok := ctx.Value(kmsIncludeUnaliasedKeysKey).(bool); ok { + includeUnaliasedKeys = val + } + // Collect all keys in the account var keys []string - listKeysPaginator := kms.NewListKeysPaginator(kck.Client, &kms.ListKeysInput{}) + listKeysPaginator := kms.NewListKeysPaginator(client, &kms.ListKeysInput{}) for listKeysPaginator.HasMorePages() { - page, err := listKeysPaginator.NextPage(c) + page, err := listKeysPaginator.NextPage(ctx) if err != nil { return nil, errors.WithStackTrace(err) } @@ -31,9 +82,9 @@ func (kck *KmsCustomerKeys) getAll(c context.Context, configObj config.Config) ( // Collect key to alias mapping keyAliases := map[string][]string{} - listAliasesPaginator := kms.NewListAliasesPaginator(kck.Client, &kms.ListAliasesInput{}) + listAliasesPaginator := kms.NewListAliasesPaginator(client, &kms.ListAliasesInput{}) for listAliasesPaginator.HasMorePages() { - page, err := listAliasesPaginator.NextPage(c) + page, err := listAliasesPaginator.NextPage(ctx) if err != nil { return nil, errors.WithStackTrace(err) } @@ -66,13 +117,12 @@ func (kck *KmsCustomerKeys) getAll(c context.Context, configObj config.Config) ( // If the keyId isn't found in the map, this returns an empty array aliasesForKey := keyAliases[keyId] - go kck.shouldInclude(&wg, resultsChan[id], keyId, aliasesForKey, configObj) + go shouldIncludeKmsKey(ctx, client, &wg, resultsChan[id], keyId, aliasesForKey, cfg, includeUnaliasedKeys) id++ } wg.Wait() var kmsIds []*string - aliases := map[string][]string{} for _, channel := range resultsChan { result := <-channel @@ -82,41 +132,36 @@ func (kck *KmsCustomerKeys) getAll(c context.Context, configObj config.Config) ( continue } if result.KeyId != "" { - aliases[result.KeyId] = keyAliases[result.KeyId] kmsIds = append(kmsIds, &result.KeyId) } } - kck.KeyAliases = aliases return kmsIds, nil } -// KmsCheckIncludeResult - structure used results of evaluation: not null KeyId - key should be included -type KmsCheckIncludeResult struct { - KeyId string - Error error -} - -func (kck *KmsCustomerKeys) shouldInclude( +func shouldIncludeKmsKey( + ctx context.Context, + client KmsCustomerKeysAPI, wg *sync.WaitGroup, resultsChan chan *KmsCheckIncludeResult, key string, aliases []string, - configObj config.Config) { + cfg config.ResourceType, + includeUnaliasedKeys bool) { defer wg.Done() includedByName := false // verify if key aliases matches configurations for _, alias := range aliases { - if config.ShouldInclude(&alias, configObj.KMSCustomerKeys.IncludeRule.NamesRegExp, - configObj.KMSCustomerKeys.ExcludeRule.NamesRegExp) { + if config.ShouldInclude(&alias, cfg.IncludeRule.NamesRegExp, + cfg.ExcludeRule.NamesRegExp) { includedByName = true } } // Only delete keys without aliases if the user explicitly says so if len(aliases) == 0 { - if !configObj.KMSCustomerKeys.IncludeUnaliasedKeys { + if !includeUnaliasedKeys { resultsChan <- &KmsCheckIncludeResult{KeyId: ""} return } else { @@ -131,7 +176,7 @@ func (kck *KmsCustomerKeys) shouldInclude( return } // additional request to describe key and get information about creation date, removal status - details, err := kck.Client.DescribeKey(kck.Context, &kms.DescribeKeyInput{KeyId: &key}) + details, err := client.DescribeKey(ctx, &kms.DescribeKeyInput{KeyId: &key}) if err != nil { resultsChan <- &KmsCheckIncludeResult{Error: err} return @@ -153,7 +198,7 @@ func (kck *KmsCustomerKeys) shouldInclude( } // Check time-based filtering (name filtering was already done above) referenceTime := metadata.CreationDate - if !configObj.KMSCustomerKeys.ShouldIncludeBasedOnTime(*referenceTime) { + if !cfg.ShouldIncludeBasedOnTime(*referenceTime) { resultsChan <- &KmsCheckIncludeResult{KeyId: ""} return } @@ -161,69 +206,12 @@ func (kck *KmsCustomerKeys) shouldInclude( resultsChan <- &KmsCheckIncludeResult{KeyId: key} } -func (kck *KmsCustomerKeys) nukeAll(keyIds []*string) error { - if len(keyIds) == 0 { - logging.Debugf("No Customer Keys to nuke in region %s", kck.Region) - return nil - } - - // usage of go routines for parallel keys removal - // https://docs.aws.amazon.com/sdk-for-go/api/service/kms/#KMS.ScheduleKeyDeletion - logging.Debugf("Deleting Keys secrets in region %s", kck.Region) - wg := new(sync.WaitGroup) - wg.Add(len(keyIds)) - errChans := make([]chan error, len(keyIds)) - for i, secretID := range keyIds { - errChans[i] = make(chan error, 1) - go kck.requestKeyDeletion(wg, errChans[i], secretID) - } - wg.Wait() - - wgAlias := new(sync.WaitGroup) - wgAlias.Add(len(kck.KeyAliases)) - for _, aliases := range kck.KeyAliases { - go kck.deleteAliases(wgAlias, aliases) - } - wgAlias.Wait() - - // collect errors from each channel - var allErrs *multierror.Error - for _, errChan := range errChans { - if err := <-errChan; err != nil { - allErrs = multierror.Append(allErrs, err) - logging.Debugf("[Failed] %s", err) - } - } - return errors.WithStackTrace(allErrs.ErrorOrNil()) -} - -func (kck *KmsCustomerKeys) deleteAliases(wg *sync.WaitGroup, aliases []string) { - defer wg.Done() - - for _, aliasName := range aliases { - input := &kms.DeleteAliasInput{AliasName: &aliasName} - _, err := kck.Client.DeleteAlias(kck.Context, input) - - if err != nil { - logging.Errorf("[Failed] Failed deleting alias: %s", aliasName) - } else { - logging.Debugf("Deleted alias %s", aliasName) - } - } -} - -func (kck *KmsCustomerKeys) requestKeyDeletion(wg *sync.WaitGroup, errChan chan error, key *string) { - defer wg.Done() - input := &kms.ScheduleKeyDeletionInput{KeyId: key, PendingWindowInDays: aws.Int32(int32(kmsRemovalWindow))} - _, err := kck.Client.ScheduleKeyDeletion(kck.Context, input) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(key), - ResourceType: "Key Management Service (KMS) Key", - Error: err, - } - report.Record(e) - - errChan <- err +// deleteKmsCustomerKey schedules a single KMS customer key for deletion. +// AWS automatically deletes all aliases associated with a key when the key is deleted. +func deleteKmsCustomerKey(ctx context.Context, client KmsCustomerKeysAPI, keyId *string) error { + _, err := client.ScheduleKeyDeletion(ctx, &kms.ScheduleKeyDeletionInput{ + KeyId: keyId, + PendingWindowInDays: aws.Int32(int32(kmsRemovalWindow)), + }) + return err } diff --git a/aws/resources/kms_customer_key_test.go b/aws/resources/kms_customer_key_test.go index 8ac7da0f..ead82368 100644 --- a/aws/resources/kms_customer_key_test.go +++ b/aws/resources/kms_customer_key_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) @@ -18,7 +19,6 @@ type mockedKmsCustomerKeys struct { ListKeysOutput kms.ListKeysOutput ListAliasesOutput kms.ListAliasesOutput DescribeKeyOutput map[string]kms.DescribeKeyOutput - DeleteAliasOutput kms.DeleteAliasOutput ScheduleKeyDeletionOutput kms.ScheduleKeyDeletionOutput } @@ -36,10 +36,6 @@ func (m mockedKmsCustomerKeys) DescribeKey(ctx context.Context, params *kms.Desc return &output, nil } -func (m mockedKmsCustomerKeys) DeleteAlias(ctx context.Context, params *kms.DeleteAliasInput, optFns ...func(*kms.Options)) (*kms.DeleteAliasOutput, error) { - return &m.DeleteAliasOutput, nil -} - func (m mockedKmsCustomerKeys) ScheduleKeyDeletion(ctx context.Context, params *kms.ScheduleKeyDeletionInput, optFns ...func(*kms.Options)) (*kms.ScheduleKeyDeletionOutput, error) { return &m.ScheduleKeyDeletionOutput, nil } @@ -52,44 +48,42 @@ func TestKMS_GetAll(t *testing.T) { alias1 := "alias/key1" alias2 := "alias/key2" now := time.Now() - kck := KmsCustomerKeys{ - Client: mockedKmsCustomerKeys{ - ListKeysOutput: kms.ListKeysOutput{ - Keys: []types.KeyListEntry{ - { - KeyId: aws.String(key1), - }, - { - KeyId: aws.String(key2), - }, + mockClient := mockedKmsCustomerKeys{ + ListKeysOutput: kms.ListKeysOutput{ + Keys: []types.KeyListEntry{ + { + KeyId: aws.String(key1), + }, + { + KeyId: aws.String(key2), }, }, - ListAliasesOutput: kms.ListAliasesOutput{ - Aliases: []types.AliasListEntry{ - { - AliasName: aws.String(alias1), - TargetKeyId: aws.String(key1), - }, - { - AliasName: aws.String(alias2), - TargetKeyId: aws.String(key2), - }, + }, + ListAliasesOutput: kms.ListAliasesOutput{ + Aliases: []types.AliasListEntry{ + { + AliasName: aws.String(alias1), + TargetKeyId: aws.String(key1), + }, + { + AliasName: aws.String(alias2), + TargetKeyId: aws.String(key2), }, }, - DescribeKeyOutput: map[string]kms.DescribeKeyOutput{ - key1: { - KeyMetadata: &types.KeyMetadata{ - KeyId: aws.String(key1), - KeyManager: types.KeyManagerTypeCustomer, - CreationDate: aws.Time(now), - }, + }, + DescribeKeyOutput: map[string]kms.DescribeKeyOutput{ + key1: { + KeyMetadata: &types.KeyMetadata{ + KeyId: aws.String(key1), + KeyManager: types.KeyManagerTypeCustomer, + CreationDate: aws.Time(now), }, - key2: { - KeyMetadata: &types.KeyMetadata{ - KeyId: aws.String(key2), - KeyManager: types.KeyManagerTypeCustomer, - CreationDate: aws.Time(now.Add(1)), - }, + }, + key2: { + KeyMetadata: &types.KeyMetadata{ + KeyId: aws.String(key2), + KeyManager: types.KeyManagerTypeCustomer, + CreationDate: aws.Time(now.Add(1)), }, }, }, @@ -137,9 +131,7 @@ func TestKMS_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := kck.getAll(context.Background(), config.Config{ - KMSCustomerKeys: tc.configObj, - }) + names, err := listKmsCustomerKeys(context.Background(), mockClient, resource.Scope{Region: "us-east-1"}, tc.configObj.ResourceType) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) @@ -153,37 +145,35 @@ func TestKMS_GetAll_IncludeUnaliased(t *testing.T) { key2 := "key-without-alias" alias1 := "alias/my-key" now := time.Now() - kck := KmsCustomerKeys{ - Client: mockedKmsCustomerKeys{ - ListKeysOutput: kms.ListKeysOutput{ - Keys: []types.KeyListEntry{ - {KeyId: aws.String(key1)}, - {KeyId: aws.String(key2)}, - }, + mockClient := mockedKmsCustomerKeys{ + ListKeysOutput: kms.ListKeysOutput{ + Keys: []types.KeyListEntry{ + {KeyId: aws.String(key1)}, + {KeyId: aws.String(key2)}, }, - ListAliasesOutput: kms.ListAliasesOutput{ - Aliases: []types.AliasListEntry{ - { - AliasName: aws.String(alias1), - TargetKeyId: aws.String(key1), - }, - // key2 has no alias + }, + ListAliasesOutput: kms.ListAliasesOutput{ + Aliases: []types.AliasListEntry{ + { + AliasName: aws.String(alias1), + TargetKeyId: aws.String(key1), }, + // key2 has no alias }, - DescribeKeyOutput: map[string]kms.DescribeKeyOutput{ - key1: { - KeyMetadata: &types.KeyMetadata{ - KeyId: aws.String(key1), - KeyManager: types.KeyManagerTypeCustomer, - CreationDate: aws.Time(now), - }, + }, + DescribeKeyOutput: map[string]kms.DescribeKeyOutput{ + key1: { + KeyMetadata: &types.KeyMetadata{ + KeyId: aws.String(key1), + KeyManager: types.KeyManagerTypeCustomer, + CreationDate: aws.Time(now), }, - key2: { - KeyMetadata: &types.KeyMetadata{ - KeyId: aws.String(key2), - KeyManager: types.KeyManagerTypeCustomer, - CreationDate: aws.Time(now), - }, + }, + key2: { + KeyMetadata: &types.KeyMetadata{ + KeyId: aws.String(key2), + KeyManager: types.KeyManagerTypeCustomer, + CreationDate: aws.Time(now), }, }, }, @@ -209,9 +199,9 @@ func TestKMS_GetAll_IncludeUnaliased(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := kck.getAll(context.Background(), config.Config{ - KMSCustomerKeys: tc.configObj, - }) + // Pass IncludeUnaliasedKeys via context + ctx := context.WithValue(context.Background(), kmsIncludeUnaliasedKeysKey, tc.configObj.IncludeUnaliasedKeys) + names, err := listKmsCustomerKeys(ctx, mockClient, resource.Scope{Region: "us-east-1"}, tc.configObj.ResourceType) require.NoError(t, err) require.ElementsMatch(t, tc.expected, aws.ToStringSlice(names)) }) @@ -221,13 +211,15 @@ func TestKMS_GetAll_IncludeUnaliased(t *testing.T) { func TestKMS_NukeAll(t *testing.T) { t.Parallel() - kck := KmsCustomerKeys{ - Client: mockedKmsCustomerKeys{ - DeleteAliasOutput: kms.DeleteAliasOutput{}, - ScheduleKeyDeletionOutput: kms.ScheduleKeyDeletionOutput{}, - }, + mockClient := mockedKmsCustomerKeys{ + ScheduleKeyDeletionOutput: kms.ScheduleKeyDeletionOutput{}, } - err := kck.nukeAll([]*string{aws.String("key1"), aws.String("key2")}) + // Test deleting a single key + err := deleteKmsCustomerKey(context.Background(), mockClient, aws.String("key1")) + require.NoError(t, err) + + // Test deleting another key + err = deleteKmsCustomerKey(context.Background(), mockClient, aws.String("key2")) require.NoError(t, err) } diff --git a/aws/resources/kms_customer_key_types.go b/aws/resources/kms_customer_key_types.go deleted file mode 100644 index 1a9801ea..00000000 --- a/aws/resources/kms_customer_key_types.go +++ /dev/null @@ -1,68 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/kms" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -// https://docs.aws.amazon.com/sdk-for-go/api/service/kms/#ScheduleKeyDeletionInput -// must be between 7 and 30, inclusive -const kmsRemovalWindow = 7 - -type KmsCustomerKeysAPI interface { - ListKeys(ctx context.Context, params *kms.ListKeysInput, optFns ...func(*kms.Options)) (*kms.ListKeysOutput, error) - ListAliases(ctx context.Context, params *kms.ListAliasesInput, optFns ...func(*kms.Options)) (*kms.ListAliasesOutput, error) - DescribeKey(ctx context.Context, params *kms.DescribeKeyInput, optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) - DeleteAlias(ctx context.Context, params *kms.DeleteAliasInput, optFns ...func(*kms.Options)) (*kms.DeleteAliasOutput, error) - ScheduleKeyDeletion(ctx context.Context, params *kms.ScheduleKeyDeletionInput, optFns ...func(*kms.Options)) (*kms.ScheduleKeyDeletionOutput, error) -} - -type KmsCustomerKeys struct { - BaseAwsResource - Client KmsCustomerKeysAPI - Region string - KeyIds []string - KeyAliases map[string][]string -} - -func (kck *KmsCustomerKeys) Init(cfg aws.Config) { - kck.Client = kms.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (kck *KmsCustomerKeys) ResourceName() string { - return "kmscustomerkeys" -} - -// ResourceIdentifiers - The KMS Key IDs -func (kck *KmsCustomerKeys) ResourceIdentifiers() []string { - return kck.KeyIds -} - -// MaxBatchSize - Requests batch size -func (kck *KmsCustomerKeys) MaxBatchSize() int { - return 49 -} - -func (kck *KmsCustomerKeys) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := kck.getAll(c, configObj) - if err != nil { - return nil, err - } - - kck.KeyIds = aws.ToStringSlice(identifiers) - return kck.KeyIds, nil -} - -// Nuke - remove all customer managed keys -func (kck *KmsCustomerKeys) Nuke(ctx context.Context, keyIds []string) error { - if err := kck.nukeAll(aws.StringSlice(keyIds)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/lambda.go b/aws/resources/lambda.go index dba6e8be..08c03b49 100644 --- a/aws/resources/lambda.go +++ b/aws/resources/lambda.go @@ -9,23 +9,52 @@ import ( "github.com/aws/aws-sdk-go-v2/service/lambda/types" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" - "github.com/gruntwork-io/go-commons/errors" + "github.com/gruntwork-io/cloud-nuke/resource" ) -func (lf *LambdaFunctions) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// LambdaFunctionsAPI defines the interface for Lambda operations. +type LambdaFunctionsAPI interface { + DeleteFunction(ctx context.Context, params *lambda.DeleteFunctionInput, optFns ...func(*lambda.Options)) (*lambda.DeleteFunctionOutput, error) + ListFunctions(ctx context.Context, params *lambda.ListFunctionsInput, optFns ...func(*lambda.Options)) (*lambda.ListFunctionsOutput, error) + ListTags(ctx context.Context, params *lambda.ListTagsInput, optFns ...func(*lambda.Options)) (*lambda.ListTagsOutput, error) +} + +// NewLambdaFunctions creates a new Lambda Functions resource using the generic resource pattern. +func NewLambdaFunctions() AwsResource { + return NewAwsResource(&resource.Resource[LambdaFunctionsAPI]{ + ResourceTypeName: "lambda", + BatchSize: 49, + InitClient: func(r *resource.Resource[LambdaFunctionsAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for Lambda client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = lambda.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.LambdaFunction + }, + Lister: listLambdaFunctions, + Nuker: resource.SimpleBatchDeleter(deleteLambdaFunction), + }) +} + +// listLambdaFunctions retrieves all Lambda functions that match the config filters. +func listLambdaFunctions(ctx context.Context, client LambdaFunctionsAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { var names []*string - paginator := lambda.NewListFunctionsPaginator(lf.Client, &lambda.ListFunctionsInput{}) + paginator := lambda.NewListFunctionsPaginator(client, &lambda.ListFunctionsInput{}) for paginator.HasMorePages() { - page, err := paginator.NextPage(c) + page, err := paginator.NextPage(ctx) if err != nil { - return nil, errors.WithStackTrace(err) + return nil, err } - for _, name := range page.Functions { - if lf.shouldInclude(&name, configObj) { - names = append(names, name.FunctionName) + for _, fn := range page.Functions { + if shouldIncludeLambdaFunction(ctx, client, &fn, cfg) { + names = append(names, fn.FunctionName) } } } @@ -33,7 +62,8 @@ func (lf *LambdaFunctions) getAll(c context.Context, configObj config.Config) ([ return names, nil } -func (lf *LambdaFunctions) shouldInclude(lambdaFn *types.FunctionConfiguration, configObj config.Config) bool { +// shouldIncludeLambdaFunction determines if a Lambda function should be included for deletion. +func shouldIncludeLambdaFunction(ctx context.Context, client LambdaFunctionsAPI, lambdaFn *types.FunctionConfiguration, cfg config.ResourceType) bool { if lambdaFn == nil { return false } @@ -50,50 +80,27 @@ func (lf *LambdaFunctions) shouldInclude(lambdaFn *types.FunctionConfiguration, params := &lambda.ListTagsInput{ Resource: lambdaFn.FunctionArn, } - tagsOutput, err := lf.Client.ListTags(lf.Context, params) + tagsOutput, err := client.ListTags(ctx, params) if err != nil { logging.Errorf("failed to list tags for %s: %s", aws.ToString(lambdaFn.FunctionArn), err) } - return configObj.LambdaFunction.ShouldInclude(config.ResourceValue{ + var tags map[string]string + if tagsOutput != nil { + tags = tagsOutput.Tags + } + + return cfg.ShouldInclude(config.ResourceValue{ Time: &lastModifiedDateTime, Name: fnName, - Tags: tagsOutput.Tags, + Tags: tags, }) } -func (lf *LambdaFunctions) nukeAll(names []*string) error { - if len(names) == 0 { - logging.Debugf("No Lambda Functions to nuke in region %s", lf.Region) - return nil - } - - logging.Debugf("Deleting all Lambda Functions in region %s", lf.Region) - var deletedNames []*string - - for _, name := range names { - params := &lambda.DeleteFunctionInput{ - FunctionName: name, - } - - _, err := lf.Client.DeleteFunction(lf.Context, params) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(name), - ResourceType: "Lambda function", - Error: err, - } - report.Record(e) - - if err != nil { - logging.Errorf("[Failed] %s: %s", *name, err) - } else { - deletedNames = append(deletedNames, name) - logging.Debugf("Deleted Lambda Function: %s", aws.ToString(name)) - } - } - - logging.Debugf("[OK] %d Lambda Function(s) deleted in %s", len(deletedNames), lf.Region) - return nil +// deleteLambdaFunction deletes a single Lambda function. +func deleteLambdaFunction(ctx context.Context, client LambdaFunctionsAPI, name *string) error { + _, err := client.DeleteFunction(ctx, &lambda.DeleteFunctionInput{ + FunctionName: name, + }) + return err } diff --git a/aws/resources/lambda_layer.go b/aws/resources/lambda_layer.go index 6bb7c3ea..511fd30a 100644 --- a/aws/resources/lambda_layer.go +++ b/aws/resources/lambda_layer.go @@ -2,6 +2,9 @@ package resources import ( "context" + "fmt" + "strconv" + "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -9,56 +12,82 @@ import ( "github.com/aws/aws-sdk-go-v2/service/lambda/types" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" ) -func (ll *LambdaLayers) getAll(c context.Context, configObj config.Config) ([]*string, error) { - var layers []types.LayersListItem - var names []*string +// LambdaLayersAPI defines the interface for Lambda layer operations. +type LambdaLayersAPI interface { + DeleteLayerVersion(ctx context.Context, params *lambda.DeleteLayerVersionInput, optFns ...func(*lambda.Options)) (*lambda.DeleteLayerVersionOutput, error) + ListLayers(ctx context.Context, params *lambda.ListLayersInput, optFns ...func(*lambda.Options)) (*lambda.ListLayersOutput, error) + ListLayerVersions(ctx context.Context, params *lambda.ListLayerVersionsInput, optFns ...func(*lambda.Options)) (*lambda.ListLayerVersionsOutput, error) +} - paginator := lambda.NewListLayersPaginator(ll.Client, &lambda.ListLayersInput{}) +// NewLambdaLayers creates a new LambdaLayers resource using the generic resource pattern. +func NewLambdaLayers() AwsResource { + return NewAwsResource(&resource.Resource[LambdaLayersAPI]{ + ResourceTypeName: "lambda_layer", + BatchSize: 49, + InitClient: func(r *resource.Resource[LambdaLayersAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for LambdaLayers client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = lambda.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.LambdaLayer + }, + Lister: listLambdaLayers, + Nuker: resource.SimpleBatchDeleter(deleteLambdaLayerVersion), + }) +} + +// listLambdaLayers retrieves all Lambda layers that match the config filters. +// Returns composite identifiers in the format "layerName:versionNumber" for each layer version. +func listLambdaLayers(ctx context.Context, client LambdaLayersAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + var identifiers []*string + + paginator := lambda.NewListLayersPaginator(client, &lambda.ListLayersInput{}) for paginator.HasMorePages() { - page, err := paginator.NextPage(c) + page, err := paginator.NextPage(ctx) if err != nil { return nil, errors.WithStackTrace(err) } for _, layer := range page.Layers { - logging.Logger.Debugf("Found layer! %s", *layer.LayerName) - - if ll.shouldInclude(&layer, configObj) { - layers = append(layers, layer) - } - } - } - - for _, layer := range layers { + logging.Debugf("Found layer: %s", *layer.LayerName) - versionsPaginator := lambda.NewListLayerVersionsPaginator(ll.Client, &lambda.ListLayerVersionsInput{ - LayerName: layer.LayerName, - }) - for versionsPaginator.HasMorePages() { - page, err := versionsPaginator.NextPage(c) - if err != nil { - return nil, errors.WithStackTrace(err) + if !shouldIncludeLambdaLayer(&layer, cfg) { + continue } - for _, version := range page.LayerVersions { - logging.Logger.Debugf("Found layer version! %d", version.Version) + // List all versions for this layer + versionsPaginator := lambda.NewListLayerVersionsPaginator(client, &lambda.ListLayerVersionsInput{ + LayerName: layer.LayerName, + }) + for versionsPaginator.HasMorePages() { + versionsPage, err := versionsPaginator.NextPage(ctx) + if err != nil { + return nil, errors.WithStackTrace(err) + } - // Currently the output is just the identifier which is the layer's name. - // There could be potentially multiple rows of the same identifier or - // layer name since there can be multiple versions of it. - names = append(names, layer.LayerName) + for _, version := range versionsPage.LayerVersions { + // Create composite identifier: layerName:versionNumber + id := fmt.Sprintf("%s:%d", *layer.LayerName, version.Version) + logging.Debugf("Found layer version: %s", id) + identifiers = append(identifiers, aws.String(id)) + } } } } - return names, nil + return identifiers, nil } -func (ll *LambdaLayers) shouldInclude(lambdaLayer *types.LayersListItem, configObj config.Config) bool { +func shouldIncludeLambdaLayer(lambdaLayer *types.LayersListItem, cfg config.ResourceType) bool { if lambdaLayer == nil { return false } @@ -70,67 +99,33 @@ func (ll *LambdaLayers) shouldInclude(lambdaLayer *types.LayersListItem, configO layout := "2006-01-02T15:04:05.000+0000" lastModifiedDateTime, err := time.Parse(layout, fnLastModified) if err != nil { - logging.Logger.Debugf("Could not parse last modified timestamp (%s) of Lambda layer %s. Excluding from delete.", fnLastModified, *fnName) + logging.Debugf("Could not parse last modified timestamp (%s) of Lambda layer %s. Excluding from delete.", fnLastModified, *fnName) return false } - return configObj.LambdaLayer.ShouldInclude(config.ResourceValue{ + return cfg.ShouldInclude(config.ResourceValue{ Time: &lastModifiedDateTime, Name: fnName, }) } -func (ll *LambdaLayers) nukeAll(names []*string) error { - if len(names) == 0 { - logging.Logger.Debugf("No Lambda Layers to nuke in region %s", ll.Region) - return nil - } - - logging.Logger.Debugf("Deleting all Lambda Layers in region %s", ll.Region) - var deletedNames []*string - var deleteLayerVersions []*lambda.DeleteLayerVersionInput - - for _, name := range names { - paginator := lambda.NewListLayerVersionsPaginator(ll.Client, &lambda.ListLayerVersionsInput{ - LayerName: name, - }) - for paginator.HasMorePages() { - page, err := paginator.NextPage(ll.Context) - if err != nil { - return errors.WithStackTrace(err) - } - - for _, version := range page.LayerVersions { - logging.Logger.Debugf("Found layer version! %s", *version.LayerVersionArn) - params := &lambda.DeleteLayerVersionInput{ - LayerName: name, - VersionNumber: aws.Int64(version.Version), - } - deleteLayerVersions = append(deleteLayerVersions, params) - } - } +// deleteLambdaLayerVersion deletes a single Lambda layer version. +// The id parameter is a composite identifier in the format "layerName:versionNumber". +func deleteLambdaLayerVersion(ctx context.Context, client LambdaLayersAPI, id *string) error { + parts := strings.SplitN(aws.ToString(id), ":", 2) + if len(parts) != 2 { + return fmt.Errorf("invalid layer version identifier: %s", aws.ToString(id)) } - for _, params := range deleteLayerVersions { - - _, err := ll.Client.DeleteLayerVersion(ll.Context, params) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(params.LayerName), - ResourceType: "Lambda layer", - Error: err, - } - report.Record(e) - - if err != nil { - logging.Logger.Errorf("[Failed] %s: %s", *params.LayerName, err) - } else { - deletedNames = append(deletedNames, params.LayerName) - logging.Logger.Debugf("Deleted Lambda Layer: %s", aws.ToString(params.LayerName)) - } + layerName := parts[0] + versionNum, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return fmt.Errorf("invalid version number in identifier %s: %w", aws.ToString(id), err) } - logging.Logger.Debugf("[OK] %d Lambda Layer(s) deleted in %s", len(deletedNames), ll.Region) - return nil + _, err = client.DeleteLayerVersion(ctx, &lambda.DeleteLayerVersionInput{ + LayerName: aws.String(layerName), + VersionNumber: aws.Int64(versionNum), + }) + return err } diff --git a/aws/resources/lambda_layer_test.go b/aws/resources/lambda_layer_test.go index bdfc075b..3c7c9543 100644 --- a/aws/resources/lambda_layer_test.go +++ b/aws/resources/lambda_layer_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/lambda" "github.com/aws/aws-sdk-go-v2/service/lambda/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) @@ -36,40 +37,34 @@ func TestLambdaLayer_GetAll(t *testing.T) { t.Parallel() testName1 := "test-lambda-layer1" - testName1Version1 := int64(1) - testName2 := "test-lambda-layer2" - testTime := time.Now() - layout := "2006-01-02T15:04:05.000+0000" testTimeStr := "2023-07-28T12:34:56.789+0000" testTime, err := time.Parse(layout, testTimeStr) require.NoError(t, err) - ll := LambdaLayers{ - Client: mockedLambdaLayer{ - ListLayersOutput: lambda.ListLayersOutput{ - Layers: []types.LayersListItem{ - { - LayerName: aws.String(testName1), - LatestMatchingVersion: &types.LayerVersionsListItem{ - CreatedDate: aws.String(testTimeStr), - }, + client := mockedLambdaLayer{ + ListLayersOutput: lambda.ListLayersOutput{ + Layers: []types.LayersListItem{ + { + LayerName: aws.String(testName1), + LatestMatchingVersion: &types.LayerVersionsListItem{ + CreatedDate: aws.String(testTimeStr), }, - { - LayerName: aws.String(testName2), - LatestMatchingVersion: &types.LayerVersionsListItem{ - CreatedDate: aws.String(testTimeStr), - }, + }, + { + LayerName: aws.String(testName2), + LatestMatchingVersion: &types.LayerVersionsListItem{ + CreatedDate: aws.String(testTimeStr), }, }, }, - ListLayerVersionsOutput: lambda.ListLayerVersionsOutput{ - LayerVersions: []types.LayerVersionsListItem{ - { - Version: testName1Version1, - }, + }, + ListLayerVersionsOutput: lambda.ListLayerVersionsOutput{ + LayerVersions: []types.LayerVersionsListItem{ + { + Version: 1, }, }, }, @@ -81,7 +76,8 @@ func TestLambdaLayer_GetAll(t *testing.T) { }{ "emptyFilter": { configObj: config.ResourceType{}, - expected: []string{testName1, testName2}, + // Each layer has one version (1), so identifiers are "layerName:1" + expected: []string{testName1 + ":1", testName2 + ":1"}, }, "nameExclusionFilter": { configObj: config.ResourceType{ @@ -90,7 +86,7 @@ func TestLambdaLayer_GetAll(t *testing.T) { RE: *regexp.MustCompile(testName1), }}}, }, - expected: []string{testName2}, + expected: []string{testName2 + ":1"}, }, "timeAfterExclusionFilter": { configObj: config.ResourceType{ @@ -102,9 +98,7 @@ func TestLambdaLayer_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := ll.getAll(context.Background(), config.Config{ - LambdaLayer: tc.configObj, - }) + names, err := listLambdaLayers(context.Background(), client, resource.Scope{Region: "us-east-1"}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) @@ -114,12 +108,25 @@ func TestLambdaLayer_GetAll(t *testing.T) { func TestLambdaLayer_NukeAll(t *testing.T) { t.Parallel() - ll := LambdaLayers{ - Client: mockedLambdaLayer{ - DeleteLayerVersionOutput: lambda.DeleteLayerVersionOutput{}, - }, + client := mockedLambdaLayer{ + DeleteLayerVersionOutput: lambda.DeleteLayerVersionOutput{}, } - err := ll.nukeAll([]*string{aws.String("test")}) + // Test deleting a layer version with composite identifier + err := deleteLambdaLayerVersion(context.Background(), client, aws.String("test-layer:1")) + require.NoError(t, err) + + // Test with multiple versions + err = deleteLambdaLayerVersion(context.Background(), client, aws.String("test-layer:2")) require.NoError(t, err) + + // Test with invalid identifier (no colon) + err = deleteLambdaLayerVersion(context.Background(), client, aws.String("invalid-identifier")) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid layer version identifier") + + // Test with invalid version number + err = deleteLambdaLayerVersion(context.Background(), client, aws.String("test-layer:invalid")) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid version number") } diff --git a/aws/resources/lambda_layer_types.go b/aws/resources/lambda_layer_types.go deleted file mode 100644 index b1d60ed9..00000000 --- a/aws/resources/lambda_layer_types.go +++ /dev/null @@ -1,72 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/lambda" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type LambdaLayersAPI interface { - DeleteLayerVersion(ctx context.Context, params *lambda.DeleteLayerVersionInput, optFns ...func(*lambda.Options)) (*lambda.DeleteLayerVersionOutput, error) - ListLayers(ctx context.Context, params *lambda.ListLayersInput, optFns ...func(*lambda.Options)) (*lambda.ListLayersOutput, error) - ListLayerVersions(ctx context.Context, params *lambda.ListLayerVersionsInput, optFns ...func(*lambda.Options)) (*lambda.ListLayerVersionsOutput, error) -} - -type LambdaLayers struct { - BaseAwsResource - Client LambdaLayersAPI - Region string - LambdaFunctionNames []string -} - -func (ll *LambdaLayers) Init(cfg aws.Config) { - ll.Client = lambda.NewFromConfig(cfg) -} - -func (ll *LambdaLayers) ResourceName() string { - return "lambda_layer" -} - -// ResourceIdentifiers - The names of the lambda functions -func (ll *LambdaLayers) ResourceIdentifiers() []string { - return ll.LambdaFunctionNames -} - -func (ll *LambdaLayers) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (ll *LambdaLayers) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.LambdaLayer -} - -func (ll *LambdaLayers) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := ll.getAll(c, configObj) - if err != nil { - return nil, err - } - - ll.LambdaFunctionNames = aws.ToStringSlice(identifiers) - return ll.LambdaFunctionNames, nil -} - -// Nuke - nuke 'em all!!! -func (ll *LambdaLayers) Nuke(ctx context.Context, identifiers []string) error { - if err := ll.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} - -type LambdaVersionDeleteError struct { - name string -} - -func (e LambdaVersionDeleteError) Error() string { - return "Lambda Function:" + e.name + "was not deleted" -} diff --git a/aws/resources/lambda_test.go b/aws/resources/lambda_test.go index 882441f7..a8c382ee 100644 --- a/aws/resources/lambda_test.go +++ b/aws/resources/lambda_test.go @@ -10,52 +10,49 @@ import ( "github.com/aws/aws-sdk-go-v2/service/lambda" "github.com/aws/aws-sdk-go-v2/service/lambda/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) -type mockedLambda struct { - LambdaFunctionsAPI +type mockLambdaClient struct { DeleteFunctionOutput lambda.DeleteFunctionOutput ListFunctionsOutput lambda.ListFunctionsOutput ListTagsOutput lambda.ListTagsOutput } -func (m mockedLambda) DeleteFunction(ctx context.Context, params *lambda.DeleteFunctionInput, optFns ...func(*lambda.Options)) (*lambda.DeleteFunctionOutput, error) { +func (m *mockLambdaClient) DeleteFunction(ctx context.Context, params *lambda.DeleteFunctionInput, optFns ...func(*lambda.Options)) (*lambda.DeleteFunctionOutput, error) { return &m.DeleteFunctionOutput, nil } -func (m mockedLambda) ListFunctions(ctx context.Context, params *lambda.ListFunctionsInput, optFns ...func(*lambda.Options)) (*lambda.ListFunctionsOutput, error) { +func (m *mockLambdaClient) ListFunctions(ctx context.Context, params *lambda.ListFunctionsInput, optFns ...func(*lambda.Options)) (*lambda.ListFunctionsOutput, error) { return &m.ListFunctionsOutput, nil } -func (m mockedLambda) ListTags(ctx context.Context, params *lambda.ListTagsInput, optFns ...func(*lambda.Options)) (*lambda.ListTagsOutput, error) { +func (m *mockLambdaClient) ListTags(ctx context.Context, params *lambda.ListTagsInput, optFns ...func(*lambda.Options)) (*lambda.ListTagsOutput, error) { return &m.ListTagsOutput, nil } -func TestLambdaFunction_GetAll(t *testing.T) { +func TestListLambdaFunctions(t *testing.T) { t.Parallel() testName1 := "test-lambda-function1" testName2 := "test-lambda-function2" - testTime := time.Now() layout := "2006-01-02T15:04:05.000+0000" testTimeStr := "2023-07-28T12:34:56.789+0000" testTime, err := time.Parse(layout, testTimeStr) require.NoError(t, err) - lf := LambdaFunctions{ - Client: mockedLambda{ - ListFunctionsOutput: lambda.ListFunctionsOutput{ - Functions: []types.FunctionConfiguration{ - { - FunctionName: aws.String(testName1), - LastModified: aws.String(testTimeStr), - }, - { - FunctionName: aws.String(testName2), - LastModified: aws.String(testTimeStr), - }, + mock := &mockLambdaClient{ + ListFunctionsOutput: lambda.ListFunctionsOutput{ + Functions: []types.FunctionConfiguration{ + { + FunctionName: aws.String(testName1), + LastModified: aws.String(testTimeStr), + }, + { + FunctionName: aws.String(testName2), + LastModified: aws.String(testTimeStr), }, }, }, @@ -88,24 +85,17 @@ func TestLambdaFunction_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := lf.getAll(context.Background(), config.Config{ - LambdaFunction: tc.configObj, - }) + names, err := listLambdaFunctions(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) } } -func TestLambdaFunction_NukeAll(t *testing.T) { +func TestDeleteLambdaFunction(t *testing.T) { t.Parallel() - lf := LambdaFunctions{ - Client: mockedLambda{ - DeleteFunctionOutput: lambda.DeleteFunctionOutput{}, - }, - } - - err := lf.nukeAll([]*string{aws.String("test")}) + mock := &mockLambdaClient{} + err := deleteLambdaFunction(context.Background(), mock, aws.String("test")) require.NoError(t, err) } diff --git a/aws/resources/lambda_types.go b/aws/resources/lambda_types.go deleted file mode 100644 index 6e987133..00000000 --- a/aws/resources/lambda_types.go +++ /dev/null @@ -1,72 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/lambda" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type LambdaFunctionsAPI interface { - DeleteFunction(ctx context.Context, params *lambda.DeleteFunctionInput, optFns ...func(*lambda.Options)) (*lambda.DeleteFunctionOutput, error) - ListFunctions(ctx context.Context, params *lambda.ListFunctionsInput, optFns ...func(*lambda.Options)) (*lambda.ListFunctionsOutput, error) - ListTags(ctx context.Context, params *lambda.ListTagsInput, optFns ...func(*lambda.Options)) (*lambda.ListTagsOutput, error) -} - -type LambdaFunctions struct { - BaseAwsResource - Client LambdaFunctionsAPI - Region string - LambdaFunctionNames []string -} - -func (lf *LambdaFunctions) Init(cfg aws.Config) { - lf.Client = lambda.NewFromConfig(cfg) -} - -func (lf *LambdaFunctions) ResourceName() string { - return "lambda" -} - -// ResourceIdentifiers - The names of the lambda functions -func (lf *LambdaFunctions) ResourceIdentifiers() []string { - return lf.LambdaFunctionNames -} - -func (lf *LambdaFunctions) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (lf *LambdaFunctions) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.LambdaFunction -} - -func (lf *LambdaFunctions) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := lf.getAll(c, configObj) - if err != nil { - return nil, err - } - - lf.LambdaFunctionNames = aws.ToStringSlice(identifiers) - return lf.LambdaFunctionNames, nil -} - -// Nuke - nuke 'em all!!! -func (lf *LambdaFunctions) Nuke(ctx context.Context, identifiers []string) error { - if err := lf.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} - -type LambdaDeleteError struct { - name string -} - -func (e LambdaDeleteError) Error() string { - return "Lambda Function:" + e.name + "was not deleted" -} diff --git a/aws/resources/msk_cluster.go b/aws/resources/msk_cluster.go index ace858b2..66248100 100644 --- a/aws/resources/msk_cluster.go +++ b/aws/resources/msk_cluster.go @@ -8,31 +8,60 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kafka/types" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" - "github.com/gruntwork-io/go-commons/errors" + "github.com/gruntwork-io/cloud-nuke/resource" ) -func (m *MSKCluster) getAll(c context.Context, configObj config.Config) ([]*string, error) { - var clusterIDs []*string +// MSKClusterAPI defines the interface for MSK Cluster operations. +type MSKClusterAPI interface { + ListClustersV2(ctx context.Context, params *kafka.ListClustersV2Input, optFns ...func(*kafka.Options)) (*kafka.ListClustersV2Output, error) + DeleteCluster(ctx context.Context, params *kafka.DeleteClusterInput, optFns ...func(*kafka.Options)) (*kafka.DeleteClusterOutput, error) +} + +// NewMSKCluster creates a new MSKCluster resource using the generic resource pattern. +func NewMSKCluster() AwsResource { + return NewAwsResource(&resource.Resource[MSKClusterAPI]{ + ResourceTypeName: "msk-cluster", + BatchSize: 10, + InitClient: func(r *resource.Resource[MSKClusterAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for MSKCluster client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = kafka.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.MSKCluster + }, + Lister: listMSKClusters, + Nuker: resource.SimpleBatchDeleter(deleteMSKCluster), + }) +} + +// listMSKClusters retrieves all MSK clusters that match the config filters. +func listMSKClusters(ctx context.Context, client MSKClusterAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + var clusterArns []*string - paginator := kafka.NewListClustersV2Paginator(m.Client, &kafka.ListClustersV2Input{}) + paginator := kafka.NewListClustersV2Paginator(client, &kafka.ListClustersV2Input{}) for paginator.HasMorePages() { - page, err := paginator.NextPage(c) + page, err := paginator.NextPage(ctx) if err != nil { - return nil, errors.WithStackTrace(err) + return nil, err } for _, cluster := range page.ClusterInfoList { - if m.shouldInclude(cluster, configObj) { - clusterIDs = append(clusterIDs, cluster.ClusterArn) + if shouldIncludeMSKCluster(cluster, cfg) { + clusterArns = append(clusterArns, cluster.ClusterArn) } } } - return clusterIDs, nil + return clusterArns, nil } -func (m *MSKCluster) shouldInclude(cluster types.Cluster, configObj config.Config) bool { +// shouldIncludeMSKCluster determines if a cluster should be included based on state and config. +func shouldIncludeMSKCluster(cluster types.Cluster, cfg config.ResourceType) bool { if cluster.State == types.ClusterStateDeleting { return false } @@ -49,33 +78,16 @@ func (m *MSKCluster) shouldInclude(cluster types.Cluster, configObj config.Confi return false } - return configObj.MSKCluster.ShouldInclude(config.ResourceValue{ + return cfg.ShouldInclude(config.ResourceValue{ Name: cluster.ClusterName, Time: cluster.CreationTime, }) } -func (m *MSKCluster) nukeAll(identifiers []*string) error { - if len(identifiers) == 0 { - return nil - } - - for _, clusterArn := range identifiers { - _, err := m.Client.DeleteCluster(m.Context, &kafka.DeleteClusterInput{ - ClusterArn: clusterArn, - }) - if err != nil { - logging.Errorf("[Failed] %s", err) - } - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(clusterArn), - ResourceType: "MSKCluster", - Error: err, - } - report.Record(e) - } - - return nil +// deleteMSKCluster deletes a single MSK cluster. +func deleteMSKCluster(ctx context.Context, client MSKClusterAPI, clusterArn *string) error { + _, err := client.DeleteCluster(ctx, &kafka.DeleteClusterInput{ + ClusterArn: clusterArn, + }) + return err } diff --git a/aws/resources/msk_cluster_test.go b/aws/resources/msk_cluster_test.go index f9ff787a..ecc00e5d 100644 --- a/aws/resources/msk_cluster_test.go +++ b/aws/resources/msk_cluster_test.go @@ -12,24 +12,25 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kafka" "github.com/aws/aws-sdk-go-v2/service/kafka/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" + "github.com/stretchr/testify/require" ) type mockMSKClient struct { - MSKClusterAPI ListClustersV2Output kafka.ListClustersV2Output DeleteClusterOutput kafka.DeleteClusterOutput } -func (m mockMSKClient) ListClustersV2(ctx context.Context, params *kafka.ListClustersV2Input, optFns ...func(*kafka.Options)) (*kafka.ListClustersV2Output, error) { +func (m *mockMSKClient) ListClustersV2(ctx context.Context, params *kafka.ListClustersV2Input, optFns ...func(*kafka.Options)) (*kafka.ListClustersV2Output, error) { return &m.ListClustersV2Output, nil } -func (m mockMSKClient) DeleteCluster(ctx context.Context, params *kafka.DeleteClusterInput, optFns ...func(*kafka.Options)) (*kafka.DeleteClusterOutput, error) { +func (m *mockMSKClient) DeleteCluster(ctx context.Context, params *kafka.DeleteClusterInput, optFns ...func(*kafka.Options)) (*kafka.DeleteClusterOutput, error) { return &m.DeleteClusterOutput, nil } func TestListMSKClustersSingle(t *testing.T) { - mockMskClient := mockMSKClient{ + mock := &mockMSKClient{ ListClustersV2Output: kafka.ListClustersV2Output{ ClusterInfoList: []types.Cluster{ { @@ -42,11 +43,7 @@ func TestListMSKClustersSingle(t *testing.T) { }, } - msk := MSKCluster{ - Client: &mockMskClient, - } - - clusterIDs, err := msk.getAll(context.Background(), config.Config{}) + clusterIDs, err := listMSKClusters(context.Background(), mock, resource.Scope{}, config.ResourceType{}) if err != nil { t.Fatalf("Unable to list MSK Clusters: %v", err) } @@ -61,7 +58,7 @@ func TestListMSKClustersSingle(t *testing.T) { } func TestListMSKClustersMultiple(t *testing.T) { - mockMskClient := mockMSKClient{ + mock := &mockMSKClient{ ListClustersV2Output: kafka.ListClustersV2Output{ ClusterInfoList: []types.Cluster{ { @@ -84,11 +81,7 @@ func TestListMSKClustersMultiple(t *testing.T) { }, } - msk := MSKCluster{ - Client: &mockMskClient, - } - - clusterIDs, err := msk.getAll(context.Background(), config.Config{}) + clusterIDs, err := listMSKClusters(context.Background(), mock, resource.Scope{}, config.ResourceType{}) if err != nil { t.Fatalf("Unable to list MSK Clusters: %v", err) } @@ -111,7 +104,7 @@ func TestShouldIncludeMSKCluster(t *testing.T) { tests := map[string]struct { cluster types.Cluster - configObj config.Config + configObj config.ResourceType expected bool }{ "cluster is in deleting state": { @@ -120,7 +113,7 @@ func TestShouldIncludeMSKCluster(t *testing.T) { State: types.ClusterStateDeleting, CreationTime: &creationTime, }, - configObj: config.Config{}, + configObj: config.ResourceType{}, expected: false, }, "cluster is in creating state": { @@ -129,7 +122,7 @@ func TestShouldIncludeMSKCluster(t *testing.T) { State: types.ClusterStateCreating, CreationTime: &creationTime, }, - configObj: config.Config{}, + configObj: config.ResourceType{}, expected: false, }, "cluster is in active state": { @@ -138,7 +131,7 @@ func TestShouldIncludeMSKCluster(t *testing.T) { State: types.ClusterStateActive, CreationTime: &creationTime, }, - configObj: config.Config{}, + configObj: config.ResourceType{}, expected: true, }, "cluster excluded by name": { @@ -147,13 +140,11 @@ func TestShouldIncludeMSKCluster(t *testing.T) { State: types.ClusterStateActive, CreationTime: &creationTime, }, - configObj: config.Config{ - MSKCluster: config.ResourceType{ - ExcludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ - { - RE: *regexp.MustCompile("test-cluster"), - }, + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{ + { + RE: *regexp.MustCompile("test-cluster"), }, }, }, @@ -166,13 +157,11 @@ func TestShouldIncludeMSKCluster(t *testing.T) { State: types.ClusterStateActive, CreationTime: &creationTime, }, - configObj: config.Config{ - MSKCluster: config.ResourceType{ - IncludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ - { - RE: *regexp.MustCompile("^test-cluster"), - }, + configObj: config.ResourceType{ + IncludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{ + { + RE: *regexp.MustCompile("^test-cluster"), }, }, }, @@ -183,8 +172,7 @@ func TestShouldIncludeMSKCluster(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - msk := MSKCluster{} - actual := msk.shouldInclude(tc.cluster, tc.configObj) + actual := shouldIncludeMSKCluster(tc.cluster, tc.configObj) if actual != tc.expected { t.Fatalf("Expected %v, got %v", tc.expected, actual) } @@ -192,17 +180,11 @@ func TestShouldIncludeMSKCluster(t *testing.T) { } } -func TestNukeMSKCluster(t *testing.T) { - mockMskClient := mockMSKClient{ +func TestDeleteMSKCluster(t *testing.T) { + mock := &mockMSKClient{ DeleteClusterOutput: kafka.DeleteClusterOutput{}, } - msk := MSKCluster{ - Client: &mockMskClient, - } - - err := msk.Nuke(context.TODO(), []string{}) - if err != nil { - t.Fatalf("Unable to nuke MSK Clusters: %v", err) - } + err := deleteMSKCluster(context.Background(), mock, aws.String("test-arn")) + require.NoError(t, err) } diff --git a/aws/resources/msk_cluster_types.go b/aws/resources/msk_cluster_types.go deleted file mode 100644 index 4dfdb8f3..00000000 --- a/aws/resources/msk_cluster_types.go +++ /dev/null @@ -1,67 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/kafka" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type MSKClusterAPI interface { - ListClustersV2(ctx context.Context, params *kafka.ListClustersV2Input, optFns ...func(*kafka.Options)) (*kafka.ListClustersV2Output, error) - DeleteCluster(ctx context.Context, params *kafka.DeleteClusterInput, optFns ...func(*kafka.Options)) (*kafka.DeleteClusterOutput, error) -} - -// MSKCluster - represents all AWS Managed Streaming for Kafka clusters that should be deleted. -type MSKCluster struct { - BaseAwsResource - Client MSKClusterAPI - Region string - ClusterArns []string -} - -func (m *MSKCluster) Init(cfg aws.Config) { - m.Client = kafka.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (m *MSKCluster) ResourceName() string { - return "msk-cluster" -} - -// ResourceIdentifiers - The instance ids of the AWS Managed Streaming for Kafka clusters -func (m *MSKCluster) ResourceIdentifiers() []string { - return m.ClusterArns -} - -func (m *MSKCluster) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle. Note that nat gateway does not support bulk delete, so - // we will be deleting this many in parallel using go routines. We conservatively pick 10 here, both to limit - // overloading the runtime and to avoid AWS throttling with many API calls. - return 10 -} - -func (m *MSKCluster) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.MSKCluster -} - -func (m *MSKCluster) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := m.getAll(c, configObj) - if err != nil { - return nil, err - } - - m.ClusterArns = aws.ToStringSlice(identifiers) - return m.ClusterArns, nil -} - -// Nuke - nuke 'em all!!! -func (m *MSKCluster) Nuke(ctx context.Context, identifiers []string) error { - if err := m.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/opensearch.go b/aws/resources/opensearch.go index f87b0a4e..df42cde2 100644 --- a/aws/resources/opensearch.go +++ b/aws/resources/opensearch.go @@ -3,7 +3,6 @@ package resources import ( "context" "fmt" - "sync" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -12,26 +11,56 @@ import ( "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/gruntwork-io/go-commons/errors" "github.com/gruntwork-io/go-commons/retry" - "github.com/hashicorp/go-multierror" ) -// getAll queries AWS for all active domains in the account that meet the nuking criteria based on +// OpenSearchDomainsAPI defines the interface for OpenSearch operations. +type OpenSearchDomainsAPI interface { + AddTags(ctx context.Context, params *opensearch.AddTagsInput, optFns ...func(*opensearch.Options)) (*opensearch.AddTagsOutput, error) + DeleteDomain(ctx context.Context, params *opensearch.DeleteDomainInput, optFns ...func(*opensearch.Options)) (*opensearch.DeleteDomainOutput, error) + DescribeDomains(ctx context.Context, params *opensearch.DescribeDomainsInput, optFns ...func(*opensearch.Options)) (*opensearch.DescribeDomainsOutput, error) + ListDomainNames(ctx context.Context, params *opensearch.ListDomainNamesInput, optFns ...func(*opensearch.Options)) (*opensearch.ListDomainNamesOutput, error) + ListTags(ctx context.Context, params *opensearch.ListTagsInput, optFns ...func(*opensearch.Options)) (*opensearch.ListTagsOutput, error) +} + +// NewOpenSearchDomains creates a new OpenSearchDomains resource using the generic resource pattern. +func NewOpenSearchDomains() AwsResource { + return NewAwsResource(&resource.Resource[OpenSearchDomainsAPI]{ + ResourceTypeName: "opensearchdomain", + BatchSize: 10, + InitClient: func(r *resource.Resource[OpenSearchDomainsAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for OpenSearch client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = opensearch.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.OpenSearchDomain + }, + Lister: listOpenSearchDomains, + Nuker: resource.ConcurrentDeleteThenWaitAll(deleteOpenSearchDomain, waitForOpenSearchDomainsDeleted), + }) +} + +// listOpenSearchDomains queries AWS for all active domains in the account that meet the nuking criteria based on // the excludeAfter and configObj configurations. Note that OpenSearch Domains do not have resource timestamps, so we // use the first-seen tagging pattern to track which OpenSearch Domains should be nuked based on time. This routine will // tag resources with the first-seen tag if it does not have one. -func (osd *OpenSearchDomains) getAll(c context.Context, configObj config.Config) ([]*string, error) { +func listOpenSearchDomains(ctx context.Context, client OpenSearchDomainsAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { var firstSeenTime *time.Time var err error - domains, err := osd.getAllActiveOpenSearchDomains() + domains, err := getAllActiveOpenSearchDomains(ctx, client) if err != nil { return nil, errors.WithStackTrace(err) } - excludeFirstSeenTag, err := util.GetBoolFromContext(c, util.ExcludeFirstSeenTagKey) + excludeFirstSeenTag, err := util.GetBoolFromContext(ctx, util.ExcludeFirstSeenTagKey) if err != nil { return nil, errors.WithStackTrace(err) } @@ -39,20 +68,20 @@ func (osd *OpenSearchDomains) getAll(c context.Context, configObj config.Config) domainsToNuke := []*string{} for _, domain := range domains { if !excludeFirstSeenTag { - firstSeenTime, err = osd.getFirstSeenTag(domain.ARN) + firstSeenTime, err = getOpenSearchFirstSeenTag(ctx, client, domain.ARN) if err != nil { return nil, errors.WithStackTrace(err) } if firstSeenTime == nil { - err := osd.setFirstSeenTag(domain.ARN, time.Now().UTC()) + err := setOpenSearchFirstSeenTag(ctx, client, domain.ARN, time.Now().UTC()) if err != nil { logging.Errorf("Error tagging the OpenSearch Domain with ARN %s with error: %s", aws.ToString(domain.ARN), err.Error()) return nil, errors.WithStackTrace(err) } } } - if configObj.OpenSearchDomain.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Name: domain.DomainName, Time: firstSeenTime, }) { @@ -64,9 +93,9 @@ func (osd *OpenSearchDomains) getAll(c context.Context, configObj config.Config) } // getAllActiveOpenSearchDomains filters all active OpenSearch domains, which are those that have the `Created` flag true and `Deleted` flag false. -func (osd *OpenSearchDomains) getAllActiveOpenSearchDomains() ([]types.DomainStatus, error) { +func getAllActiveOpenSearchDomains(ctx context.Context, client OpenSearchDomainsAPI) ([]types.DomainStatus, error) { allDomains := []*string{} - resp, err := osd.Client.ListDomainNames(osd.Context, &opensearch.ListDomainNamesInput{}) + resp, err := client.ListDomainNames(ctx, &opensearch.ListDomainNamesInput{}) if err != nil { logging.Errorf("Error getting all OpenSearch domains") return nil, errors.WithStackTrace(err) @@ -76,7 +105,7 @@ func (osd *OpenSearchDomains) getAllActiveOpenSearchDomains() ([]types.DomainSta } input := &opensearch.DescribeDomainsInput{DomainNames: aws.ToStringSlice(allDomains)} - describedDomains, describeErr := osd.Client.DescribeDomains(osd.Context, input) + describedDomains, describeErr := client.DescribeDomains(ctx, input) if describeErr != nil { logging.Errorf("Error describing Domains from input %s: ", input) return nil, errors.WithStackTrace(describeErr) @@ -91,8 +120,8 @@ func (osd *OpenSearchDomains) getAllActiveOpenSearchDomains() ([]types.DomainSta return filteredDomains, nil } -// Tag an OpenSearch Domain identified by the given ARN when it's first seen by cloud-nuke -func (osd *OpenSearchDomains) setFirstSeenTag(domainARN *string, timestamp time.Time) error { +// setOpenSearchFirstSeenTag tags an OpenSearch Domain identified by the given ARN when it's first seen by cloud-nuke +func setOpenSearchFirstSeenTag(ctx context.Context, client OpenSearchDomainsAPI, domainARN *string, timestamp time.Time) error { logging.Debugf("Tagging the OpenSearch Domain with ARN %s with first seen timestamp", aws.ToString(domainARN)) firstSeenTime := util.FormatTimestamp(timestamp) @@ -106,7 +135,7 @@ func (osd *OpenSearchDomains) setFirstSeenTag(domainARN *string, timestamp time. }, } - _, err := osd.Client.AddTags(osd.Context, input) + _, err := client.AddTags(ctx, input) if err != nil { return errors.WithStackTrace(err) } @@ -114,12 +143,12 @@ func (osd *OpenSearchDomains) setFirstSeenTag(domainARN *string, timestamp time. return nil } -// getFirstSeenTag gets the `cloud-nuke-first-seen` tag value for a given OpenSearch Domain -func (osd *OpenSearchDomains) getFirstSeenTag(domainARN *string) (*time.Time, error) { +// getOpenSearchFirstSeenTag gets the `cloud-nuke-first-seen` tag value for a given OpenSearch Domain +func getOpenSearchFirstSeenTag(ctx context.Context, client OpenSearchDomainsAPI, domainARN *string) (*time.Time, error) { var firstSeenTime *time.Time input := &opensearch.ListTagsInput{ARN: domainARN} - domainTags, err := osd.Client.ListTags(osd.Context, input) + domainTags, err := client.ListTags(ctx, input) if err != nil { logging.Errorf("Error getting the tags for OpenSearch Domain with ARN %s", aws.ToString(domainARN)) return firstSeenTime, errors.WithStackTrace(err) @@ -140,55 +169,25 @@ func (osd *OpenSearchDomains) getFirstSeenTag(domainARN *string) (*time.Time, er return firstSeenTime, nil } -// nukeAll nukes the given list of OpenSearch domains concurrently. Note that the opensearch API -// does not support bulk delete, so this routine will spawn a goroutine for each domain that needs to be nuked so that -// they can be issued concurrently. -func (osd *OpenSearchDomains) nukeAll(identifiers []*string) error { - if len(identifiers) == 0 { - logging.Debugf("No OpenSearch Domains to nuke in region %s", osd.Region) - return nil - } - - // NOTE: we don't need to do pagination here, because the caller handles the pagination to this function, - // based on OpenSearchDomains.MaxBatchSize, however, we add a guard here to warn users when the batching fails and has a - // chance of throttling AWS. Since we concurrently make one call for each identifier, we pick 100 for the limit here - // because many APIs in AWS have a limit of 100 requests per second. - if len(identifiers) > 100 { - logging.Errorf("Nuking too many OpenSearch Domains at once (100): halting to avoid hitting AWS API rate limiting") - return TooManyOpenSearchDomainsErr{} - } - - logging.Debugf("Deleting OpenSearch Domains in region %s", osd.Region) - wg := new(sync.WaitGroup) - wg.Add(len(identifiers)) - errChans := make([]chan error, len(identifiers)) - for i, domainName := range identifiers { - errChans[i] = make(chan error, 1) - go osd.deleteAsync(wg, errChans[i], domainName) - } - wg.Wait() - - // Collect all the errors from the async delete calls into a single error struct. - var allErrs *multierror.Error - for _, errChan := range errChans { - if err := <-errChan; err != nil { - allErrs = multierror.Append(allErrs, err) - logging.Errorf("[Failed] %s", err) - } - } - finalErr := allErrs.ErrorOrNil() - if finalErr != nil { - return errors.WithStackTrace(finalErr) +// deleteOpenSearchDomain deletes a single OpenSearch domain. +func deleteOpenSearchDomain(ctx context.Context, client OpenSearchDomainsAPI, domainName *string) error { + input := &opensearch.DeleteDomainInput{DomainName: domainName} + _, err := client.DeleteDomain(ctx, input) + if err != nil { + return errors.WithStackTrace(err) } + return nil +} - // Now wait until the OpenSearch Domains are deleted - err := retry.DoWithRetry( +// waitForOpenSearchDomainsDeleted waits for all OpenSearch domains to be fully deleted. +func waitForOpenSearchDomainsDeleted(ctx context.Context, client OpenSearchDomainsAPI, names []string) error { + return retry.DoWithRetry( logging.Logger.WithTime(time.Now()), "Waiting for all OpenSearch Domains to be deleted.", // Wait a maximum of 5 minutes: 10 seconds in between, up to 30 times 30, 10*time.Second, func() error { - resp, err := osd.Client.DescribeDomains(osd.Context, &opensearch.DescribeDomainsInput{DomainNames: aws.ToStringSlice(identifiers)}) + resp, err := client.DescribeDomains(ctx, &opensearch.DescribeDomainsInput{DomainNames: names}) if err != nil { return errors.WithStackTrace(retry.FatalError{Underlying: err}) } @@ -198,32 +197,6 @@ func (osd *OpenSearchDomains) nukeAll(identifiers []*string) error { return fmt.Errorf("Not all OpenSearch domains are deleted.") }, ) - if err != nil { - return errors.WithStackTrace(err) - } - for _, domainName := range identifiers { - logging.Debugf("[OK] OpenSearch Domain %s was deleted in %s", aws.ToString(domainName), osd.Region) - } - return nil -} - -// deleteAsync deletes the provided OpenSearch Domain asynchronously in a goroutine, using wait groups -// for concurrency control and a return channel for errors. -func (osd *OpenSearchDomains) deleteAsync(wg *sync.WaitGroup, errChan chan error, domainName *string) { - defer wg.Done() - - input := &opensearch.DeleteDomainInput{DomainName: domainName} - _, err := osd.Client.DeleteDomain(osd.Context, input) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(domainName), - ResourceType: "OpenSearch Domain", - Error: err, - } - report.Record(e) - - errChan <- err } // Custom errors diff --git a/aws/resources/opensearch_test.go b/aws/resources/opensearch_test.go index 641be1fd..614f1cdf 100644 --- a/aws/resources/opensearch_test.go +++ b/aws/resources/opensearch_test.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/opensearch/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/stretchr/testify/require" ) @@ -50,40 +51,38 @@ func TestOpenSearch_GetAll(t *testing.T) { testName1 := "test-domain1" testName2 := "test-domain2" now := time.Now() - osd := OpenSearchDomains{ - Client: mockedOpenSearch{ - ListDomainNamesOutput: opensearch.ListDomainNamesOutput{ - DomainNames: []types.DomainInfo{ - {DomainName: aws.String(testName1)}, - {DomainName: aws.String(testName2)}, - }, + mockClient := mockedOpenSearch{ + ListDomainNamesOutput: opensearch.ListDomainNamesOutput{ + DomainNames: []types.DomainInfo{ + {DomainName: aws.String(testName1)}, + {DomainName: aws.String(testName2)}, }, + }, - ListTagsOutput: opensearch.ListTagsOutput{ - TagList: []types.Tag{ - { - Key: aws.String(firstSeenTagKey), - Value: aws.String(util.FormatTimestamp(now)), - }, - { - Key: aws.String(firstSeenTagKey), - Value: aws.String(util.FormatTimestamp(now.Add(1))), - }, + ListTagsOutput: opensearch.ListTagsOutput{ + TagList: []types.Tag{ + { + Key: aws.String(firstSeenTagKey), + Value: aws.String(util.FormatTimestamp(now)), + }, + { + Key: aws.String(firstSeenTagKey), + Value: aws.String(util.FormatTimestamp(now.Add(1))), }, }, + }, - DescribeDomainsOutput: opensearch.DescribeDomainsOutput{ - DomainStatusList: []types.DomainStatus{ - { - DomainName: aws.String(testName1), - Created: aws.Bool(true), - Deleted: aws.Bool(false), - }, - { - DomainName: aws.String(testName2), - Created: aws.Bool(true), - Deleted: aws.Bool(false), - }, + DescribeDomainsOutput: opensearch.DescribeDomainsOutput{ + DomainStatusList: []types.DomainStatus{ + { + DomainName: aws.String(testName1), + Created: aws.Bool(true), + Deleted: aws.Bool(false), + }, + { + DomainName: aws.String(testName2), + Created: aws.Bool(true), + Deleted: aws.Bool(false), }, }, }, @@ -120,26 +119,36 @@ func TestOpenSearch_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := osd.getAll(tc.ctx, config.Config{ - OpenSearchDomain: tc.configObj, - }) + names, err := listOpenSearchDomains(tc.ctx, mockClient, resource.Scope{Region: "us-east-1"}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) } } -func TestOpenSearch_NukeAll(t *testing.T) { +func TestOpenSearch_Nuke(t *testing.T) { + t.Parallel() + + mockClient := mockedOpenSearch{ + DeleteDomainOutput: opensearch.DeleteDomainOutput{}, + DescribeDomainsOutput: opensearch.DescribeDomainsOutput{}, + } + + // Test the individual delete function + err := deleteOpenSearchDomain(context.Background(), mockClient, aws.String("test-domain")) + require.NoError(t, err) +} +func TestOpenSearch_WaitForDeleted(t *testing.T) { t.Parallel() - osd := OpenSearchDomains{ - Client: mockedOpenSearch{ - DeleteDomainOutput: opensearch.DeleteDomainOutput{}, - DescribeDomainsOutput: opensearch.DescribeDomainsOutput{}, + // Mock returns empty list, meaning domains are deleted + mockClient := mockedOpenSearch{ + DescribeDomainsOutput: opensearch.DescribeDomainsOutput{ + DomainStatusList: []types.DomainStatus{}, }, } - err := osd.nukeAll([]*string{aws.String("test")}) + err := waitForOpenSearchDomainsDeleted(context.Background(), mockClient, []string{"test-domain"}) require.NoError(t, err) } diff --git a/aws/resources/opensearch_types.go b/aws/resources/opensearch_types.go deleted file mode 100644 index c9090dfe..00000000 --- a/aws/resources/opensearch_types.go +++ /dev/null @@ -1,71 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/opensearch" - - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type OpenSearchDomainsAPI interface { - AddTags(ctx context.Context, params *opensearch.AddTagsInput, optFns ...func(*opensearch.Options)) (*opensearch.AddTagsOutput, error) - DeleteDomain(ctx context.Context, params *opensearch.DeleteDomainInput, optFns ...func(*opensearch.Options)) (*opensearch.DeleteDomainOutput, error) - DescribeDomains(ctx context.Context, params *opensearch.DescribeDomainsInput, optFns ...func(*opensearch.Options)) (*opensearch.DescribeDomainsOutput, error) - ListDomainNames(ctx context.Context, params *opensearch.ListDomainNamesInput, optFns ...func(*opensearch.Options)) (*opensearch.ListDomainNamesOutput, error) - ListTags(ctx context.Context, params *opensearch.ListTagsInput, optFns ...func(*opensearch.Options)) (*opensearch.ListTagsOutput, error) -} - -// OpenSearchDomains represents all OpenSearch domains found in a region -type OpenSearchDomains struct { - BaseAwsResource - Client OpenSearchDomainsAPI - Region string - DomainNames []string -} - -func (osd *OpenSearchDomains) Init(cfg aws.Config) { - osd.Client = opensearch.NewFromConfig(cfg) -} - -// ResourceName is the simple name of the aws resource -func (osd *OpenSearchDomains) ResourceName() string { - return "opensearchdomain" -} - -// ResourceIdentifiers the collected OpenSearch Domains -func (osd *OpenSearchDomains) ResourceIdentifiers() []string { - return osd.DomainNames -} - -// MaxBatchSize returns the number of resources that should be nuked at a time. A small number is used to ensure AWS -// doesn't throttle. OpenSearch Domains do not support bulk delete, so we will be deleting this many in parallel -// using go routines. We conservatively pick 10 here, both to limit overloading the runtime and to avoid AWS throttling -// with many API calls. -func (osd *OpenSearchDomains) MaxBatchSize() int { - return 10 -} - -func (osd *OpenSearchDomains) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.OpenSearchDomain -} - -func (osd *OpenSearchDomains) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := osd.getAll(c, configObj) - if err != nil { - return nil, err - } - - osd.DomainNames = aws.ToStringSlice(identifiers) - return osd.DomainNames, nil -} - -// Nuke nukes all OpenSearch domain resources -func (osd *OpenSearchDomains) Nuke(ctx context.Context, identifiers []string) error { - if err := osd.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - return nil -} diff --git a/aws/resources/rds.go b/aws/resources/rds.go index 1ba7b1c3..e7054a2b 100644 --- a/aws/resources/rds.go +++ b/aws/resources/rds.go @@ -2,26 +2,57 @@ package resources import ( "context" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/gruntwork-io/go-commons/errors" ) -func (di *DBInstances) getAll(ctx context.Context, configObj config.Config) ([]*string, error) { - result, err := di.Client.DescribeDBInstances(ctx, &rds.DescribeDBInstancesInput{}) +// DBInstancesAPI defines the interface for RDS DB Instance operations. +type DBInstancesAPI interface { + ModifyDBInstance(ctx context.Context, params *rds.ModifyDBInstanceInput, optFns ...func(*rds.Options)) (*rds.ModifyDBInstanceOutput, error) + DeleteDBInstance(ctx context.Context, params *rds.DeleteDBInstanceInput, optFns ...func(*rds.Options)) (*rds.DeleteDBInstanceOutput, error) + DescribeDBInstances(ctx context.Context, params *rds.DescribeDBInstancesInput, optFns ...func(*rds.Options)) (*rds.DescribeDBInstancesOutput, error) +} + +// NewDBInstances creates a new DBInstances resource using the generic resource pattern. +func NewDBInstances() AwsResource { + return NewAwsResource(&resource.Resource[DBInstancesAPI]{ + ResourceTypeName: "rds", + BatchSize: 49, + InitClient: func(r *resource.Resource[DBInstancesAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for RDS client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = rds.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.DBInstances.ResourceType + }, + Lister: listDBInstances, + Nuker: deleteDBInstances, + }) +} + +// listDBInstances retrieves all RDS DB instances that match the config filters. +func listDBInstances(ctx context.Context, client DBInstancesAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + result, err := client.DescribeDBInstances(ctx, &rds.DescribeDBInstancesInput{}) if err != nil { return nil, errors.WithStackTrace(err) } var names []*string - for _, database := range result.DBInstances { - if configObj.DBInstances.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Time: database.InstanceCreateTime, Name: database.DBInstanceIdentifier, Tags: util.ConvertRDSTypeTagsToMap(database.TagList), @@ -33,18 +64,19 @@ func (di *DBInstances) getAll(ctx context.Context, configObj config.Config) ([]* return names, nil } -func (di *DBInstances) nukeAll(names []*string) error { +// deleteDBInstances deletes all RDS DB instances. +func deleteDBInstances(ctx context.Context, client DBInstancesAPI, scope resource.Scope, resourceType string, names []*string) error { if len(names) == 0 { - logging.Debugf("No RDS DB Instance to nuke in region %s", di.Region) + logging.Debugf("No RDS DB Instance to nuke in region %s", scope.Region) return nil } - logging.Debugf("Deleting all RDS Instances in region %s", di.Region) + logging.Debugf("Deleting all RDS Instances in region %s", scope.Region) deletedNames := []*string{} for _, name := range names { // Check if instance is part of a cluster before trying to disable deletion protection - describeResp, err := di.Client.DescribeDBInstances(di.Context, &rds.DescribeDBInstancesInput{ + describeResp, err := client.DescribeDBInstances(ctx, &rds.DescribeDBInstancesInput{ DBInstanceIdentifier: name, }) if err != nil { @@ -53,7 +85,7 @@ func (di *DBInstances) nukeAll(names []*string) error { } // Only disable deletion protection if instance is not part of a cluster if len(describeResp.DBInstances) > 0 && describeResp.DBInstances[0].DBClusterIdentifier == nil { - _, modifyErr := di.Client.ModifyDBInstance(context.TODO(), &rds.ModifyDBInstanceInput{ + _, modifyErr := client.ModifyDBInstance(ctx, &rds.ModifyDBInstanceInput{ DBInstanceIdentifier: name, DeletionProtection: aws.Bool(false), ApplyImmediately: aws.Bool(true), @@ -70,7 +102,7 @@ func (di *DBInstances) nukeAll(names []*string) error { SkipFinalSnapshot: aws.Bool(true), } - _, err = di.Client.DeleteDBInstance(di.Context, params) + _, err = client.DeleteDBInstance(ctx, params) if err != nil { logging.Errorf("[Failed] %s: %s", *name, err) @@ -82,10 +114,10 @@ func (di *DBInstances) nukeAll(names []*string) error { if len(deletedNames) > 0 { for _, name := range deletedNames { - waiter := rds.NewDBInstanceDeletedWaiter(di.Client) - err := waiter.Wait(di.Context, &rds.DescribeDBInstancesInput{ + waiter := rds.NewDBInstanceDeletedWaiter(client) + err := waiter.Wait(ctx, &rds.DescribeDBInstancesInput{ DBInstanceIdentifier: name, - }, di.Timeout) + }, 5*time.Minute) // Record status of this resource e := report.Entry{ @@ -102,6 +134,15 @@ func (di *DBInstances) nukeAll(names []*string) error { } } - logging.Debugf("[OK] %d RDS DB Instance(s) deleted in %s", len(deletedNames), di.Region) + logging.Debugf("[OK] %d RDS DB Instance(s) deleted in %s", len(deletedNames), scope.Region) return nil } + +// RdsDeleteError represents an error when deleting RDS resources. +type RdsDeleteError struct { + name string +} + +func (e RdsDeleteError) Error() string { + return "RDS DB Instance:" + e.name + "was not deleted" +} diff --git a/aws/resources/rds_cluster.go b/aws/resources/rds_cluster.go index 023443a1..7c03a151 100644 --- a/aws/resources/rds_cluster.go +++ b/aws/resources/rds_cluster.go @@ -5,30 +5,84 @@ import ( goerr "errors" "time" - "github.com/gruntwork-io/cloud-nuke/util" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" - "github.com/gruntwork-io/go-commons/errors" + "github.com/gruntwork-io/cloud-nuke/resource" + "github.com/gruntwork-io/cloud-nuke/util" ) -func (instance *DBClusters) waitUntilRdsClusterDeleted(input *rds.DescribeDBClustersInput) error { - waitTimeout := instance.Timeout +// DBClustersAPI defines the interface for RDS DB Cluster operations. +type DBClustersAPI interface { + DeleteDBCluster(ctx context.Context, params *rds.DeleteDBClusterInput, optFns ...func(*rds.Options)) (*rds.DeleteDBClusterOutput, error) + DescribeDBClusters(ctx context.Context, params *rds.DescribeDBClustersInput, optFns ...func(*rds.Options)) (*rds.DescribeDBClustersOutput, error) + ModifyDBCluster(ctx context.Context, params *rds.ModifyDBClusterInput, optFns ...func(*rds.Options)) (*rds.ModifyDBClusterOutput, error) +} + +// NewDBClusters creates a new RDS DB Clusters resource using the generic resource pattern. +func NewDBClusters() AwsResource { + return NewAwsResource(&resource.Resource[DBClustersAPI]{ + ResourceTypeName: "rds-cluster", + BatchSize: 49, + InitClient: func(r *resource.Resource[DBClustersAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for RDS client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = rds.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.DBClusters.ResourceType + }, + Lister: listDBClusters, + Nuker: resource.SequentialDeleter(resource.DeleteThenWait(deleteDBCluster, waitUntilRdsClusterDeleted)), + }) +} + +// listDBClusters retrieves all RDS DB Clusters that match the config filters. +func listDBClusters(ctx context.Context, client DBClustersAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + var names []*string + paginator := rds.NewDescribeDBClustersPaginator(client, &rds.DescribeDBClustersInput{}) + + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return nil, err + } + + for _, database := range page.DBClusters { + if cfg.ShouldInclude(config.ResourceValue{ + Name: database.DBClusterIdentifier, + Time: database.ClusterCreateTime, + Tags: util.ConvertRDSTypeTagsToMap(database.TagList), + }) { + names = append(names, database.DBClusterIdentifier) + } + } + } + + return names, nil +} + +// waitUntilRdsClusterDeleted waits until the RDS cluster is deleted. +func waitUntilRdsClusterDeleted(ctx context.Context, client DBClustersAPI, clusterIdentifier *string) error { + waitTimeout := DefaultWaitTimeout const retryInterval = 10 * time.Second maxRetries := int(waitTimeout / retryInterval) for i := 0; i < maxRetries; i++ { - _, err := instance.Client.DescribeDBClusters(instance.Context, input) + _, err := client.DescribeDBClusters(ctx, &rds.DescribeDBClustersInput{ + DBClusterIdentifier: clusterIdentifier, + }) if err != nil { var notFoundErr *types.DBClusterNotFoundFault if goerr.As(err, ¬FoundErr) { return nil } - return err } @@ -36,85 +90,26 @@ func (instance *DBClusters) waitUntilRdsClusterDeleted(input *rds.DescribeDBClus logging.Debug("Waiting for RDS Cluster to be deleted") } - return RdsDeleteError{name: *input.DBClusterIdentifier} + return RdsDeleteError{name: aws.ToString(clusterIdentifier)} } -func (instance *DBClusters) getAll(c context.Context, configObj config.Config) ([]*string, error) { - result, err := instance.Client.DescribeDBClusters(c, &rds.DescribeDBClustersInput{}) +// deleteDBCluster deletes a single RDS DB Cluster after disabling deletion protection. +func deleteDBCluster(ctx context.Context, client DBClustersAPI, name *string) error { + // Disable deletion protection before attempting to delete the cluster + _, err := client.ModifyDBCluster(ctx, &rds.ModifyDBClusterInput{ + DBClusterIdentifier: name, + DeletionProtection: aws.Bool(false), + ApplyImmediately: aws.Bool(true), + }) if err != nil { - return nil, errors.WithStackTrace(err) - } - - var names []*string - for _, database := range result.DBClusters { - if configObj.DBClusters.ShouldInclude(config.ResourceValue{ - Name: database.DBClusterIdentifier, - Time: database.ClusterCreateTime, - Tags: util.ConvertRDSTypeTagsToMap(database.TagList), - }) { - names = append(names, database.DBClusterIdentifier) - } - } - - return names, nil -} - -func (instance *DBClusters) nukeAll(names []*string) error { - if len(names) == 0 { - logging.Debugf("No RDS DB Cluster to nuke in region %s", instance.Region) - return nil - } - - logging.Debugf("Deleting all RDS Clusters in region %s", instance.Region) - deletedNames := []*string{} - - for _, name := range names { - // Disable deletion protection before attempting to delete the cluster - _, err := instance.Client.ModifyDBCluster(instance.Context, &rds.ModifyDBClusterInput{ - DBClusterIdentifier: name, - DeletionProtection: aws.Bool(false), - ApplyImmediately: aws.Bool(true), - }) - if err != nil { - logging.Warnf("[Failed] to disable deletion protection for cluster %s: %s", *name, err) - } - - params := &rds.DeleteDBClusterInput{ - DBClusterIdentifier: name, - SkipFinalSnapshot: aws.Bool(true), - } - - _, err = instance.Client.DeleteDBCluster(instance.Context, params) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(name), - ResourceType: "RDS Cluster", - Error: err, - } - report.Record(e) - - if err != nil { - logging.Debugf("[Failed] %s: %s", *name, err) - } else { - deletedNames = append(deletedNames, name) - logging.Debugf("Deleted RDS DB Cluster: %s", aws.ToString(name)) - } + logging.Warnf("[Failed] to disable deletion protection for cluster %s: %s", *name, err) } - if len(deletedNames) > 0 { - for _, name := range deletedNames { - - err := instance.waitUntilRdsClusterDeleted(&rds.DescribeDBClustersInput{ - DBClusterIdentifier: name, - }) - if err != nil { - logging.Errorf("[Failed] %s", err) - return errors.WithStackTrace(err) - } - } + params := &rds.DeleteDBClusterInput{ + DBClusterIdentifier: name, + SkipFinalSnapshot: aws.Bool(true), } - logging.Debugf("[OK] %d RDS DB Cluster(s) nuked in %s", len(deletedNames), instance.Region) - return nil + _, err = client.DeleteDBCluster(ctx, params) + return err } diff --git a/aws/resources/rds_cluster_test.go b/aws/resources/rds_cluster_test.go index 75e0c82e..d8d3b458 100644 --- a/aws/resources/rds_cluster_test.go +++ b/aws/resources/rds_cluster_test.go @@ -10,107 +10,119 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -type mockedDBClusters struct { - DBClustersAPI +type mockDBClustersClient struct { DescribeDBClustersOutput rds.DescribeDBClustersOutput DescribeDBClustersError error DeleteDBClusterOutput rds.DeleteDBClusterOutput ModifyDBClusterOutput rds.ModifyDBClusterOutput } -func (m mockedDBClusters) waitUntilRdsClusterDeleted(*rds.DescribeDBClustersInput) error { - return nil -} - -func (m mockedDBClusters) DeleteDBCluster(ctx context.Context, params *rds.DeleteDBClusterInput, optFns ...func(*rds.Options)) (*rds.DeleteDBClusterOutput, error) { +func (m *mockDBClustersClient) DeleteDBCluster(ctx context.Context, params *rds.DeleteDBClusterInput, optFns ...func(*rds.Options)) (*rds.DeleteDBClusterOutput, error) { return &m.DeleteDBClusterOutput, nil } -func (m mockedDBClusters) DescribeDBClusters(ctx context.Context, params *rds.DescribeDBClustersInput, optFns ...func(*rds.Options)) (*rds.DescribeDBClustersOutput, error) { +func (m *mockDBClustersClient) DescribeDBClusters(ctx context.Context, params *rds.DescribeDBClustersInput, optFns ...func(*rds.Options)) (*rds.DescribeDBClustersOutput, error) { return &m.DescribeDBClustersOutput, m.DescribeDBClustersError } -func (m mockedDBClusters) ModifyDBCluster(ctx context.Context, params *rds.ModifyDBClusterInput, optFns ...func(*rds.Options)) (*rds.ModifyDBClusterOutput, error) { +func (m *mockDBClustersClient) ModifyDBCluster(ctx context.Context, params *rds.ModifyDBClusterInput, optFns ...func(*rds.Options)) (*rds.ModifyDBClusterOutput, error) { return &m.ModifyDBClusterOutput, nil } -func TestRDSClusterGetAll(t *testing.T) { +func TestDBClusters_ResourceName(t *testing.T) { + r := NewDBClusters() + assert.Equal(t, "rds-cluster", r.ResourceName()) +} + +func TestDBClusters_MaxBatchSize(t *testing.T) { + r := NewDBClusters() + assert.Equal(t, 49, r.MaxBatchSize()) +} + +func TestListDBClusters(t *testing.T) { t.Parallel() - // Test data setup testName := "test-db-cluster" testProtectedName := "test-protected-cluster" now := time.Now() - dbCluster := DBClusters{ - Client: mockedDBClusters{ - DescribeDBClustersOutput: rds.DescribeDBClustersOutput{ - DBClusters: []types.DBCluster{ - { - DBClusterIdentifier: &testName, - ClusterCreateTime: &now, - DeletionProtection: aws.Bool(false), - }, - { - DBClusterIdentifier: &testProtectedName, - ClusterCreateTime: &now, - DeletionProtection: aws.Bool(true), - }, + + mock := &mockDBClustersClient{ + DescribeDBClustersOutput: rds.DescribeDBClustersOutput{ + DBClusters: []types.DBCluster{ + { + DBClusterIdentifier: &testName, + ClusterCreateTime: &now, + DeletionProtection: aws.Bool(false), + }, + { + DBClusterIdentifier: &testProtectedName, + ClusterCreateTime: &now, + DeletionProtection: aws.Bool(true), }, }, }, } // Test Case 1: Empty config - should include both protected and unprotected clusters - // Deletion protection is automatically disabled during deletion, so we include these clusters - clusters, err := dbCluster.getAll(context.Background(), config.Config{DBClusters: config.AWSProtectectableResourceType{}}) - assert.NoError(t, err) + clusters, err := listDBClusters(context.Background(), mock, resource.Scope{}, config.ResourceType{}) + require.NoError(t, err) assert.Contains(t, aws.ToStringSlice(clusters), strings.ToLower(testName)) assert.Contains(t, aws.ToStringSlice(clusters), strings.ToLower(testProtectedName)) +} - // Test Case 2: With IncludeDeletionProtected flag - behavior is now the same as Test Case 1 - // since deletion-protected clusters are always included - clusters, err = dbCluster.getAll(context.Background(), config.Config{ - DBClusters: config.AWSProtectectableResourceType{ - IncludeDeletionProtected: true, - }, - }) - assert.NoError(t, err) - assert.Contains(t, aws.ToStringSlice(clusters), strings.ToLower(testName)) - assert.Contains(t, aws.ToStringSlice(clusters), strings.ToLower(testProtectedName)) +func TestListDBClusters_WithTimeFilter(t *testing.T) { + t.Parallel() + + testName := "test-db-cluster" + testProtectedName := "test-protected-cluster" + now := time.Now() - // Test Case 3: Time-based exclusion - should exclude all clusters created after specified time - clusters, err = dbCluster.getAll(context.Background(), config.Config{ - DBClusters: config.AWSProtectectableResourceType{ - ResourceType: config.ResourceType{ - ExcludeRule: config.FilterRule{ - TimeAfter: aws.Time(now.Add(-1)), + mock := &mockDBClustersClient{ + DescribeDBClustersOutput: rds.DescribeDBClustersOutput{ + DBClusters: []types.DBCluster{ + { + DBClusterIdentifier: &testName, + ClusterCreateTime: &now, + DeletionProtection: aws.Bool(false), + }, + { + DBClusterIdentifier: &testProtectedName, + ClusterCreateTime: &now, + DeletionProtection: aws.Bool(true), }, }, }, - }) - assert.NoError(t, err) + } + + // Time-based exclusion - should exclude all clusters created after specified time + cfg := config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now.Add(-1)), + }, + } + + clusters, err := listDBClusters(context.Background(), mock, resource.Scope{}, cfg) + require.NoError(t, err) assert.NotContains(t, aws.ToStringSlice(clusters), strings.ToLower(testName)) assert.NotContains(t, aws.ToStringSlice(clusters), strings.ToLower(testProtectedName)) } -func TestRDSClusterNukeAll(t *testing.T) { +func TestDeleteDBCluster(t *testing.T) { t.Parallel() testName := "test-db-cluster" - dbCluster := DBClusters{ - Client: mockedDBClusters{ - DescribeDBClustersOutput: rds.DescribeDBClustersOutput{}, - DescribeDBClustersError: &types.DBClusterNotFoundFault{}, - ModifyDBClusterOutput: rds.ModifyDBClusterOutput{}, - DeleteDBClusterOutput: rds.DeleteDBClusterOutput{}, - }, + mock := &mockDBClustersClient{ + DescribeDBClustersOutput: rds.DescribeDBClustersOutput{}, + DescribeDBClustersError: &types.DBClusterNotFoundFault{}, + ModifyDBClusterOutput: rds.ModifyDBClusterOutput{}, + DeleteDBClusterOutput: rds.DeleteDBClusterOutput{}, } - dbCluster.Context = context.Background() - dbCluster.Timeout = DefaultWaitTimeout - err := dbCluster.nukeAll([]*string{&testName}) - assert.NoError(t, err) + err := deleteDBCluster(context.Background(), mock, aws.String(testName)) + require.NoError(t, err) } diff --git a/aws/resources/rds_cluster_types.go b/aws/resources/rds_cluster_types.go deleted file mode 100644 index 9caff6db..00000000 --- a/aws/resources/rds_cluster_types.go +++ /dev/null @@ -1,63 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/rds" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type DBClustersAPI interface { - DeleteDBCluster(ctx context.Context, params *rds.DeleteDBClusterInput, optFns ...func(*rds.Options)) (*rds.DeleteDBClusterOutput, error) - DescribeDBClusters(ctx context.Context, params *rds.DescribeDBClustersInput, optFns ...func(*rds.Options)) (*rds.DescribeDBClustersOutput, error) - ModifyDBCluster(ctx context.Context, params *rds.ModifyDBClusterInput, optFns ...func(*rds.Options)) (*rds.ModifyDBClusterOutput, error) -} -type DBClusters struct { - BaseAwsResource - Client DBClustersAPI - Region string - InstanceNames []string -} - -func (instance *DBClusters) Init(cfg aws.Config) { - instance.Client = rds.NewFromConfig(cfg) -} - -func (instance *DBClusters) ResourceName() string { - return "rds-cluster" -} - -// ResourceIdentifiers - The instance names of the rds db instances -func (instance *DBClusters) ResourceIdentifiers() []string { - return instance.InstanceNames -} - -func (instance *DBClusters) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (instance *DBClusters) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.DBClusters.ResourceType -} - -func (instance *DBClusters) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := instance.getAll(c, configObj) - if err != nil { - return nil, err - } - - instance.InstanceNames = aws.ToStringSlice(identifiers) - return instance.InstanceNames, nil -} - -// Nuke - nuke 'em all!!! -func (instance *DBClusters) Nuke(ctx context.Context, identifiers []string) error { - if err := instance.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/rds_parameter_group.go b/aws/resources/rds_parameter_group.go index de22a6fa..0863cc33 100644 --- a/aws/resources/rds_parameter_group.go +++ b/aws/resources/rds_parameter_group.go @@ -8,21 +8,50 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" - "github.com/gruntwork-io/go-commons/errors" + "github.com/gruntwork-io/cloud-nuke/resource" ) -func (pg *RdsParameterGroup) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// RdsParameterGroupAPI defines the interface for RDS Parameter Group operations. +type RdsParameterGroupAPI interface { + DeleteDBParameterGroup(ctx context.Context, params *rds.DeleteDBParameterGroupInput, optFns ...func(*rds.Options)) (*rds.DeleteDBParameterGroupOutput, error) + DescribeDBParameterGroups(ctx context.Context, params *rds.DescribeDBParameterGroupsInput, optFns ...func(*rds.Options)) (*rds.DescribeDBParameterGroupsOutput, error) +} + +// NewRdsParameterGroup creates a new RDS Parameter Group resource using the generic resource pattern. +func NewRdsParameterGroup() AwsResource { + return NewAwsResource(&resource.Resource[RdsParameterGroupAPI]{ + ResourceTypeName: "rds-parameter-group", + // Tentative batch size to ensure AWS doesn't throttle + BatchSize: 49, + InitClient: func(r *resource.Resource[RdsParameterGroupAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for RDS client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = rds.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.RdsParameterGroup + }, + Lister: listRdsParameterGroups, + Nuker: resource.SimpleBatchDeleter(deleteRdsParameterGroup), + }) +} + +// listRdsParameterGroups retrieves all RDS parameter groups that match the config filters. +func listRdsParameterGroups(ctx context.Context, client RdsParameterGroupAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { var names []*string // Initialize the paginator - paginator := rds.NewDescribeDBParameterGroupsPaginator(pg.Client, &rds.DescribeDBParameterGroupsInput{}) + paginator := rds.NewDescribeDBParameterGroupsPaginator(client, &rds.DescribeDBParameterGroupsInput{}) // Iterate through the pages for paginator.HasMorePages() { - page, err := paginator.NextPage(c) + page, err := paginator.NextPage(ctx) if err != nil { - return nil, errors.WithStackTrace(err) + return nil, err } // Process each parameter group on the page @@ -34,7 +63,7 @@ func (pg *RdsParameterGroup) getAll(c context.Context, configObj config.Config) continue } - if configObj.RdsParameterGroup.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Name: parameterGroup.DBParameterGroupName, }) { names = append(names, parameterGroup.DBParameterGroupName) @@ -45,39 +74,10 @@ func (pg *RdsParameterGroup) getAll(c context.Context, configObj config.Config) return names, nil } -func (pg *RdsParameterGroup) nukeAll(identifiers []*string) error { - if len(identifiers) == 0 { - logging.Debugf("No DB parameter groups in region %s", pg.Region) - return nil - } - - logging.Debugf("Deleting all DB parameter groups in region %s", pg.Region) - var deleted []*string - - for _, identifier := range identifiers { - logging.Debugf("[RDS Parameter Group] Deleting %s in region %s", *identifier, pg.Region) - - _, err := pg.Client.DeleteDBParameterGroup( - pg.Context, - &rds.DeleteDBParameterGroupInput{ - DBParameterGroupName: identifier, - }) - if err != nil { - logging.Errorf("[RDS Parameter Group] Error deleting RDS Parameter Group %s: %s", *identifier, err) - } else { - deleted = append(deleted, identifier) - logging.Debugf("[RDS Parameter Group] Deleted RDS Parameter Group %s", *identifier) - } - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(identifier), - ResourceType: pg.ResourceName(), - Error: err, - } - report.Record(e) - } - - logging.Debugf("[OK] %d RDS DB parameter group(s) nuked in %s", len(deleted), pg.Region) - return nil +// deleteRdsParameterGroup deletes a single RDS parameter group. +func deleteRdsParameterGroup(ctx context.Context, client RdsParameterGroupAPI, identifier *string) error { + _, err := client.DeleteDBParameterGroup(ctx, &rds.DeleteDBParameterGroupInput{ + DBParameterGroupName: identifier, + }) + return err } diff --git a/aws/resources/rds_parameter_group_test.go b/aws/resources/rds_parameter_group_test.go index a1dbea1f..f0a80c05 100644 --- a/aws/resources/rds_parameter_group_test.go +++ b/aws/resources/rds_parameter_group_test.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -27,22 +28,30 @@ func (m mockedRdsDBParameterGroup) DeleteDBParameterGroup(ctx context.Context, p return &m.DeleteDBParameterGroupOutput, nil } -func TestRDSparameterGroupGetAll(t *testing.T) { +func TestRdsParameterGroup_ResourceName(t *testing.T) { + r := NewRdsParameterGroup() + require.Equal(t, "rds-parameter-group", r.ResourceName()) +} + +func TestRdsParameterGroup_MaxBatchSize(t *testing.T) { + r := NewRdsParameterGroup() + require.Equal(t, 49, r.MaxBatchSize()) +} + +func TestListRdsParameterGroups(t *testing.T) { t.Parallel() testName01 := "test-db-paramater-group-01" testName02 := "test-db-paramater-group-02" - pg := RdsParameterGroup{ - Client: mockedRdsDBParameterGroup{ - DescribeDBParameterGroupsOutput: rds.DescribeDBParameterGroupsOutput{ - DBParameterGroups: []types.DBParameterGroup{ - { - DBParameterGroupName: aws.String(testName01), - }, - { - DBParameterGroupName: aws.String(testName02), - }, + mock := mockedRdsDBParameterGroup{ + DescribeDBParameterGroupsOutput: rds.DescribeDBParameterGroupsOutput{ + DBParameterGroups: []types.DBParameterGroup{ + { + DBParameterGroupName: aws.String(testName01), + }, + { + DBParameterGroupName: aws.String(testName02), }, }, }, @@ -68,9 +77,7 @@ func TestRDSparameterGroupGetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := pg.getAll(context.Background(), config.Config{ - RdsParameterGroup: tc.configObj, - }) + names, err := listRdsParameterGroups(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, len(tc.expected), len(names)) @@ -81,15 +88,13 @@ func TestRDSparameterGroupGetAll(t *testing.T) { } } -func TestRDSParameterGroupNukeAll(t *testing.T) { +func TestDeleteRdsParameterGroup(t *testing.T) { t.Parallel() testName := "test-db-parameter-group" - dbCluster := RdsParameterGroup{ - Client: mockedRdsDBParameterGroup{ - DeleteDBParameterGroupOutput: rds.DeleteDBParameterGroupOutput{}, - }, + mock := mockedRdsDBParameterGroup{ + DeleteDBParameterGroupOutput: rds.DeleteDBParameterGroupOutput{}, } - err := dbCluster.nukeAll([]*string{&testName}) + err := deleteRdsParameterGroup(context.Background(), mock, &testName) assert.NoError(t, err) } diff --git a/aws/resources/rds_parameter_group_types.go b/aws/resources/rds_parameter_group_types.go deleted file mode 100644 index 215fcb4d..00000000 --- a/aws/resources/rds_parameter_group_types.go +++ /dev/null @@ -1,62 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/rds" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type RdsParameterGroupAPI interface { - DeleteDBParameterGroup(ctx context.Context, params *rds.DeleteDBParameterGroupInput, optFns ...func(*rds.Options)) (*rds.DeleteDBParameterGroupOutput, error) - DescribeDBParameterGroups(ctx context.Context, params *rds.DescribeDBParameterGroupsInput, optFns ...func(*rds.Options)) (*rds.DescribeDBParameterGroupsOutput, error) -} -type RdsParameterGroup struct { - BaseAwsResource - Client RdsParameterGroupAPI - Region string - GroupNames []string -} - -func (pg *RdsParameterGroup) Init(cfg aws.Config) { - pg.Client = rds.NewFromConfig(cfg) -} - -func (pg *RdsParameterGroup) ResourceName() string { - return "rds-parameter-group" -} - -// ResourceIdentifiers - The names of the rds parameter group -func (pg *RdsParameterGroup) ResourceIdentifiers() []string { - return pg.GroupNames -} - -func (pg *RdsParameterGroup) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (pg *RdsParameterGroup) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.RdsParameterGroup -} - -func (pg *RdsParameterGroup) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := pg.getAll(c, configObj) - if err != nil { - return nil, err - } - - pg.GroupNames = aws.ToStringSlice(identifiers) - return pg.GroupNames, nil -} - -// Nuke - nuke 'em all!!! -func (pg *RdsParameterGroup) Nuke(ctx context.Context, identifiers []string) error { - if err := pg.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/rds_subnet_group.go b/aws/resources/rds_subnet_group.go index b8577d61..76dbd020 100644 --- a/aws/resources/rds_subnet_group.go +++ b/aws/resources/rds_subnet_group.go @@ -6,48 +6,57 @@ import ( "fmt" "time" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/rds/types" + "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" ) -func (dsg *DBSubnetGroups) waitUntilRdsDbSubnetGroupDeleted(name *string) error { - // wait up to 15 minutes - for i := 0; i < 90; i++ { - _, err := dsg.Client.DescribeDBSubnetGroups( - dsg.Context, &rds.DescribeDBSubnetGroupsInput{DBSubnetGroupName: name}) - if err != nil { - var notFoundErr *types.DBSubnetGroupNotFoundFault - if goerr.As(err, ¬FoundErr) { - return nil - } - return err - } - - time.Sleep(10 * time.Second) - logging.Debug("Waiting for RDS Cluster to be deleted") - } +// DBSubnetGroupsAPI defines the interface for RDS DB Subnet Group operations. +type DBSubnetGroupsAPI interface { + DescribeDBSubnetGroups(ctx context.Context, params *rds.DescribeDBSubnetGroupsInput, optFns ...func(*rds.Options)) (*rds.DescribeDBSubnetGroupsOutput, error) + DeleteDBSubnetGroup(ctx context.Context, params *rds.DeleteDBSubnetGroupInput, optFns ...func(*rds.Options)) (*rds.DeleteDBSubnetGroupOutput, error) + ListTagsForResource(ctx context.Context, params *rds.ListTagsForResourceInput, optFns ...func(*rds.Options)) (*rds.ListTagsForResourceOutput, error) +} - return RdsDeleteError{name: *name} +// NewDBSubnetGroups creates a new DBSubnetGroups resource using the generic resource pattern. +func NewDBSubnetGroups() AwsResource { + return NewAwsResource(&resource.Resource[DBSubnetGroupsAPI]{ + ResourceTypeName: "rds-subnet-group", + BatchSize: 49, + InitClient: func(r *resource.Resource[DBSubnetGroupsAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for RDS client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = rds.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.DBSubnetGroups + }, + Lister: listDBSubnetGroups, + Nuker: resource.SequentialDeleter(deleteDBSubnetGroup), + }) } -func (dsg *DBSubnetGroups) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// listDBSubnetGroups retrieves all RDS DB Subnet Groups that match the config filters. +func listDBSubnetGroups(ctx context.Context, client DBSubnetGroupsAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { var names []*string - paginator := rds.NewDescribeDBSubnetGroupsPaginator(dsg.Client, &rds.DescribeDBSubnetGroupsInput{}) + paginator := rds.NewDescribeDBSubnetGroupsPaginator(client, &rds.DescribeDBSubnetGroupsInput{}) for paginator.HasMorePages() { - page, err := paginator.NextPage(c) + page, err := paginator.NextPage(ctx) if err != nil { return nil, errors.WithStackTrace(err) } for _, subnetGroup := range page.DBSubnetGroups { - tagsRes, err := dsg.Client.ListTagsForResource(c, &rds.ListTagsForResourceInput{ + tagsRes, err := client.ListTagsForResource(ctx, &rds.ListTagsForResourceInput{ ResourceName: subnetGroup.DBSubnetGroupArn, }) if err != nil { @@ -61,7 +70,7 @@ func (dsg *DBSubnetGroups) getAll(c context.Context, configObj config.Config) ([ for _, v := range tagsRes.TagList { rv.Tags[*v.Key] = *v.Value } - if configObj.DBSubnetGroups.ShouldInclude(rv) { + if cfg.ShouldInclude(rv) { names = append(names, subnetGroup.DBSubnetGroupName) } } @@ -70,47 +79,34 @@ func (dsg *DBSubnetGroups) getAll(c context.Context, configObj config.Config) ([ return names, nil } -func (dsg *DBSubnetGroups) nukeAll(names []*string) error { - if len(names) == 0 { - logging.Debugf("No DB Subnet groups in region %s", dsg.Region) - return nil +// deleteDBSubnetGroup deletes a single RDS DB Subnet Group. +func deleteDBSubnetGroup(ctx context.Context, client DBSubnetGroupsAPI, name *string) error { + _, err := client.DeleteDBSubnetGroup(ctx, &rds.DeleteDBSubnetGroupInput{ + DBSubnetGroupName: name, + }) + if err != nil { + return err } - logging.Debugf("Deleting all DB Subnet groups in region %s", dsg.Region) - deletedNames := []*string{} - - for _, name := range names { - _, err := dsg.Client.DeleteDBSubnetGroup(dsg.Context, &rds.DeleteDBSubnetGroupInput{ - DBSubnetGroupName: name, - }) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(name), - ResourceType: "RDS DB Subnet Group", - Error: err, - } - report.Record(e) + return waitUntilRdsDbSubnetGroupDeleted(ctx, client, name) +} +// waitUntilRdsDbSubnetGroupDeleted waits for an RDS DB Subnet Group to be deleted. +func waitUntilRdsDbSubnetGroupDeleted(ctx context.Context, client DBSubnetGroupsAPI, name *string) error { + // wait up to 15 minutes + for i := 0; i < 90; i++ { + _, err := client.DescribeDBSubnetGroups(ctx, &rds.DescribeDBSubnetGroupsInput{DBSubnetGroupName: name}) if err != nil { - logging.Debugf("[Failed] %s: %s", *name, err) - } else { - deletedNames = append(deletedNames, name) - logging.Debugf("Deleted RDS DB subnet group: %s", aws.ToString(name)) - } - } - - if len(deletedNames) > 0 { - for _, name := range deletedNames { - - err := dsg.waitUntilRdsDbSubnetGroupDeleted(name) - if err != nil { - logging.Errorf("[Failed] %s", err) - return errors.WithStackTrace(err) + var notFoundErr *types.DBSubnetGroupNotFoundFault + if goerr.As(err, ¬FoundErr) { + return nil } + return err } + + time.Sleep(10 * time.Second) + logging.Debug("Waiting for RDS DB Subnet Group to be deleted") } - logging.Debugf("[OK] %d RDS DB subnet group(s) nuked in %s", len(deletedNames), dsg.Region) - return nil + return RdsDeleteError{name: *name} } diff --git a/aws/resources/rds_subnet_group_test.go b/aws/resources/rds_subnet_group_test.go index 35fe420c..e26c4a56 100644 --- a/aws/resources/rds_subnet_group_test.go +++ b/aws/resources/rds_subnet_group_test.go @@ -9,24 +9,25 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) -type mockedDBSubnetGroups struct { - DBSubnetGroupsAPI +type mockDBSubnetGroupsClient struct { DescribeDBSubnetGroupsOutput rds.DescribeDBSubnetGroupsOutput DescribeDBSubnetGroupError error DeleteDBSubnetGroupOutput rds.DeleteDBSubnetGroupOutput } -func (m mockedDBSubnetGroups) DescribeDBSubnetGroups(ctx context.Context, params *rds.DescribeDBSubnetGroupsInput, optFns ...func(*rds.Options)) (*rds.DescribeDBSubnetGroupsOutput, error) { +func (m *mockDBSubnetGroupsClient) DescribeDBSubnetGroups(ctx context.Context, params *rds.DescribeDBSubnetGroupsInput, optFns ...func(*rds.Options)) (*rds.DescribeDBSubnetGroupsOutput, error) { return &m.DescribeDBSubnetGroupsOutput, m.DescribeDBSubnetGroupError } -func (m mockedDBSubnetGroups) DeleteDBSubnetGroup(ctx context.Context, params *rds.DeleteDBSubnetGroupInput, optFns ...func(*rds.Options)) (*rds.DeleteDBSubnetGroupOutput, error) { +func (m *mockDBSubnetGroupsClient) DeleteDBSubnetGroup(ctx context.Context, params *rds.DeleteDBSubnetGroupInput, optFns ...func(*rds.Options)) (*rds.DeleteDBSubnetGroupOutput, error) { return &m.DeleteDBSubnetGroupOutput, nil } -func (m mockedDBSubnetGroups) ListTagsForResource(ctx context.Context, params *rds.ListTagsForResourceInput, optFns ...func(*rds.Options)) (*rds.ListTagsForResourceOutput, error) { + +func (m *mockDBSubnetGroupsClient) ListTagsForResource(ctx context.Context, params *rds.ListTagsForResourceInput, optFns ...func(*rds.Options)) (*rds.ListTagsForResourceOutput, error) { return &rds.ListTagsForResourceOutput{ TagList: []types.Tag{ { @@ -39,24 +40,22 @@ func (m mockedDBSubnetGroups) ListTagsForResource(ctx context.Context, params *r var dbSubnetGroupNotFoundError = &types.DBSubnetGroupNotFoundFault{} -func TestDBSubnetGroups_GetAll(t *testing.T) { - +func TestListDBSubnetGroups(t *testing.T) { t.Parallel() testName1 := "test-db-subnet-group1" testName2 := "test-db-subnet-group2" - dsg := DBSubnetGroups{ - Client: mockedDBSubnetGroups{ - DescribeDBSubnetGroupsOutput: rds.DescribeDBSubnetGroupsOutput{ - DBSubnetGroups: []types.DBSubnetGroup{ - { - DBSubnetGroupName: aws.String(testName1), - DBSubnetGroupArn: aws.String("arn:" + testName1), - }, - { - DBSubnetGroupName: aws.String(testName2), - DBSubnetGroupArn: aws.String("arn:" + testName2), - }, + + mock := &mockDBSubnetGroupsClient{ + DescribeDBSubnetGroupsOutput: rds.DescribeDBSubnetGroupsOutput{ + DBSubnetGroups: []types.DBSubnetGroup{ + { + DBSubnetGroupName: aws.String(testName1), + DBSubnetGroupArn: aws.String("arn:" + testName1), + }, + { + DBSubnetGroupName: aws.String(testName2), + DBSubnetGroupArn: aws.String("arn:" + testName2), }, }, }, @@ -93,27 +92,21 @@ func TestDBSubnetGroups_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := dsg.getAll(context.Background(), config.Config{ - DBSubnetGroups: tc.configObj, - }) + names, err := listDBSubnetGroups(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) } - } -func TestDBSubnetGroups_NukeAll(t *testing.T) { - +func TestDeleteDBSubnetGroup(t *testing.T) { t.Parallel() - dsg := DBSubnetGroups{ - Client: mockedDBSubnetGroups{ - DeleteDBSubnetGroupOutput: rds.DeleteDBSubnetGroupOutput{}, - DescribeDBSubnetGroupError: dbSubnetGroupNotFoundError, - }, + mock := &mockDBSubnetGroupsClient{ + DeleteDBSubnetGroupOutput: rds.DeleteDBSubnetGroupOutput{}, + DescribeDBSubnetGroupError: dbSubnetGroupNotFoundError, } - err := dsg.nukeAll([]*string{aws.String("test")}) + err := deleteDBSubnetGroup(context.Background(), mock, aws.String("test")) require.NoError(t, err) } diff --git a/aws/resources/rds_subnet_group_types.go b/aws/resources/rds_subnet_group_types.go deleted file mode 100644 index a23a5281..00000000 --- a/aws/resources/rds_subnet_group_types.go +++ /dev/null @@ -1,64 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/rds" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type DBSubnetGroupsAPI interface { - DescribeDBSubnetGroups(ctx context.Context, params *rds.DescribeDBSubnetGroupsInput, optFns ...func(*rds.Options)) (*rds.DescribeDBSubnetGroupsOutput, error) - DeleteDBSubnetGroup(ctx context.Context, params *rds.DeleteDBSubnetGroupInput, optFns ...func(*rds.Options)) (*rds.DeleteDBSubnetGroupOutput, error) - ListTagsForResource(ctx context.Context, params *rds.ListTagsForResourceInput, optFns ...func(*rds.Options)) (*rds.ListTagsForResourceOutput, error) -} - -type DBSubnetGroups struct { - BaseAwsResource - Client DBSubnetGroupsAPI - Region string - InstanceNames []string -} - -func (dsg *DBSubnetGroups) Init(cfg aws.Config) { - dsg.Client = rds.NewFromConfig(cfg) -} - -func (dsg *DBSubnetGroups) ResourceName() string { - return "rds-subnet-group" -} - -// ResourceIdentifiers - The instance names of the rds db instances -func (dsg *DBSubnetGroups) ResourceIdentifiers() []string { - return dsg.InstanceNames -} - -func (dsg *DBSubnetGroups) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (dsg *DBSubnetGroups) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.DBSubnetGroups -} - -func (dsg *DBSubnetGroups) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := dsg.getAll(c, configObj) - if err != nil { - return nil, err - } - - dsg.InstanceNames = aws.ToStringSlice(identifiers) - return dsg.InstanceNames, nil -} - -// Nuke - nuke 'em all!!! -func (dsg *DBSubnetGroups) Nuke(ctx context.Context, identifiers []string) error { - if err := dsg.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/rds_test.go b/aws/resources/rds_test.go index 820fe68b..58e29d1d 100644 --- a/aws/resources/rds_test.go +++ b/aws/resources/rds_test.go @@ -10,21 +10,21 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type mockedDBInstance struct { +type mockDBInstanceClient struct { t *testing.T - RDSAPI DescribeDBInstancesOutput rds.DescribeDBInstancesOutput DeleteDBInstanceOutput rds.DeleteDBInstanceOutput ModifyCallExpected bool InstancesDeleted map[string]bool } -func (m mockedDBInstance) ModifyDBInstance(ctx context.Context, params *rds.ModifyDBInstanceInput, optFns ...func(*rds.Options)) (*rds.ModifyDBInstanceOutput, error) { +func (m *mockDBInstanceClient) ModifyDBInstance(ctx context.Context, params *rds.ModifyDBInstanceInput, optFns ...func(*rds.Options)) (*rds.ModifyDBInstanceOutput, error) { if !m.ModifyCallExpected { assert.Fail(m.t, "ModifyDBInstance should not be called for cluster member instances") } @@ -35,7 +35,7 @@ func (m mockedDBInstance) ModifyDBInstance(ctx context.Context, params *rds.Modi return nil, nil } -func (m mockedDBInstance) DescribeDBInstances(ctx context.Context, params *rds.DescribeDBInstancesInput, optFns ...func(*rds.Options)) (*rds.DescribeDBInstancesOutput, error) { +func (m *mockDBInstanceClient) DescribeDBInstances(ctx context.Context, params *rds.DescribeDBInstancesInput, optFns ...func(*rds.Options)) (*rds.DescribeDBInstancesOutput, error) { // If specific instance is requested and it's been deleted, return empty result if params.DBInstanceIdentifier != nil { if m.InstancesDeleted != nil && m.InstancesDeleted[*params.DBInstanceIdentifier] { @@ -45,7 +45,7 @@ func (m mockedDBInstance) DescribeDBInstances(ctx context.Context, params *rds.D return &m.DescribeDBInstancesOutput, nil } -func (m mockedDBInstance) DeleteDBInstance(ctx context.Context, params *rds.DeleteDBInstanceInput, optFns ...func(*rds.Options)) (*rds.DeleteDBInstanceOutput, error) { +func (m *mockDBInstanceClient) DeleteDBInstance(ctx context.Context, params *rds.DeleteDBInstanceInput, optFns ...func(*rds.Options)) (*rds.DeleteDBInstanceOutput, error) { // Mark instance as deleted for waiter if m.InstancesDeleted != nil && params.DBInstanceIdentifier != nil { m.InstancesDeleted[*params.DBInstanceIdentifier] = true @@ -53,141 +53,114 @@ func (m mockedDBInstance) DeleteDBInstance(ctx context.Context, params *rds.Dele return &m.DeleteDBInstanceOutput, nil } -func (m mockedDBInstance) WaitForOutput(ctx context.Context, params *rds.DescribeDBInstancesInput, maxWaitDur time.Duration, optFns ...func(*rds.Options)) (*rds.DescribeDBInstancesOutput, error) { - return nil, nil -} - -func TestDBInstances_GetAll(t *testing.T) { - +func TestListDBInstances(t *testing.T) { t.Parallel() - testName1 := "test-db-instance1" - testName2 := "test-db-instance2" - testName3 := "test-db-instance3" testIdentifier1 := "test-identifier1" testIdentifier2 := "test-identifier2" testIdentifier3 := "test-identifier3" now := time.Now() - di := DBInstances{ - Client: mockedDBInstance{ - t: t, - DescribeDBInstancesOutput: rds.DescribeDBInstancesOutput{ - DBInstances: []types.DBInstance{ - { - DBInstanceIdentifier: aws.String(testIdentifier1), - DBName: aws.String(testName1), - InstanceCreateTime: aws.Time(now), - }, - { - DBInstanceIdentifier: aws.String(testIdentifier2), - DBName: aws.String(testName2), - InstanceCreateTime: aws.Time(now.Add(1)), - }, - { - DBInstanceIdentifier: aws.String(testIdentifier3), - DBName: aws.String(testName3), - InstanceCreateTime: aws.Time(now.Add(1)), - DeletionProtection: aws.Bool(true), - }, + + mock := &mockDBInstanceClient{ + t: t, + DescribeDBInstancesOutput: rds.DescribeDBInstancesOutput{ + DBInstances: []types.DBInstance{ + { + DBInstanceIdentifier: aws.String(testIdentifier1), + DBName: aws.String("test-db-instance1"), + InstanceCreateTime: aws.Time(now), + }, + { + DBInstanceIdentifier: aws.String(testIdentifier2), + DBName: aws.String("test-db-instance2"), + InstanceCreateTime: aws.Time(now.Add(1)), + }, + { + DBInstanceIdentifier: aws.String(testIdentifier3), + DBName: aws.String("test-db-instance3"), + InstanceCreateTime: aws.Time(now.Add(1)), + DeletionProtection: aws.Bool(true), }, }, }, } tests := map[string]struct { - configObj config.AWSProtectectableResourceType + configObj config.ResourceType expected []string }{ "emptyFilter": { - configObj: config.AWSProtectectableResourceType{}, + configObj: config.ResourceType{}, expected: []string{testIdentifier1, testIdentifier2, testIdentifier3}, }, "nameExclusionFilter": { - configObj: config.AWSProtectectableResourceType{ResourceType: config.ResourceType{ + configObj: config.ResourceType{ ExcludeRule: config.FilterRule{ NamesRegExp: []config.Expression{{ - RE: *regexp.MustCompile("^" + testName1 + "$"), + RE: *regexp.MustCompile("^test-identifier1$"), }}, }, - }}, - expected: []string{testIdentifier1, testIdentifier2, testIdentifier3}, + }, + expected: []string{testIdentifier2, testIdentifier3}, }, "timeAfterExclusionFilter": { - configObj: config.AWSProtectectableResourceType{ - ResourceType: config.ResourceType{ - ExcludeRule: config.FilterRule{ - TimeAfter: aws.Time(now), - }}}, - expected: []string{testIdentifier1}, - }, - "includeDeletionProtection": { - configObj: config.AWSProtectectableResourceType{ - IncludeDeletionProtected: true, - ResourceType: config.ResourceType{}, + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now), + }, }, - expected: []string{testIdentifier1, testIdentifier2, testIdentifier3}, + expected: []string{testIdentifier1}, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := di.getAll(context.Background(), config.Config{ - DBInstances: tc.configObj, - }) + names, err := listDBInstances(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) } } -func TestDBInstances_NukeAll(t *testing.T) { - +func TestDeleteDBInstances(t *testing.T) { t.Parallel() t.Run("standalone instance", func(t *testing.T) { - di := DBInstances{ - Client: mockedDBInstance{ - t: t, - DescribeDBInstancesOutput: rds.DescribeDBInstancesOutput{ - DBInstances: []types.DBInstance{ - { - DBInstanceIdentifier: aws.String("test-standalone"), - DBClusterIdentifier: nil, // Not part of a cluster - }, + mock := &mockDBInstanceClient{ + t: t, + DescribeDBInstancesOutput: rds.DescribeDBInstancesOutput{ + DBInstances: []types.DBInstance{ + { + DBInstanceIdentifier: aws.String("test-standalone"), + DBClusterIdentifier: nil, // Not part of a cluster }, }, - DeleteDBInstanceOutput: rds.DeleteDBInstanceOutput{}, - ModifyCallExpected: true, // Should call ModifyDBInstance - InstancesDeleted: make(map[string]bool), // Track deleted instances }, + DeleteDBInstanceOutput: rds.DeleteDBInstanceOutput{}, + ModifyCallExpected: true, // Should call ModifyDBInstance + InstancesDeleted: make(map[string]bool), // Track deleted instances } - di.Context = context.Background() - di.Timeout = DefaultWaitTimeout - err := di.nukeAll([]*string{aws.String("test-standalone")}) + err := deleteDBInstances(context.Background(), mock, resource.Scope{Region: "us-east-1"}, "rds", []*string{aws.String("test-standalone")}) require.NoError(t, err) }) t.Run("cluster member instance", func(t *testing.T) { - di := DBInstances{ - Client: mockedDBInstance{ - t: t, - DescribeDBInstancesOutput: rds.DescribeDBInstancesOutput{ - DBInstances: []types.DBInstance{ - { - DBInstanceIdentifier: aws.String("test-cluster-member"), - DBClusterIdentifier: aws.String("my-aurora-cluster"), // Part of a cluster - }, + mock := &mockDBInstanceClient{ + t: t, + DescribeDBInstancesOutput: rds.DescribeDBInstancesOutput{ + DBInstances: []types.DBInstance{ + { + DBInstanceIdentifier: aws.String("test-cluster-member"), + DBClusterIdentifier: aws.String("my-aurora-cluster"), // Part of a cluster }, }, - DeleteDBInstanceOutput: rds.DeleteDBInstanceOutput{}, - ModifyCallExpected: false, // Should NOT call ModifyDBInstance - InstancesDeleted: make(map[string]bool), // Track deleted instances }, + DeleteDBInstanceOutput: rds.DeleteDBInstanceOutput{}, + ModifyCallExpected: false, // Should NOT call ModifyDBInstance + InstancesDeleted: make(map[string]bool), // Track deleted instances } - di.Context = context.Background() - di.Timeout = DefaultWaitTimeout - err := di.nukeAll([]*string{aws.String("test-cluster-member")}) + err := deleteDBInstances(context.Background(), mock, resource.Scope{Region: "us-east-1"}, "rds", []*string{aws.String("test-cluster-member")}) require.NoError(t, err) }) } diff --git a/aws/resources/rds_types.go b/aws/resources/rds_types.go deleted file mode 100644 index e48b68f0..00000000 --- a/aws/resources/rds_types.go +++ /dev/null @@ -1,71 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/rds" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type RDSAPI interface { - ModifyDBInstance(ctx context.Context, params *rds.ModifyDBInstanceInput, optFns ...func(*rds.Options)) (*rds.ModifyDBInstanceOutput, error) - DeleteDBInstance(ctx context.Context, params *rds.DeleteDBInstanceInput, optFns ...func(*rds.Options)) (*rds.DeleteDBInstanceOutput, error) - DescribeDBInstances(ctx context.Context, params *rds.DescribeDBInstancesInput, optFns ...func(*rds.Options)) (*rds.DescribeDBInstancesOutput, error) -} -type DBInstances struct { - BaseAwsResource - Client RDSAPI - Region string - InstanceNames []string -} - -func (di *DBInstances) Init(cfg aws.Config) { - di.Client = rds.NewFromConfig(cfg) -} - -func (di *DBInstances) ResourceName() string { - return "rds" -} - -// ResourceIdentifiers - The instance names of the rds db instances -func (di *DBInstances) ResourceIdentifiers() []string { - return di.InstanceNames -} - -func (di *DBInstances) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (di *DBInstances) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.DBInstances.ResourceType -} - -func (di *DBInstances) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := di.getAll(c, configObj) - if err != nil { - return nil, err - } - - di.InstanceNames = aws.ToStringSlice(identifiers) - return di.InstanceNames, nil -} - -// Nuke - nuke 'em all!!! -func (di *DBInstances) Nuke(ctx context.Context, identifiers []string) error { - if err := di.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} - -type RdsDeleteError struct { - name string -} - -func (e RdsDeleteError) Error() string { - return "RDS DB Instance:" + e.name + "was not deleted" -} diff --git a/aws/resources/redshift.go b/aws/resources/redshift.go index 611b2280..bf92e2bd 100644 --- a/aws/resources/redshift.go +++ b/aws/resources/redshift.go @@ -2,22 +2,49 @@ package resources import ( "context" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" ) -func (rc *RedshiftClusters) getAll(ctx context.Context, configObj config.Config) ([]*string, error) { - var clusterIds []*string +// RedshiftClustersAPI defines the interface for Redshift Cluster operations. +type RedshiftClustersAPI interface { + DescribeClusters(ctx context.Context, params *redshift.DescribeClustersInput, optFns ...func(*redshift.Options)) (*redshift.DescribeClustersOutput, error) + DeleteCluster(ctx context.Context, params *redshift.DeleteClusterInput, optFns ...func(*redshift.Options)) (*redshift.DeleteClusterOutput, error) +} + +// NewRedshiftClusters creates a new RedshiftClusters resource using the generic resource pattern. +func NewRedshiftClusters() AwsResource { + return NewAwsResource(&resource.Resource[RedshiftClustersAPI]{ + ResourceTypeName: "redshift", + BatchSize: 49, + InitClient: func(r *resource.Resource[RedshiftClustersAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for Redshift client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = redshift.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.Redshift + }, + Lister: listRedshiftClusters, + Nuker: resource.SequentialDeleter(resource.DeleteThenWait(deleteRedshiftCluster, waitForRedshiftClusterDeleted)), + }) +} - // Initialize the paginator with any optional settings. - paginator := redshift.NewDescribeClustersPaginator(rc.Client, &redshift.DescribeClustersInput{}) +// listRedshiftClusters retrieves all Redshift clusters that match the config filters. +func listRedshiftClusters(ctx context.Context, client RedshiftClustersAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + var clusterIds []*string - // Use the paginator to go through each page of clusters. + paginator := redshift.NewDescribeClustersPaginator(client, &redshift.DescribeClustersInput{}) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -25,9 +52,8 @@ func (rc *RedshiftClusters) getAll(ctx context.Context, configObj config.Config) return nil, errors.WithStackTrace(err) } - // Process each cluster in the current page. for _, cluster := range output.Clusters { - if configObj.Redshift.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Time: cluster.ClusterCreateTime, Name: cluster.ClusterIdentifier, }) { @@ -39,46 +65,19 @@ func (rc *RedshiftClusters) getAll(ctx context.Context, configObj config.Config) return clusterIds, nil } -func (rc *RedshiftClusters) nukeAll(identifiers []*string) error { - if len(identifiers) == 0 { - logging.Debugf("No Redshift Clusters to nuke in region %s", rc.Region) - return nil - } - logging.Debugf("Deleting all Redshift Clusters in region %s", rc.Region) - deletedIds := []*string{} - for _, id := range identifiers { - _, err := rc.Client.DeleteCluster(rc.Context, &redshift.DeleteClusterInput{ - ClusterIdentifier: id, - SkipFinalClusterSnapshot: aws.Bool(true), - }) - if err != nil { - logging.Errorf("[Failed] %s: %s", *id, err) - } else { - deletedIds = append(deletedIds, id) - logging.Debugf("Deleted Redshift Cluster: %s", aws.ToString(id)) - } - } - - if len(deletedIds) > 0 { - for _, id := range deletedIds { - waiter := redshift.NewClusterDeletedWaiter(rc.Client) - err := waiter.Wait(rc.Context, &redshift.DescribeClustersInput{ - ClusterIdentifier: id, - }, rc.Timeout) +// deleteRedshiftCluster deletes a single Redshift cluster. +func deleteRedshiftCluster(ctx context.Context, client RedshiftClustersAPI, id *string) error { + _, err := client.DeleteCluster(ctx, &redshift.DeleteClusterInput{ + ClusterIdentifier: id, + SkipFinalClusterSnapshot: aws.Bool(true), + }) + return err +} - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(id), - ResourceType: "Redshift Cluster", - Error: err, - } - report.Record(e) - if err != nil { - logging.Errorf("[Failed] %s", err) - return errors.WithStackTrace(err) - } - } - } - logging.Debugf("[OK] %d Redshift Cluster(s) deleted in %s", len(deletedIds), rc.Region) - return nil +// waitForRedshiftClusterDeleted waits for a Redshift cluster to be deleted. +func waitForRedshiftClusterDeleted(ctx context.Context, client RedshiftClustersAPI, id *string) error { + waiter := redshift.NewClusterDeletedWaiter(client) + return waiter.Wait(ctx, &redshift.DescribeClustersInput{ + ClusterIdentifier: id, + }, 5*time.Minute) } diff --git a/aws/resources/redshift_test.go b/aws/resources/redshift_test.go index e276eca6..e0f89151 100644 --- a/aws/resources/redshift_test.go +++ b/aws/resources/redshift_test.go @@ -11,12 +11,11 @@ import ( "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/smithy-go" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) type mockedRedshift struct { - RedshiftClustersAPI - DeleteClusterOutput redshift.DeleteClusterOutput DescribeClustersOutput redshift.DescribeClustersOutput DescribeClusterError error @@ -30,28 +29,22 @@ func (m mockedRedshift) DeleteCluster(ctx context.Context, input *redshift.Delet return &m.DeleteClusterOutput, nil } -func (m mockedRedshift) WaitForOutput(ctx context.Context, params *redshift.DescribeClustersInput, maxWaitDur time.Duration, optFns ...func(*redshift.Options)) (*redshift.DescribeClustersOutput, error) { - return nil, nil -} func TestRedshiftCluster_GetAll(t *testing.T) { - t.Parallel() now := time.Now() testName1 := "test-cluster1" testName2 := "test-cluster2" - rc := RedshiftClusters{ - Client: mockedRedshift{ - DescribeClustersOutput: redshift.DescribeClustersOutput{ - Clusters: []types.Cluster{ - { - ClusterIdentifier: aws.String(testName1), - ClusterCreateTime: aws.Time(now), - }, - { - ClusterIdentifier: aws.String(testName2), - ClusterCreateTime: aws.Time(now.Add(1)), - }, + mock := mockedRedshift{ + DescribeClustersOutput: redshift.DescribeClustersOutput{ + Clusters: []types.Cluster{ + { + ClusterIdentifier: aws.String(testName1), + ClusterCreateTime: aws.Time(now), + }, + { + ClusterIdentifier: aws.String(testName2), + ClusterCreateTime: aws.Time(now.Add(1)), }, }, }, @@ -84,9 +77,7 @@ func TestRedshiftCluster_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := rc.getAll(context.Background(), config.Config{ - Redshift: tc.configObj, - }) + names, err := listRedshiftClusters(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) @@ -94,20 +85,20 @@ func TestRedshiftCluster_GetAll(t *testing.T) { } func TestRedshiftCluster_NukeAll(t *testing.T) { - t.Parallel() - rc := RedshiftClusters{ - Client: mockedRedshift{ - DeleteClusterOutput: redshift.DeleteClusterOutput{}, - DescribeClusterError: &smithy.GenericAPIError{ - Code: "ClusterNotFound", - }, + mock := mockedRedshift{ + DeleteClusterOutput: redshift.DeleteClusterOutput{}, + DescribeClusterError: &smithy.GenericAPIError{ + Code: "ClusterNotFound", }, } - rc.Context = context.Background() - rc.Timeout = DefaultWaitTimeout - err := rc.nukeAll([]*string{aws.String("test")}) + // Test the delete function + err := deleteRedshiftCluster(context.Background(), mock, aws.String("test")) + require.NoError(t, err) + + // Test the wait function (returns nil because ClusterNotFound means it's deleted) + err = waitForRedshiftClusterDeleted(context.Background(), mock, aws.String("test")) require.NoError(t, err) } diff --git a/aws/resources/redshift_types.go b/aws/resources/redshift_types.go deleted file mode 100644 index e4b95ccb..00000000 --- a/aws/resources/redshift_types.go +++ /dev/null @@ -1,62 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/redshift" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type RedshiftClustersAPI interface { - DescribeClusters(ctx context.Context, params *redshift.DescribeClustersInput, optFns ...func(*redshift.Options)) (*redshift.DescribeClustersOutput, error) - DeleteCluster(ctx context.Context, params *redshift.DeleteClusterInput, optFns ...func(*redshift.Options)) (*redshift.DeleteClusterOutput, error) -} -type RedshiftClusters struct { - BaseAwsResource - Client RedshiftClustersAPI - Region string - ClusterIdentifiers []string -} - -func (rc *RedshiftClusters) Init(cfg aws.Config) { - rc.Client = redshift.NewFromConfig(cfg) -} - -func (rc *RedshiftClusters) ResourceName() string { - return "redshift" -} - -// ResourceIdentifiers - The instance names of the rds db instances -func (rc *RedshiftClusters) ResourceIdentifiers() []string { - return rc.ClusterIdentifiers -} - -func (rc *RedshiftClusters) MaxBatchSize() int { - // Tentative batch size to ensure AWS doesn't throttle - return 49 -} - -func (rc *RedshiftClusters) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.Redshift -} - -func (rc *RedshiftClusters) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := rc.getAll(c, configObj) - if err != nil { - return nil, err - } - - rc.ClusterIdentifiers = aws.ToStringSlice(identifiers) - return rc.ClusterIdentifiers, nil -} - -// Nuke - nuke 'em all!!! -func (rc *RedshiftClusters) Nuke(ctx context.Context, identifiers []string) error { - if err := rc.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/security_hub.go b/aws/resources/security_hub.go index bddfb885..5d874be6 100644 --- a/aws/resources/security_hub.go +++ b/aws/resources/security_hub.go @@ -9,14 +9,49 @@ import ( "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/gruntwork-io/go-commons/errors" ) -func (sh *SecurityHub) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// SecurityHubAPI defines the interface for Security Hub operations. +type SecurityHubAPI interface { + DescribeHub(ctx context.Context, params *securityhub.DescribeHubInput, optFns ...func(*securityhub.Options)) (*securityhub.DescribeHubOutput, error) + ListMembers(ctx context.Context, params *securityhub.ListMembersInput, optFns ...func(*securityhub.Options)) (*securityhub.ListMembersOutput, error) + DisassociateMembers(ctx context.Context, params *securityhub.DisassociateMembersInput, optFns ...func(*securityhub.Options)) (*securityhub.DisassociateMembersOutput, error) + DeleteMembers(ctx context.Context, params *securityhub.DeleteMembersInput, optFns ...func(*securityhub.Options)) (*securityhub.DeleteMembersOutput, error) + GetAdministratorAccount(ctx context.Context, params *securityhub.GetAdministratorAccountInput, optFns ...func(*securityhub.Options)) (*securityhub.GetAdministratorAccountOutput, error) + DisassociateFromAdministratorAccount(ctx context.Context, params *securityhub.DisassociateFromAdministratorAccountInput, optFns ...func(*securityhub.Options)) (*securityhub.DisassociateFromAdministratorAccountOutput, error) + DisableSecurityHub(ctx context.Context, params *securityhub.DisableSecurityHubInput, optFns ...func(*securityhub.Options)) (*securityhub.DisableSecurityHubOutput, error) +} + +// NewSecurityHub creates a new SecurityHub resource using the generic resource pattern. +func NewSecurityHub() AwsResource { + return NewAwsResource(&resource.Resource[SecurityHubAPI]{ + ResourceTypeName: "security-hub", + BatchSize: 5, + InitClient: func(r *resource.Resource[SecurityHubAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for SecurityHub client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = securityhub.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.SecurityHub + }, + Lister: listSecurityHubs, + Nuker: deleteSecurityHubs, + }) +} + +// listSecurityHubs retrieves all Security Hubs that match the config filters. +func listSecurityHubs(ctx context.Context, client SecurityHubAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { var securityHubArns []*string - output, err := sh.Client.DescribeHub(sh.Context, &securityhub.DescribeHubInput{}) + output, err := client.DescribeHub(ctx, &securityhub.DescribeHubInput{}) if err != nil { // If Security Hub is not enabled when we call DescribeHub, we get back an error @@ -28,14 +63,14 @@ func (sh *SecurityHub) getAll(c context.Context, configObj config.Config) ([]*st return nil, errors.WithStackTrace(err) } - if shouldIncludeHub(output, configObj) { + if shouldIncludeSecurityHub(output, cfg) { securityHubArns = append(securityHubArns, output.HubArn) } return securityHubArns, nil } -func shouldIncludeHub(hub *securityhub.DescribeHubOutput, configObj config.Config) bool { +func shouldIncludeSecurityHub(hub *securityhub.DescribeHubOutput, cfg config.ResourceType) bool { subscribedAt, err := util.ParseTimestamp(hub.SubscribedAt) if err != nil { logging.Debugf( @@ -43,14 +78,14 @@ func shouldIncludeHub(hub *securityhub.DescribeHubOutput, configObj config.Confi return false } - return configObj.SecurityHub.ShouldInclude(config.ResourceValue{Time: subscribedAt}) + return cfg.ShouldInclude(config.ResourceValue{Time: subscribedAt}) } -func (sh *SecurityHub) getAllSecurityHubMembers() ([]*string, error) { +func getAllSecurityHubMembers(ctx context.Context, client SecurityHubAPI) ([]*string, error) { var hubMemberAccountIds []*string // OnlyAssociated=false input parameter includes "pending" invite members - members, err := sh.Client.ListMembers(sh.Context, &securityhub.ListMembersInput{OnlyAssociated: aws.Bool(false)}) + members, err := client.ListMembers(ctx, &securityhub.ListMembersInput{OnlyAssociated: aws.Bool(false)}) if err != nil { return nil, errors.WithStackTrace(err) } @@ -59,7 +94,7 @@ func (sh *SecurityHub) getAllSecurityHubMembers() ([]*string, error) { } for aws.ToString(members.NextToken) != "" { - members, err = sh.Client.ListMembers(sh.Context, &securityhub.ListMembersInput{NextToken: members.NextToken}) + members, err = client.ListMembers(ctx, &securityhub.ListMembersInput{NextToken: members.NextToken}) if err != nil { return nil, errors.WithStackTrace(err) } @@ -71,17 +106,16 @@ func (sh *SecurityHub) getAllSecurityHubMembers() ([]*string, error) { return hubMemberAccountIds, nil } -func (sh *SecurityHub) removeMembersFromHub(accountIds []*string) error { - +func removeMembersFromSecurityHub(ctx context.Context, client SecurityHubAPI, accountIds []*string) error { // Member accounts must first be disassociated - _, err := sh.Client.DisassociateMembers(sh.Context, &securityhub.DisassociateMembersInput{AccountIds: aws.ToStringSlice(accountIds)}) + _, err := client.DisassociateMembers(ctx, &securityhub.DisassociateMembersInput{AccountIds: aws.ToStringSlice(accountIds)}) if err != nil { return err } logging.Debugf("%d member accounts disassociated", len(accountIds)) // Once disassociated, member accounts can be deleted - _, err = sh.Client.DeleteMembers(sh.Context, &securityhub.DeleteMembersInput{AccountIds: aws.ToStringSlice(accountIds)}) + _, err = client.DeleteMembers(ctx, &securityhub.DeleteMembersInput{AccountIds: aws.ToStringSlice(accountIds)}) if err != nil { return err } @@ -90,22 +124,23 @@ func (sh *SecurityHub) removeMembersFromHub(accountIds []*string) error { return nil } -func (sh *SecurityHub) nukeAll(securityHubArns []string) error { - if len(securityHubArns) == 0 { - logging.Debugf("No security hub resources to nuke in region %s", sh.Region) +// deleteSecurityHubs is a custom nuker for Security Hub resources. +func deleteSecurityHubs(ctx context.Context, client SecurityHubAPI, scope resource.Scope, resourceType string, identifiers []*string) error { + if len(identifiers) == 0 { + logging.Debugf("No security hub resources to nuke in %s", scope) return nil } // Check for any member accounts in security hub // Security Hub cannot be disabled with active member accounts - memberAccountIds, err := sh.getAllSecurityHubMembers() + memberAccountIds, err := getAllSecurityHubMembers(ctx, client) if err != nil { return err } // Remove any member accounts if they exist if len(memberAccountIds) > 0 { - err = sh.removeMembersFromHub(memberAccountIds) + err = removeMembersFromSecurityHub(ctx, client, memberAccountIds) if err != nil { logging.Errorf("[Failed] Failed to disassociate members from security hub") } @@ -113,34 +148,34 @@ func (sh *SecurityHub) nukeAll(securityHubArns []string) error { // Check for an administrator account // Security hub cannot be disabled with an active administrator account - adminAccount, err := sh.Client.GetAdministratorAccount(sh.Context, &securityhub.GetAdministratorAccountInput{}) + adminAccount, err := client.GetAdministratorAccount(ctx, &securityhub.GetAdministratorAccountInput{}) if err != nil { logging.Errorf("[Failed] Failed to check for administrator account") } // Disassociate administrator account if it exists if adminAccount.Administrator != nil { - _, err := sh.Client.DisassociateFromAdministratorAccount(sh.Context, &securityhub.DisassociateFromAdministratorAccountInput{}) + _, err = client.DisassociateFromAdministratorAccount(ctx, &securityhub.DisassociateFromAdministratorAccountInput{}) if err != nil { logging.Errorf("[Failed] Failed to disassociate from administrator account") } } // Disable security hub - _, err = sh.Client.DisableSecurityHub(sh.Context, &securityhub.DisableSecurityHubInput{}) + _, err = client.DisableSecurityHub(ctx, &securityhub.DisableSecurityHubInput{}) if err != nil { logging.Errorf("[Failed] Failed to disable security hub.") e := report.Entry{ - Identifier: aws.ToString(&securityHubArns[0]), - ResourceType: "Security Hub", + Identifier: aws.ToString(identifiers[0]), + ResourceType: resourceType, Error: err, } report.Record(e) } else { - logging.Debugf("[OK] Security Hub %s disabled", securityHubArns[0]) + logging.Debugf("[OK] Security Hub %s disabled", aws.ToString(identifiers[0])) e := report.Entry{ - Identifier: aws.ToString(&securityHubArns[0]), - ResourceType: "Security Hub", + Identifier: aws.ToString(identifiers[0]), + ResourceType: resourceType, } report.Record(e) } diff --git a/aws/resources/security_hub_test.go b/aws/resources/security_hub_test.go index 24b2014f..02dc4e19 100644 --- a/aws/resources/security_hub_test.go +++ b/aws/resources/security_hub_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/securityhub" @@ -59,12 +60,10 @@ func TestSecurityHub_GetAll(t *testing.T) { now := time.Now() nowStr := now.Format(time.RFC3339) testArn := "test-arn" - sh := SecurityHub{ - Client: mockedSecurityHub{ - DescribeHubOutput: securityhub.DescribeHubOutput{ - SubscribedAt: &nowStr, - HubArn: aws.String(testArn), - }, + client := mockedSecurityHub{ + DescribeHubOutput: securityhub.DescribeHubOutput{ + SubscribedAt: &nowStr, + HubArn: aws.String(testArn), }, } @@ -86,9 +85,7 @@ func TestSecurityHub_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := sh.getAll(context.Background(), config.Config{ - SecurityHub: tc.configObj, - }) + names, err := listSecurityHubs(context.Background(), client, resource.Scope{Region: "us-east-1"}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) @@ -99,25 +96,23 @@ func TestSecurityHub_NukeAll(t *testing.T) { t.Parallel() - sh := SecurityHub{ - Client: mockedSecurityHub{ - ListMembersOutput: securityhub.ListMembersOutput{ - Members: []types.Member{{ - AccountId: aws.String("123456789012"), - }}, - }, - DisassociateMembersOutput: securityhub.DisassociateMembersOutput{}, - DeleteMembersOutput: securityhub.DeleteMembersOutput{}, - GetAdministratorAccountOutput: securityhub.GetAdministratorAccountOutput{ - Administrator: &types.Invitation{ - AccountId: aws.String("123456789012"), - }, + client := mockedSecurityHub{ + ListMembersOutput: securityhub.ListMembersOutput{ + Members: []types.Member{{ + AccountId: aws.String("123456789012"), + }}, + }, + DisassociateMembersOutput: securityhub.DisassociateMembersOutput{}, + DeleteMembersOutput: securityhub.DeleteMembersOutput{}, + GetAdministratorAccountOutput: securityhub.GetAdministratorAccountOutput{ + Administrator: &types.Invitation{ + AccountId: aws.String("123456789012"), }, - DisassociateFromAdministratorAccountOutput: securityhub.DisassociateFromAdministratorAccountOutput{}, - DisableSecurityHubOutput: securityhub.DisableSecurityHubOutput{}, }, + DisassociateFromAdministratorAccountOutput: securityhub.DisassociateFromAdministratorAccountOutput{}, + DisableSecurityHubOutput: securityhub.DisableSecurityHubOutput{}, } - err := sh.nukeAll([]string{"123456789012"}) + err := deleteSecurityHubs(context.Background(), client, resource.Scope{Region: "us-east-1"}, "security-hub", []*string{aws.String("123456789012")}) require.NoError(t, err) } diff --git a/aws/resources/security_hub_types.go b/aws/resources/security_hub_types.go deleted file mode 100644 index 7557021d..00000000 --- a/aws/resources/security_hub_types.go +++ /dev/null @@ -1,65 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/securityhub" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type SecurityHubAPI interface { - DescribeHub(ctx context.Context, params *securityhub.DescribeHubInput, optFns ...func(*securityhub.Options)) (*securityhub.DescribeHubOutput, error) - ListMembers(ctx context.Context, params *securityhub.ListMembersInput, optFns ...func(*securityhub.Options)) (*securityhub.ListMembersOutput, error) - DisassociateMembers(ctx context.Context, params *securityhub.DisassociateMembersInput, optFns ...func(*securityhub.Options)) (*securityhub.DisassociateMembersOutput, error) - DeleteMembers(ctx context.Context, params *securityhub.DeleteMembersInput, optFns ...func(*securityhub.Options)) (*securityhub.DeleteMembersOutput, error) - GetAdministratorAccount(ctx context.Context, params *securityhub.GetAdministratorAccountInput, optFns ...func(*securityhub.Options)) (*securityhub.GetAdministratorAccountOutput, error) - DisassociateFromAdministratorAccount(ctx context.Context, params *securityhub.DisassociateFromAdministratorAccountInput, optFns ...func(*securityhub.Options)) (*securityhub.DisassociateFromAdministratorAccountOutput, error) - DisableSecurityHub(ctx context.Context, params *securityhub.DisableSecurityHubInput, optFns ...func(*securityhub.Options)) (*securityhub.DisableSecurityHubOutput, error) -} - -type SecurityHub struct { - BaseAwsResource - Client SecurityHubAPI - Region string - HubArns []string -} - -func (sh *SecurityHub) Init(cfg aws.Config) { - sh.Client = securityhub.NewFromConfig(cfg) -} - -func (sh *SecurityHub) ResourceName() string { - return "security-hub" -} - -func (sh *SecurityHub) ResourceIdentifiers() []string { - return sh.HubArns -} - -func (sh *SecurityHub) MaxBatchSize() int { - return 5 -} - -func (sh *SecurityHub) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.SecurityHub -} - -func (sh *SecurityHub) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := sh.getAll(c, configObj) - if err != nil { - return nil, err - } - - sh.HubArns = aws.ToStringSlice(identifiers) - return sh.HubArns, nil -} - -func (sh *SecurityHub) Nuke(ctx context.Context, identifiers []string) error { - if err := sh.nukeAll(identifiers); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/tgw_vpc_attachment.go b/aws/resources/tgw_vpc_attachment.go index 2b07799c..30b1cb88 100644 --- a/aws/resources/tgw_vpc_attachment.go +++ b/aws/resources/tgw_vpc_attachment.go @@ -10,25 +10,53 @@ import ( goerror "github.com/go-errors/errors" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/go-commons/errors" ) -// Returns a formatted string of TransitGatewayVpcAttachment IDs -func (tgw *TransitGatewaysVpcAttachment) getAll(ctx context.Context, configObj config.Config) ([]*string, error) { +// TransitGatewaysVpcAttachmentAPI defines the interface for TransitGateway VPC Attachment operations. +type TransitGatewaysVpcAttachmentAPI interface { + DeleteTransitGatewayVpcAttachment(ctx context.Context, params *ec2.DeleteTransitGatewayVpcAttachmentInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayVpcAttachmentOutput, error) + DescribeTransitGatewayVpcAttachments(ctx context.Context, params *ec2.DescribeTransitGatewayVpcAttachmentsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTransitGatewayVpcAttachmentsOutput, error) +} + +// NewTransitGatewaysVpcAttachment creates a new TransitGatewaysVpcAttachment resource using the generic resource pattern. +func NewTransitGatewaysVpcAttachment() AwsResource { + return NewAwsResource(&resource.Resource[TransitGatewaysVpcAttachmentAPI]{ + ResourceTypeName: "transit-gateway-attachment", + BatchSize: maxBatchSize, + InitClient: func(r *resource.Resource[TransitGatewaysVpcAttachmentAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for EC2 client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = ec2.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.TransitGatewaysVpcAttachment + }, + Lister: listTransitGatewaysVpcAttachments, + Nuker: resource.SequentialDeleteThenWaitAll(deleteTransitGatewayVpcAttachment, waitForTransitGatewayAttachmentsToBeDeleted), + }) +} + +// listTransitGatewaysVpcAttachments retrieves all Transit Gateway VPC Attachments that match the config filters. +func listTransitGatewaysVpcAttachments(ctx context.Context, client TransitGatewaysVpcAttachmentAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { var identifiers []*string params := &ec2.DescribeTransitGatewayVpcAttachmentsInput{} hasMorePages := true for hasMorePages { - result, err := tgw.Client.DescribeTransitGatewayVpcAttachments(ctx, params) + result, err := client.DescribeTransitGatewayVpcAttachments(ctx, params) if err != nil { logging.Debugf("[Transit Gateway] Failed to list transit gateway VPC attachments: %s", err) return nil, errors.WithStackTrace(err) } for _, tgwVpcAttachment := range result.TransitGatewayVpcAttachments { - if configObj.TransitGatewaysVpcAttachment.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Time: tgwVpcAttachment.CreationTime, }) && tgwVpcAttachment.State != "deleted" && tgwVpcAttachment.State != "deleting" { identifiers = append(identifiers, tgwVpcAttachment.TransitGatewayAttachmentId) @@ -42,51 +70,25 @@ func (tgw *TransitGatewaysVpcAttachment) getAll(ctx context.Context, configObj c return identifiers, nil } -// Delete all TransitGatewayVpcAttachments -func (tgw *TransitGatewaysVpcAttachment) nukeAll(ids []*string) error { - if len(ids) == 0 { - logging.Debugf("No Transit Gateway Vpc Attachments to nuke in region %s", tgw.Region) - return nil - } - - logging.Debugf("Deleting all Transit Gateway Vpc Attachments in region %s", tgw.Region) - var deletedIds []*string - - for _, id := range ids { - param := &ec2.DeleteTransitGatewayVpcAttachmentInput{ - TransitGatewayAttachmentId: id, - } - - _, err := tgw.Client.DeleteTransitGatewayVpcAttachment(tgw.Context, param) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(id), - ResourceType: tgw.ResourceName(), - Error: err, - } - report.Record(e) - - if err != nil { - logging.Debugf("[Failed] %s", err) - } else { - deletedIds = append(deletedIds, id) - logging.Debugf("Deleted Transit Gateway Vpc Attachment: %s", *id) - } +// deleteTransitGatewayVpcAttachment deletes a single Transit Gateway VPC Attachment. +func deleteTransitGatewayVpcAttachment(ctx context.Context, client TransitGatewaysVpcAttachmentAPI, id *string) error { + param := &ec2.DeleteTransitGatewayVpcAttachmentInput{ + TransitGatewayAttachmentId: id, } - if waiterr := waitForTransitGatewayAttachmentToBeDeleted(*tgw); waiterr != nil { - return errors.WithStackTrace(waiterr) + _, err := client.DeleteTransitGatewayVpcAttachment(ctx, param) + if err != nil { + return errors.WithStackTrace(err) } - logging.Debugf("[OK] %d Transit Gateway Vpc Attachment(s) deleted in %s", len(deletedIds), tgw.Region) return nil } -func waitForTransitGatewayAttachmentToBeDeleted(tgw TransitGatewaysVpcAttachment) error { +// waitForTransitGatewayAttachmentsToBeDeleted waits for all Transit Gateway attachments to be deleted. +func waitForTransitGatewayAttachmentsToBeDeleted(ctx context.Context, client TransitGatewaysVpcAttachmentAPI, ids []string) error { for i := 0; i < 30; i++ { - gateways, err := tgw.Client.DescribeTransitGatewayVpcAttachments( - tgw.Context, &ec2.DescribeTransitGatewayVpcAttachmentsInput{ - TransitGatewayAttachmentIds: tgw.Ids, + gateways, err := client.DescribeTransitGatewayVpcAttachments( + ctx, &ec2.DescribeTransitGatewayVpcAttachmentsInput{ + TransitGatewayAttachmentIds: ids, Filters: []types.Filter{ { Name: aws.String("state"), diff --git a/aws/resources/tgw_vpc_attachment_test.go b/aws/resources/tgw_vpc_attachment_test.go index 538f659d..ff1a8362 100644 --- a/aws/resources/tgw_vpc_attachment_test.go +++ b/aws/resources/tgw_vpc_attachment_test.go @@ -9,47 +9,47 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) -type mockedTransitGatewayVpcAttachment struct { +type mockTransitGatewayVpcAttachmentClient struct { DescribeTransitGatewayVpcAttachmentsOutput ec2.DescribeTransitGatewayVpcAttachmentsOutput DeleteTransitGatewayVpcAttachmentOutput ec2.DeleteTransitGatewayVpcAttachmentOutput } -func (m mockedTransitGatewayVpcAttachment) DescribeTransitGatewayVpcAttachments(ctx context.Context, params *ec2.DescribeTransitGatewayVpcAttachmentsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTransitGatewayVpcAttachmentsOutput, error) { +func (m *mockTransitGatewayVpcAttachmentClient) DescribeTransitGatewayVpcAttachments(ctx context.Context, params *ec2.DescribeTransitGatewayVpcAttachmentsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTransitGatewayVpcAttachmentsOutput, error) { return &m.DescribeTransitGatewayVpcAttachmentsOutput, nil } -func (m mockedTransitGatewayVpcAttachment) DeleteTransitGatewayVpcAttachment(ctx context.Context, params *ec2.DeleteTransitGatewayVpcAttachmentInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayVpcAttachmentOutput, error) { +func (m *mockTransitGatewayVpcAttachmentClient) DeleteTransitGatewayVpcAttachment(ctx context.Context, params *ec2.DeleteTransitGatewayVpcAttachmentInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayVpcAttachmentOutput, error) { return &m.DeleteTransitGatewayVpcAttachmentOutput, nil } -func TestTransitGatewayVpcAttachments_GetAll(t *testing.T) { - +func TestListTransitGatewayVpcAttachments(t *testing.T) { t.Parallel() now := time.Now() attachment1 := "attachment1" attachment2 := "attachment2" - tgw := TransitGatewaysVpcAttachment{ - Client: mockedTransitGatewayVpcAttachment{ - DescribeTransitGatewayVpcAttachmentsOutput: ec2.DescribeTransitGatewayVpcAttachmentsOutput{ - TransitGatewayVpcAttachments: []types.TransitGatewayVpcAttachment{ - { - TransitGatewayAttachmentId: aws.String(attachment1), - CreationTime: aws.Time(now), - State: types.TransitGatewayAttachmentStateAvailable, - }, - { - TransitGatewayAttachmentId: aws.String(attachment2), - CreationTime: aws.Time(now.Add(1 * time.Hour)), - State: types.TransitGatewayAttachmentStateDeleting, - }, + + mock := &mockTransitGatewayVpcAttachmentClient{ + DescribeTransitGatewayVpcAttachmentsOutput: ec2.DescribeTransitGatewayVpcAttachmentsOutput{ + TransitGatewayVpcAttachments: []types.TransitGatewayVpcAttachment{ + { + TransitGatewayAttachmentId: aws.String(attachment1), + CreationTime: aws.Time(now), + State: types.TransitGatewayAttachmentStateAvailable, + }, + { + TransitGatewayAttachmentId: aws.String(attachment2), + CreationTime: aws.Time(now.Add(1 * time.Hour)), + State: types.TransitGatewayAttachmentStateDeleting, }, }, }, } + tests := map[string]struct { configObj config.ResourceType expected []string @@ -68,25 +68,20 @@ func TestTransitGatewayVpcAttachments_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := tgw.getAll(context.Background(), config.Config{ - TransitGatewaysVpcAttachment: tc.configObj, - }) + names, err := listTransitGatewaysVpcAttachments(context.Background(), mock, resource.Scope{}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) } } -func TestTransitGatewayVpcAttachments_NukeAll(t *testing.T) { - +func TestDeleteTransitGatewayVpcAttachment(t *testing.T) { t.Parallel() - tgw := TransitGatewaysVpcAttachment{ - Client: mockedTransitGatewayVpcAttachment{ - DeleteTransitGatewayVpcAttachmentOutput: ec2.DeleteTransitGatewayVpcAttachmentOutput{}, - }, + mock := &mockTransitGatewayVpcAttachmentClient{ + DeleteTransitGatewayVpcAttachmentOutput: ec2.DeleteTransitGatewayVpcAttachmentOutput{}, } - err := tgw.nukeAll([]*string{aws.String("test-attachment")}) + err := deleteTransitGatewayVpcAttachment(context.Background(), mock, aws.String("test-attachment")) require.NoError(t, err) } diff --git a/aws/resources/tgw_vpc_attachment_types.go b/aws/resources/tgw_vpc_attachment_types.go deleted file mode 100644 index 65df8f60..00000000 --- a/aws/resources/tgw_vpc_attachment_types.go +++ /dev/null @@ -1,61 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type TGWVpcAPI interface { - DeleteTransitGatewayVpcAttachment(ctx context.Context, params *ec2.DeleteTransitGatewayVpcAttachmentInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayVpcAttachmentOutput, error) - DescribeTransitGatewayVpcAttachments(ctx context.Context, params *ec2.DescribeTransitGatewayVpcAttachmentsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTransitGatewayVpcAttachmentsOutput, error) -} - -// TransitGatewaysVpcAttachment represents all transit gateway VPC attachments. -type TransitGatewaysVpcAttachment struct { - BaseAwsResource - Client TGWVpcAPI - Region string - Ids []string -} - -func (tgw *TransitGatewaysVpcAttachment) Init(cfg aws.Config) { - tgw.Client = ec2.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (tgw *TransitGatewaysVpcAttachment) ResourceName() string { - return "transit-gateway-attachment" -} - -// MaxBatchSize - Tentative batch size to ensure AWS doesn't throttle -func (tgw *TransitGatewaysVpcAttachment) MaxBatchSize() int { - return maxBatchSize -} - -// ResourceIdentifiers - The Ids of the transit gateways -func (tgw *TransitGatewaysVpcAttachment) ResourceIdentifiers() []string { - return tgw.Ids -} - -func (tgw *TransitGatewaysVpcAttachment) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := tgw.getAll(c, configObj) - if err != nil { - return nil, err - } - - tgw.Ids = aws.ToStringSlice(identifiers) - return tgw.Ids, nil -} - -// Nuke - nuke 'em all!!! -func (tgw *TransitGatewaysVpcAttachment) Nuke(ctx context.Context, identifiers []string) error { - if err := tgw.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/transit_gateway.go b/aws/resources/transit_gateway.go index 8f9d6d30..b8fc15c2 100644 --- a/aws/resources/transit_gateway.go +++ b/aws/resources/transit_gateway.go @@ -11,6 +11,7 @@ import ( "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" "github.com/gruntwork-io/go-commons/errors" ) @@ -21,63 +22,141 @@ const ( TransitGatewayAttachmentTypeConnect = "connect" ) -// [Note 1] : NOTE on the Apporach used:-Using the `dry run` approach on verifying the nuking permission in case of a scoped IAM role. +// TransitGatewaysAPI defines the interface for Transit Gateway operations. +type TransitGatewaysAPI interface { + DescribeTransitGateways(ctx context.Context, params *ec2.DescribeTransitGatewaysInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTransitGatewaysOutput, error) + DeleteTransitGateway(ctx context.Context, params *ec2.DeleteTransitGatewayInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayOutput, error) + DescribeTransitGatewayAttachments(ctx context.Context, params *ec2.DescribeTransitGatewayAttachmentsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTransitGatewayAttachmentsOutput, error) + DeleteTransitGatewayPeeringAttachment(ctx context.Context, params *ec2.DeleteTransitGatewayPeeringAttachmentInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayPeeringAttachmentOutput, error) + DeleteTransitGatewayVpcAttachment(ctx context.Context, params *ec2.DeleteTransitGatewayVpcAttachmentInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayVpcAttachmentOutput, error) + DeleteTransitGatewayConnect(ctx context.Context, params *ec2.DeleteTransitGatewayConnectInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayConnectOutput, error) +} + +// NewTransitGateways creates a new TransitGateways resource using the generic resource pattern. +func NewTransitGateways() AwsResource { + return NewAwsResource(&resource.Resource[TransitGatewaysAPI]{ + ResourceTypeName: "transit-gateway", + BatchSize: maxBatchSize, + InitClient: func(r *resource.Resource[TransitGatewaysAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for EC2 client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = ec2.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.TransitGateway + }, + Lister: listTransitGateways, + Nuker: deleteTransitGateways, + }) +} + +// [Note 1] : NOTE on the Approach used:-Using the `dry run` approach on verifying the nuking permission in case of a scoped IAM role. // IAM:simulateCustomPolicy : could also be used but the IAM role itself needs permission for simulateCustomPolicy method -//else this would not get the desired result. Also in case of multiple t-gateway, if only some has permssion to be nuked, +// else this would not get the desired result. Also in case of multiple t-gateway, if only some has permission to be nuked, // the t-gateway resource ids needs to be passed individually inside the IAM:simulateCustomPolicy to get the desired result, -// else all would result in `Implicit-deny` as response- this might increase the time complexity.Using dry run to avoid this. - -// Returns a formatted string of TransitGateway IDs -func (tgw *TransitGateways) getAll(c context.Context, configObj config.Config) ([]*string, error) { +// else all would result in `Implicit-deny` as response- this might increase the time complexity. Using dry run to avoid this. - result, err := tgw.Client.DescribeTransitGateways(tgw.Context, &ec2.DescribeTransitGatewaysInput{}) +// listTransitGateways returns a formatted string of TransitGateway IDs +func listTransitGateways(ctx context.Context, client TransitGatewaysAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + result, err := client.DescribeTransitGateways(ctx, &ec2.DescribeTransitGatewaysInput{}) if err != nil { logging.Debugf("[DescribeTransitGateways Failed] %s", err) return nil, errors.WithStackTrace(err) } - currentOwner := c.Value(util.AccountIdKey) + currentOwner := ctx.Value(util.AccountIdKey) var ids []*string for _, transitGateway := range result.TransitGateways { - hostNameTagValue := util.GetEC2ResourceNameTagValue(transitGateway.Tags) + // Skip deleted/deleting transit gateways + if transitGateway.State == types.TransitGatewayStateDeleted || transitGateway.State == types.TransitGatewayStateDeleting { + continue + } + + // Skip if owned by a different account + if currentOwner != nil && transitGateway.OwnerId != nil && currentOwner != aws.ToString(transitGateway.OwnerId) { + logging.Debugf("[Skipping] Transit Gateway %s owned by different account", aws.ToString(transitGateway.TransitGatewayId)) + continue + } - if configObj.TransitGateway.ShouldInclude(config.ResourceValue{ + hostNameTagValue := util.GetEC2ResourceNameTagValue(transitGateway.Tags) + if !cfg.ShouldInclude(config.ResourceValue{ Time: transitGateway.CreationTime, Name: hostNameTagValue, - }) && - transitGateway.State != types.TransitGatewayStateDeleted && transitGateway.State != types.TransitGatewayStateDeleting { - ids = append(ids, transitGateway.TransitGatewayId) + }) { + continue } - if currentOwner != nil && transitGateway.OwnerId != nil && currentOwner != aws.ToString(transitGateway.OwnerId) { - tgw.SetNukableStatus(*transitGateway.TransitGatewayId, util.ErrDifferentOwner) + // Verify permission using dry-run before including in the list + if err := verifyTransitGatewayNukePermission(ctx, client, transitGateway.TransitGatewayId); err != nil { + logging.Debugf("[Skipping] Transit Gateway %s: no permission to nuke: %v", aws.ToString(transitGateway.TransitGatewayId), err) continue } + + ids = append(ids, transitGateway.TransitGatewayId) } - // Check and verfiy the list of allowed nuke actions - // VerifyNukablePermissions is used to iterate over a list of Transit Gateway IDs (ids) and execute a provided function (func(id *string) error). - // The function, attempts to delete a Transit Gateway with the specified ID in a dry-run mode (checking permissions without actually performing the delete operation). The result of this operation (error or success) is then captured. - // See more at [Note 1] - tgw.VerifyNukablePermissions(ids, func(id *string) error { - params := &ec2.DeleteTransitGatewayInput{ - TransitGatewayId: id, - DryRun: aws.Bool(true), // dry run set as true , checks permission without actualy making the request + return ids, nil +} + +// verifyTransitGatewayNukePermission performs a dry-run delete to check permissions. +// Returns nil if the user has permission, otherwise returns an error. +func verifyTransitGatewayNukePermission(ctx context.Context, client TransitGatewaysAPI, id *string) error { + params := &ec2.DeleteTransitGatewayInput{ + TransitGatewayId: id, + DryRun: aws.Bool(true), // dry run set as true, checks permission without actually making the request + } + _, err := client.DeleteTransitGateway(ctx, params) + if err != nil { + return util.TransformAWSError(err) + } + return nil +} + +// deleteTransitGateways deletes all TransitGateways. +// All resources passed to this function have already been verified for permissions in the lister. +func deleteTransitGateways(ctx context.Context, client TransitGatewaysAPI, scope resource.Scope, resourceType string, ids []*string) error { + if len(ids) == 0 { + logging.Debugf("No Transit Gateways to nuke in region %s", scope.Region) + return nil + } + + logging.Debugf("Deleting all Transit Gateways in region %s", scope.Region) + var deletedIds []*string + + for _, id := range ids { + err := nukeTransitGateway(ctx, client, id) + + // Record status of this resource + e := report.Entry{ + Identifier: aws.ToString(id), + ResourceType: "Transit Gateway", + Error: err, } - _, err := tgw.Client.DeleteTransitGateway(tgw.Context, params) - return err - }) + report.Record(e) - return ids, nil + if err != nil { + logging.Debugf("[Failed] %s", err) + } else { + deletedIds = append(deletedIds, id) + logging.Debugf("Deleted Transit Gateway: %s", *id) + } + } + + logging.Debugf("[OK] %d Transit Gateway(s) deleted in %s", len(deletedIds), scope.Region) + return nil } -func (tgw *TransitGateways) nuke(id *string) error { +func nukeTransitGateway(ctx context.Context, client TransitGatewaysAPI, id *string) error { // check the transit gateway has attachments and nuke them before - if err := tgw.nukeAttachments(id); err != nil { + if err := nukeTransitGatewayAttachments(ctx, client, id); err != nil { return errors.WithStackTrace(err) } - if _, err := tgw.Client.DeleteTransitGateway(tgw.Context, &ec2.DeleteTransitGatewayInput{ + if _, err := client.DeleteTransitGateway(ctx, &ec2.DeleteTransitGatewayInput{ TransitGatewayId: id, }); err != nil { return errors.WithStackTrace(err) @@ -86,9 +165,9 @@ func (tgw *TransitGateways) nuke(id *string) error { return nil } -func (tgw *TransitGateways) nukeAttachments(id *string) error { +func nukeTransitGatewayAttachments(ctx context.Context, client TransitGatewaysAPI, id *string) error { logging.Debugf("nuking transit gateway attachments for %v", aws.ToString(id)) - output, err := tgw.Client.DescribeTransitGatewayAttachments(tgw.Context, &ec2.DescribeTransitGatewayAttachmentsInput{ + output, err := client.DescribeTransitGatewayAttachments(ctx, &ec2.DescribeTransitGatewayAttachmentsInput{ Filters: []types.Filter{ { Name: aws.String("transit-gateway-id"), @@ -105,7 +184,7 @@ func (tgw *TransitGateways) nukeAttachments(id *string) error { }, }) if err != nil { - logging.Errorf("[Failed] unable to describe the transit gateway attachments for %v : %s", aws.ToString(id), err) + logging.Errorf("[Failed] unable to describe the transit gateway attachments for %v : %s", aws.ToString(id), err) return errors.WithStackTrace(err) } @@ -113,7 +192,7 @@ func (tgw *TransitGateways) nukeAttachments(id *string) error { for _, attachments := range output.TransitGatewayAttachments { var ( - err error + attachmentErr error attachmentType = attachments.ResourceType now = time.Now() ) @@ -122,29 +201,29 @@ func (tgw *TransitGateways) nukeAttachments(id *string) error { case TransitGatewayAttachmentTypePeering: logging.Debugf("[Execution] deleting the attachments of type %v for %v ", attachmentType, aws.ToString(id)) // Delete the Transit Gateway Peering Attachment - _, err = tgw.Client.DeleteTransitGatewayPeeringAttachment(tgw.Context, &ec2.DeleteTransitGatewayPeeringAttachmentInput{ + _, attachmentErr = client.DeleteTransitGatewayPeeringAttachment(ctx, &ec2.DeleteTransitGatewayPeeringAttachmentInput{ TransitGatewayAttachmentId: attachments.TransitGatewayAttachmentId, }) case TransitGatewayAttachmentTypeVPC: logging.Debugf("[Execution] deleting the attachments of type %v for %v ", attachmentType, aws.ToString(id)) // Delete the Transit Gateway VPC Attachment - _, err = tgw.Client.DeleteTransitGatewayVpcAttachment(tgw.Context, &ec2.DeleteTransitGatewayVpcAttachmentInput{ + _, attachmentErr = client.DeleteTransitGatewayVpcAttachment(ctx, &ec2.DeleteTransitGatewayVpcAttachmentInput{ TransitGatewayAttachmentId: attachments.TransitGatewayAttachmentId, }) case TransitGatewayAttachmentTypeConnect: logging.Debugf("[Execution] deleting the attachments of type %v for %v ", attachmentType, aws.ToString(id)) // Delete the Transit Gateway Connect Attachment - _, err = tgw.Client.DeleteTransitGatewayConnect(tgw.Context, &ec2.DeleteTransitGatewayConnectInput{ + _, attachmentErr = client.DeleteTransitGatewayConnect(ctx, &ec2.DeleteTransitGatewayConnectInput{ TransitGatewayAttachmentId: attachments.TransitGatewayAttachmentId, }) default: - err = fmt.Errorf("%v typed transit gateway attachment nuking not handled", attachmentType) + attachmentErr = fmt.Errorf("%v typed transit gateway attachment nuking not handled", attachmentType) } - if err != nil { - logging.Errorf("[Failed] unable to delete the transit gateway peernig attachment for %v : %s", aws.ToString(id), err) - return err + if attachmentErr != nil { + logging.Errorf("[Failed] unable to delete the transit gateway peering attachment for %v : %s", aws.ToString(id), attachmentErr) + return attachmentErr } - if err := tgw.WaitUntilTransitGatewayAttachmentDeleted(id, attachmentType); err != nil { + if err := waitUntilTransitGatewayAttachmentDeleted(ctx, client, id, attachmentType); err != nil { logging.Errorf("[Failed] unable to wait until nuking the transit gateway attachment with type %v for %v : %s", attachmentType, aws.ToString(id), err) return errors.WithStackTrace(err) } @@ -156,8 +235,8 @@ func (tgw *TransitGateways) nukeAttachments(id *string) error { return nil } -func (tgw *TransitGateways) WaitUntilTransitGatewayAttachmentDeleted(id *string, attachmentType types.TransitGatewayAttachmentResourceType) error { - timeoutCtx, cancel := context.WithTimeout(tgw.Context, 5*time.Minute) +func waitUntilTransitGatewayAttachmentDeleted(ctx context.Context, client TransitGatewaysAPI, id *string, attachmentType types.TransitGatewayAttachmentResourceType) error { + timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) defer cancel() ticker := time.NewTicker(3 * time.Second) @@ -168,7 +247,7 @@ func (tgw *TransitGateways) WaitUntilTransitGatewayAttachmentDeleted(id *string, case <-timeoutCtx.Done(): return fmt.Errorf("transit gateway attachments deletion check timed out after 5 minute") case <-ticker.C: - output, err := tgw.Client.DescribeTransitGatewayAttachments(tgw.Context, &ec2.DescribeTransitGatewayAttachmentsInput{ + output, err := client.DescribeTransitGatewayAttachments(ctx, &ec2.DescribeTransitGatewayAttachmentsInput{ Filters: []types.Filter{ { Name: aws.String("transit-gateway-id"), @@ -186,7 +265,7 @@ func (tgw *TransitGateways) WaitUntilTransitGatewayAttachmentDeleted(id *string, }, }) if err != nil { - logging.Debugf("transit gateway attachment(s) as type %v existance checking error : %v", attachmentType, err) + logging.Debugf("transit gateway attachment(s) as type %v existence checking error : %v", attachmentType, err) return errors.WithStackTrace(err) } @@ -197,43 +276,3 @@ func (tgw *TransitGateways) WaitUntilTransitGatewayAttachmentDeleted(id *string, } } } - -// Delete all TransitGateways -// it attempts to nuke only those resources for which the current IAM user has permission -func (tgw *TransitGateways) nukeAll(ids []*string) error { - if len(ids) == 0 { - logging.Debugf("No Transit Gateways to nuke in region %s", tgw.Region) - return nil - } - - logging.Debugf("Deleting all Transit Gateways in region %s", tgw.Region) - var deletedIds []*string - - for _, id := range ids { - //check the id has the permission to nuke, if not. continue the execution - if nukable, reason := tgw.IsNukable(*id); !nukable { - //not adding the report on final result hence not adding a record entry here - logging.Debugf("[Skipping] %s nuke because %v", *id, reason) - continue - } - err := tgw.nuke(id) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(id), - ResourceType: "Transit Gateway", - Error: err, - } - report.Record(e) - - if err != nil { - logging.Debugf("[Failed] %s", err) - } else { - deletedIds = append(deletedIds, id) - logging.Debugf("Deleted Transit Gateway: %s", *id) - } - } - - logging.Debugf("[OK] %d Transit Gateway(s) deleted in %s", len(deletedIds), tgw.Region) - return nil -} diff --git a/aws/resources/transit_gateway_test.go b/aws/resources/transit_gateway_test.go index c3de85b7..e1f2273a 100644 --- a/aws/resources/transit_gateway_test.go +++ b/aws/resources/transit_gateway_test.go @@ -9,11 +9,12 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/stretchr/testify/require" ) type mockedTransitGateway struct { - TransitGatewayAPI + TransitGatewaysAPI DescribeTransitGatewaysOutput ec2.DescribeTransitGatewaysOutput DeleteTransitGatewayOutput ec2.DeleteTransitGatewayOutput DescribeTransitGatewayAttachmentsOutput ec2.DescribeTransitGatewayAttachmentsOutput @@ -54,20 +55,18 @@ func TestTransitGateways_GetAll(t *testing.T) { now := time.Now() gatewayId1 := "gateway1" gatewayId2 := "gateway2" - tgw := TransitGateways{ - Client: mockedTransitGateway{ - DescribeTransitGatewaysOutput: ec2.DescribeTransitGatewaysOutput{ - TransitGateways: []types.TransitGateway{ - { - TransitGatewayId: aws.String(gatewayId1), - CreationTime: aws.Time(now), - State: "available", - }, - { - TransitGatewayId: aws.String(gatewayId2), - CreationTime: aws.Time(now.Add(1)), - State: "deleting", - }, + mockClient := mockedTransitGateway{ + DescribeTransitGatewaysOutput: ec2.DescribeTransitGatewaysOutput{ + TransitGateways: []types.TransitGateway{ + { + TransitGatewayId: aws.String(gatewayId1), + CreationTime: aws.Time(now), + State: "available", + }, + { + TransitGatewayId: aws.String(gatewayId2), + CreationTime: aws.Time(now.Add(1)), + State: "deleting", }, }, }, @@ -91,9 +90,7 @@ func TestTransitGateways_GetAll(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - names, err := tgw.getAll(context.Background(), config.Config{ - TransitGateway: tc.configObj, - }) + names, err := listTransitGateways(context.Background(), mockClient, resource.Scope{Region: "us-east-1"}, tc.configObj) require.NoError(t, err) require.Equal(t, tc.expected, aws.ToStringSlice(names)) }) @@ -105,12 +102,10 @@ func TestTransitGateways_NukeAll(t *testing.T) { t.Parallel() - tgw := TransitGateways{ - Client: mockedTransitGateway{ - DeleteTransitGatewayOutput: ec2.DeleteTransitGatewayOutput{}, - }, + mockClient := mockedTransitGateway{ + DeleteTransitGatewayOutput: ec2.DeleteTransitGatewayOutput{}, } - err := tgw.nukeAll([]*string{aws.String("test-gateway")}) + err := deleteTransitGateways(context.Background(), mockClient, resource.Scope{Region: "us-east-1"}, "transit-gateway", []*string{aws.String("test-gateway")}) require.NoError(t, err) } diff --git a/aws/resources/transit_gateway_types.go b/aws/resources/transit_gateway_types.go deleted file mode 100644 index 928a8b5c..00000000 --- a/aws/resources/transit_gateway_types.go +++ /dev/null @@ -1,66 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type TransitGatewayAPI interface { - DescribeTransitGateways(ctx context.Context, params *ec2.DescribeTransitGatewaysInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTransitGatewaysOutput, error) - DeleteTransitGateway(ctx context.Context, params *ec2.DeleteTransitGatewayInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayOutput, error) - DescribeTransitGatewayAttachments(ctx context.Context, params *ec2.DescribeTransitGatewayAttachmentsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTransitGatewayAttachmentsOutput, error) - DeleteTransitGatewayPeeringAttachment(ctx context.Context, params *ec2.DeleteTransitGatewayPeeringAttachmentInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayPeeringAttachmentOutput, error) - DeleteTransitGatewayVpcAttachment(ctx context.Context, params *ec2.DeleteTransitGatewayVpcAttachmentInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayVpcAttachmentOutput, error) - DeleteTransitGatewayConnect(ctx context.Context, params *ec2.DeleteTransitGatewayConnectInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTransitGatewayConnectOutput, error) -} - -// TransitGateways - represents all transit gateways -type TransitGateways struct { - BaseAwsResource - Client TransitGatewayAPI - Region string - Ids []string -} - -func (tgw *TransitGateways) Init(cfg aws.Config) { - tgw.Client = ec2.NewFromConfig(cfg) - tgw.Nukables = make(map[string]error) -} - -// ResourceName - the simple name of the aws resource -func (tgw *TransitGateways) ResourceName() string { - return "transit-gateway" -} - -// MaxBatchSize - Tentative batch size to ensure AWS doesn't throttle -func (tgw *TransitGateways) MaxBatchSize() int { - return maxBatchSize -} - -// ResourceIdentifiers - The Ids of the transit gateways -func (tgw *TransitGateways) ResourceIdentifiers() []string { - return tgw.Ids -} - -func (tgw *TransitGateways) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := tgw.getAll(c, configObj) - if err != nil { - return nil, err - } - - tgw.Ids = aws.ToStringSlice(identifiers) - return tgw.Ids, nil -} - -// Nuke - nuke 'em all!!! -func (tgw *TransitGateways) Nuke(ctx context.Context, identifiers []string) error { - if err := tgw.nukeAll(aws.StringSlice(identifiers)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/vpc_lattice_service_network.go b/aws/resources/vpc_lattice_service_network.go index e31b7d74..8d0017b9 100644 --- a/aws/resources/vpc_lattice_service_network.go +++ b/aws/resources/vpc_lattice_service_network.go @@ -9,20 +9,49 @@ import ( "github.com/aws/aws-sdk-go-v2/service/vpclattice" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" - "github.com/gruntwork-io/go-commons/errors" + "github.com/gruntwork-io/cloud-nuke/resource" ) -func (network *VPCLatticeServiceNetwork) getAll(_ context.Context, configObj config.Config) ([]*string, error) { - output, err := network.Client.ListServiceNetworks(network.Context, nil) +// VPCLatticeServiceNetworkAPI defines the interface for VPC Lattice Service Network operations. +type VPCLatticeServiceNetworkAPI interface { + ListServiceNetworks(ctx context.Context, params *vpclattice.ListServiceNetworksInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListServiceNetworksOutput, error) + DeleteServiceNetwork(ctx context.Context, params *vpclattice.DeleteServiceNetworkInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeleteServiceNetworkOutput, error) + ListServiceNetworkServiceAssociations(ctx context.Context, params *vpclattice.ListServiceNetworkServiceAssociationsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListServiceNetworkServiceAssociationsOutput, error) + DeleteServiceNetworkServiceAssociation(ctx context.Context, params *vpclattice.DeleteServiceNetworkServiceAssociationInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeleteServiceNetworkServiceAssociationOutput, error) +} + +// NewVPCLatticeServiceNetwork creates a new VPC Lattice Service Network resource using the generic resource pattern. +func NewVPCLatticeServiceNetwork() AwsResource { + return NewAwsResource(&resource.Resource[VPCLatticeServiceNetworkAPI]{ + ResourceTypeName: "vpc-lattice-service-network", + BatchSize: maxBatchSize, + InitClient: func(r *resource.Resource[VPCLatticeServiceNetworkAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for VPC Lattice client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = vpclattice.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.VPCLatticeServiceNetwork + }, + Lister: listVPCLatticeServiceNetworks, + Nuker: resource.SequentialDeleter(deleteVPCLatticeServiceNetwork), + }) +} + +// listVPCLatticeServiceNetworks retrieves all VPC Lattice Service Networks that match the config filters. +func listVPCLatticeServiceNetworks(ctx context.Context, client VPCLatticeServiceNetworkAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + output, err := client.ListServiceNetworks(ctx, nil) if err != nil { - return nil, errors.WithStackTrace(err) + return nil, err } var ids []*string for _, item := range output.Items { - - if configObj.VPCLatticeServiceNetwork.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Name: item.Name, Time: item.CreatedAt, }) { @@ -33,88 +62,31 @@ func (network *VPCLatticeServiceNetwork) getAll(_ context.Context, configObj con return ids, nil } -func (network *VPCLatticeServiceNetwork) nukeServiceAssociations(id *string) error { - // list service associations - associations, err := network.Client.ListServiceNetworkServiceAssociations(network.Context, &vpclattice.ListServiceNetworkServiceAssociationsInput{ +// nukeServiceAssociations deletes all service associations for a service network. +func nukeServiceAssociations(ctx context.Context, client VPCLatticeServiceNetworkAPI, id *string) error { + associations, err := client.ListServiceNetworkServiceAssociations(ctx, &vpclattice.ListServiceNetworkServiceAssociationsInput{ ServiceNetworkIdentifier: id, }) if err != nil { - return errors.WithStackTrace(err) + return err } for _, item := range associations.Items { - // list service associations - _, err := network.Client.DeleteServiceNetworkServiceAssociation(network.Context, &vpclattice.DeleteServiceNetworkServiceAssociationInput{ + _, err := client.DeleteServiceNetworkServiceAssociation(ctx, &vpclattice.DeleteServiceNetworkServiceAssociationInput{ ServiceNetworkServiceAssociationIdentifier: item.Id, }) if err != nil { - return errors.WithStackTrace(err) - } - - } - return nil -} - -func (network *VPCLatticeServiceNetwork) nukeAll(identifiers []*string) error { - if len(identifiers) == 0 { - logging.Debugf("No %s to nuke in region %s", network.ResourceServiceName(), network.Region) - return nil - - } - - logging.Debugf("Deleting all %s in region %s", network.ResourceServiceName(), network.Region) - - deletedCount := 0 - for _, id := range identifiers { - - err := network.nuke(id) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(id), - ResourceType: network.ResourceServiceName(), - Error: err, - } - report.Record(e) - - if err != nil { - logging.Debugf("[Failed] %s", err) - } else { - deletedCount++ - logging.Debugf("Deleted %s: %s", network.ResourceServiceName(), aws.ToString(id)) + return err } } - - logging.Debugf("[OK] %d %s(s) terminated in %s", deletedCount, network.ResourceServiceName(), network.Region) - return nil -} - -func (network *VPCLatticeServiceNetwork) nukeServiceNetwork(id *string) error { - _, err := network.Client.DeleteServiceNetwork(network.Context, &vpclattice.DeleteServiceNetworkInput{ - ServiceNetworkIdentifier: id, - }) - return err -} - -func (network *VPCLatticeServiceNetwork) nuke(id *string) error { - if err := network.nukeServiceAssociations(id); err != nil { - return err - } - - if err := network.waitUntilAllServiceAssociationDeleted(id); err != nil { - return err - } - if err := network.nukeServiceNetwork(id); err != nil { - return err - } - return nil } -func (network *VPCLatticeServiceNetwork) waitUntilAllServiceAssociationDeleted(id *string) error { +// waitUntilAllServiceAssociationDeleted waits until all service associations are deleted. +func waitUntilAllServiceAssociationDeleted(ctx context.Context, client VPCLatticeServiceNetworkAPI, id *string) error { for i := 0; i < 10; i++ { - output, err := network.Client.ListServiceNetworkServiceAssociations(network.Context, &vpclattice.ListServiceNetworkServiceAssociationsInput{ + output, err := client.ListServiceNetworkServiceAssociations(ctx, &vpclattice.ListServiceNetworkServiceAssociationsInput{ ServiceNetworkIdentifier: id, }) @@ -129,5 +101,23 @@ func (network *VPCLatticeServiceNetwork) waitUntilAllServiceAssociationDeleted(i } return fmt.Errorf("timed out waiting for service associations to be successfully deleted") +} + +// deleteVPCLatticeServiceNetwork deletes a single VPC Lattice Service Network after removing associations. +func deleteVPCLatticeServiceNetwork(ctx context.Context, client VPCLatticeServiceNetworkAPI, id *string) error { + // First delete all service associations + if err := nukeServiceAssociations(ctx, client, id); err != nil { + return err + } + // Wait for all associations to be deleted + if err := waitUntilAllServiceAssociationDeleted(ctx, client, id); err != nil { + return err + } + + // Finally delete the service network + _, err := client.DeleteServiceNetwork(ctx, &vpclattice.DeleteServiceNetworkInput{ + ServiceNetworkIdentifier: id, + }) + return err } diff --git a/aws/resources/vpc_lattice_service_network_test.go b/aws/resources/vpc_lattice_service_network_test.go index 44675305..3c5e7310 100644 --- a/aws/resources/vpc_lattice_service_network_test.go +++ b/aws/resources/vpc_lattice_service_network_test.go @@ -1,4 +1,4 @@ -package resources_test +package resources import ( "context" @@ -9,108 +9,154 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/vpclattice" "github.com/aws/aws-sdk-go-v2/service/vpclattice/types" - "github.com/gruntwork-io/cloud-nuke/aws/resources" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" "github.com/gruntwork-io/cloud-nuke/util" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type mockedVPCLatticeServiceNetwork struct { - resources.VPCLatticeServiceNetworkAPI - DeleteServiceNetworkOutput vpclattice.DeleteServiceNetworkOutput - ListServiceNetworksOutput vpclattice.ListServiceNetworksOutput - +type mockVPCLatticeServiceNetworkClient struct { + ListServiceNetworksOutput vpclattice.ListServiceNetworksOutput + DeleteServiceNetworkOutput vpclattice.DeleteServiceNetworkOutput ListServiceNetworkServiceAssociationsOutput vpclattice.ListServiceNetworkServiceAssociationsOutput DeleteServiceNetworkServiceAssociationOutput vpclattice.DeleteServiceNetworkServiceAssociationOutput } -func (m mockedVPCLatticeServiceNetwork) ListServiceNetworks(ctx context.Context, params *vpclattice.ListServiceNetworksInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListServiceNetworksOutput, error) { +func (m *mockVPCLatticeServiceNetworkClient) ListServiceNetworks(ctx context.Context, params *vpclattice.ListServiceNetworksInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListServiceNetworksOutput, error) { return &m.ListServiceNetworksOutput, nil } -func (m mockedVPCLatticeServiceNetwork) DeleteServiceNetwork(ctx context.Context, params *vpclattice.DeleteServiceNetworkInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeleteServiceNetworkOutput, error) { +func (m *mockVPCLatticeServiceNetworkClient) DeleteServiceNetwork(ctx context.Context, params *vpclattice.DeleteServiceNetworkInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeleteServiceNetworkOutput, error) { return &m.DeleteServiceNetworkOutput, nil } -func (m mockedVPCLatticeServiceNetwork) ListServiceNetworkServiceAssociations(ctx context.Context, params *vpclattice.ListServiceNetworkServiceAssociationsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListServiceNetworkServiceAssociationsOutput, error) { +func (m *mockVPCLatticeServiceNetworkClient) ListServiceNetworkServiceAssociations(ctx context.Context, params *vpclattice.ListServiceNetworkServiceAssociationsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListServiceNetworkServiceAssociationsOutput, error) { return &m.ListServiceNetworkServiceAssociationsOutput, nil } -func (m mockedVPCLatticeServiceNetwork) DeleteServiceNetworkServiceAssociation(ctx context.Context, params *vpclattice.DeleteServiceNetworkServiceAssociationInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeleteServiceNetworkServiceAssociationOutput, error) { + +func (m *mockVPCLatticeServiceNetworkClient) DeleteServiceNetworkServiceAssociation(ctx context.Context, params *vpclattice.DeleteServiceNetworkServiceAssociationInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeleteServiceNetworkServiceAssociationOutput, error) { return &m.DeleteServiceNetworkServiceAssociationOutput, nil } -func TestVPCLatticeServiceNetwork_GetAll(t *testing.T) { +func TestVPCLatticeServiceNetwork_ResourceName(t *testing.T) { + r := NewVPCLatticeServiceNetwork() + assert.Equal(t, "vpc-lattice-service-network", r.ResourceName()) +} + +func TestVPCLatticeServiceNetwork_MaxBatchSize(t *testing.T) { + r := NewVPCLatticeServiceNetwork() + assert.Equal(t, 49, r.MaxBatchSize()) +} +func TestListVPCLatticeServiceNetworks(t *testing.T) { t.Parallel() - var ( - id1 = "aws-nuke-test-" + util.UniqueID() - id2 = "aws-nuke-test-" + util.UniqueID() - now = time.Now() - ) - - obj := resources.VPCLatticeServiceNetwork{ - Client: mockedVPCLatticeServiceNetwork{ - ListServiceNetworksOutput: vpclattice.ListServiceNetworksOutput{ - Items: []types.ServiceNetworkSummary{ - { - Arn: aws.String(id1), - Name: aws.String(id1), - CreatedAt: aws.Time(now), - }, { - Arn: aws.String(id2), - Name: aws.String(id2), - CreatedAt: aws.Time(now.Add(1 * time.Hour)), - }, + id1 := "aws-nuke-test-" + util.UniqueID() + id2 := "aws-nuke-test-" + util.UniqueID() + now := time.Now() + + mock := &mockVPCLatticeServiceNetworkClient{ + ListServiceNetworksOutput: vpclattice.ListServiceNetworksOutput{ + Items: []types.ServiceNetworkSummary{ + { + Arn: aws.String(id1), + Name: aws.String(id1), + CreatedAt: aws.Time(now), + }, + { + Arn: aws.String(id2), + Name: aws.String(id2), + CreatedAt: aws.Time(now.Add(1 * time.Hour)), }, }, }, } - tests := map[string]struct { - configObj config.ResourceType - expected []string - }{ - "emptyFilter": { - configObj: config.ResourceType{}, - expected: []string{id1, id2}, - }, - "nameExclusionFilter": { - configObj: config.ResourceType{ - ExcludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{{ - RE: *regexp.MustCompile(id2), - }}}, + ids, err := listVPCLatticeServiceNetworks(context.Background(), mock, resource.Scope{}, config.ResourceType{}) + require.NoError(t, err) + require.ElementsMatch(t, []string{id1, id2}, aws.ToStringSlice(ids)) +} + +func TestListVPCLatticeServiceNetworks_WithNameExclusionFilter(t *testing.T) { + t.Parallel() + + id1 := "aws-nuke-test-" + util.UniqueID() + id2 := "aws-nuke-test-" + util.UniqueID() + now := time.Now() + + mock := &mockVPCLatticeServiceNetworkClient{ + ListServiceNetworksOutput: vpclattice.ListServiceNetworksOutput{ + Items: []types.ServiceNetworkSummary{ + { + Arn: aws.String(id1), + Name: aws.String(id1), + CreatedAt: aws.Time(now), + }, + { + Arn: aws.String(id2), + Name: aws.String(id2), + CreatedAt: aws.Time(now.Add(1 * time.Hour)), + }, }, - expected: []string{id1}, }, - "timeAfterExclusionFilter": { - configObj: config.ResourceType{ - ExcludeRule: config.FilterRule{ - TimeAfter: aws.Time(now), - }}, - expected: []string{id1}, + } + + cfg := config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{RE: *regexp.MustCompile(id2)}}, }, } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - names, err := obj.GetAndSetIdentifiers(context.Background(), config.Config{ - VPCLatticeServiceNetwork: tc.configObj, - }) - require.NoError(t, err) - require.Equal(t, tc.expected, names) - }) + + ids, err := listVPCLatticeServiceNetworks(context.Background(), mock, resource.Scope{}, cfg) + require.NoError(t, err) + require.Equal(t, []string{id1}, aws.ToStringSlice(ids)) +} + +func TestListVPCLatticeServiceNetworks_WithTimeAfterExclusionFilter(t *testing.T) { + t.Parallel() + + id1 := "aws-nuke-test-" + util.UniqueID() + id2 := "aws-nuke-test-" + util.UniqueID() + now := time.Now() + + mock := &mockVPCLatticeServiceNetworkClient{ + ListServiceNetworksOutput: vpclattice.ListServiceNetworksOutput{ + Items: []types.ServiceNetworkSummary{ + { + Arn: aws.String(id1), + Name: aws.String(id1), + CreatedAt: aws.Time(now), + }, + { + Arn: aws.String(id2), + Name: aws.String(id2), + CreatedAt: aws.Time(now.Add(1 * time.Hour)), + }, + }, + }, + } + + cfg := config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now), + }, } + + ids, err := listVPCLatticeServiceNetworks(context.Background(), mock, resource.Scope{}, cfg) + require.NoError(t, err) + require.Equal(t, []string{id1}, aws.ToStringSlice(ids)) } -func TestVPCLatticeServiceNetwork__NukeAll(t *testing.T) { +func TestDeleteVPCLatticeServiceNetwork(t *testing.T) { t.Parallel() - obj := resources.VPCLatticeServiceNetwork{ - Client: mockedVPCLatticeServiceNetwork{ - ListServiceNetworksOutput: vpclattice.ListServiceNetworksOutput{}, + mock := &mockVPCLatticeServiceNetworkClient{ + ListServiceNetworkServiceAssociationsOutput: vpclattice.ListServiceNetworkServiceAssociationsOutput{ + Items: []types.ServiceNetworkServiceAssociationSummary{}, }, } - err := obj.Nuke(context.TODO(), []string{"test"}) + + err := deleteVPCLatticeServiceNetwork(context.Background(), mock, aws.String("test-arn")) require.NoError(t, err) } diff --git a/aws/resources/vpc_lattice_service_network_types.go b/aws/resources/vpc_lattice_service_network_types.go deleted file mode 100644 index 1bac6d1e..00000000 --- a/aws/resources/vpc_lattice_service_network_types.go +++ /dev/null @@ -1,69 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/vpclattice" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type VPCLatticeServiceNetworkAPI interface { - ListServiceNetworks(ctx context.Context, params *vpclattice.ListServiceNetworksInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListServiceNetworksOutput, error) - DeleteServiceNetwork(ctx context.Context, params *vpclattice.DeleteServiceNetworkInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeleteServiceNetworkOutput, error) - ListServiceNetworkServiceAssociations(ctx context.Context, params *vpclattice.ListServiceNetworkServiceAssociationsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListServiceNetworkServiceAssociationsOutput, error) - DeleteServiceNetworkServiceAssociation(ctx context.Context, params *vpclattice.DeleteServiceNetworkServiceAssociationInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeleteServiceNetworkServiceAssociationOutput, error) -} - -type VPCLatticeServiceNetwork struct { - BaseAwsResource - Client VPCLatticeServiceNetworkAPI - Region string - ARNs []string -} - -func (sch *VPCLatticeServiceNetwork) Init(cfg aws.Config) { - sch.Client = vpclattice.NewFromConfig(cfg) -} - -// ResourceName - the simple name of the aws resource -func (n *VPCLatticeServiceNetwork) ResourceName() string { - return "vpc-lattice-service-network" -} - -// ResourceIdentifiers - the arns of the aws certificate manager certificates -func (n *VPCLatticeServiceNetwork) ResourceIdentifiers() []string { - return n.ARNs -} - -func (n *VPCLatticeServiceNetwork) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.VPCLatticeServiceNetwork -} - -func (n *VPCLatticeServiceNetwork) ResourceServiceName() string { - return "VPC Lattice Service Network" -} - -func (n *VPCLatticeServiceNetwork) MaxBatchSize() int { - return maxBatchSize -} - -func (n *VPCLatticeServiceNetwork) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := n.getAll(c, configObj) - if err != nil { - return nil, err - } - - n.ARNs = aws.ToStringSlice(identifiers) - return n.ARNs, nil -} - -// Nuke - nuke 'em all!!! -func (n *VPCLatticeServiceNetwork) Nuke(ctx context.Context, arns []string) error { - if err := n.nukeAll(aws.StringSlice(arns)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/aws/resources/vpc_lattice_target_group.go b/aws/resources/vpc_lattice_target_group.go index cd19c7fd..a0f4a93e 100644 --- a/aws/resources/vpc_lattice_target_group.go +++ b/aws/resources/vpc_lattice_target_group.go @@ -8,40 +8,67 @@ import ( "github.com/aws/aws-sdk-go-v2/service/vpclattice/types" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/report" - "github.com/gruntwork-io/go-commons/errors" + "github.com/gruntwork-io/cloud-nuke/resource" ) -func (network *VPCLatticeTargetGroup) getAll(_ context.Context, configObj config.Config) ([]*string, error) { - output, err := network.Client.ListTargetGroups(network.Context, nil) +// VPCLatticeTargetGroupAPI defines the interface for VPC Lattice Target Group operations. +type VPCLatticeTargetGroupAPI interface { + ListTargetGroups(ctx context.Context, params *vpclattice.ListTargetGroupsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListTargetGroupsOutput, error) + ListTargets(ctx context.Context, params *vpclattice.ListTargetsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListTargetsOutput, error) + DeregisterTargets(ctx context.Context, params *vpclattice.DeregisterTargetsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeregisterTargetsOutput, error) + DeleteTargetGroup(ctx context.Context, params *vpclattice.DeleteTargetGroupInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeleteTargetGroupOutput, error) +} + +// NewVPCLatticeTargetGroup creates a new VPC Lattice Target Group resource using the generic resource pattern. +func NewVPCLatticeTargetGroup() AwsResource { + return NewAwsResource(&resource.Resource[VPCLatticeTargetGroupAPI]{ + ResourceTypeName: "vpc-lattice-target-group", + BatchSize: maxBatchSize, + InitClient: func(r *resource.Resource[VPCLatticeTargetGroupAPI], cfg any) { + awsCfg, ok := cfg.(aws.Config) + if !ok { + logging.Debugf("Invalid config type for VPC Lattice client: expected aws.Config") + return + } + r.Scope.Region = awsCfg.Region + r.Client = vpclattice.NewFromConfig(awsCfg) + }, + ConfigGetter: func(c config.Config) config.ResourceType { + return c.VPCLatticeTargetGroup + }, + Lister: listVPCLatticeTargetGroups, + Nuker: resource.SequentialDeleter(deleteVPCLatticeTargetGroup), + }) +} + +// listVPCLatticeTargetGroups retrieves all VPC Lattice Target Groups that match the config filters. +func listVPCLatticeTargetGroups(ctx context.Context, client VPCLatticeTargetGroupAPI, scope resource.Scope, cfg config.ResourceType) ([]*string, error) { + output, err := client.ListTargetGroups(ctx, nil) if err != nil { - return nil, errors.WithStackTrace(err) + return nil, err } var ids []*string for _, item := range output.Items { - - if configObj.VPCLatticeTargetGroup.ShouldInclude(config.ResourceValue{ + if cfg.ShouldInclude(config.ResourceValue{ Name: item.Name, Time: item.CreatedAt, }) { ids = append(ids, item.Arn) - // also keep the complete info about the target groups as the target group assoiation needs to be nuked before removing it - network.TargetGroups[aws.ToString(item.Arn)] = &item } } return ids, nil } -func (network *VPCLatticeTargetGroup) nukeTargets(identifier *string) error { - // list the targets associated on the target group - output, err := network.Client.ListTargets(network.Context, &vpclattice.ListTargetsInput{ +// nukeVPCLatticeTargets deregisters all targets from a target group. +func nukeVPCLatticeTargets(ctx context.Context, client VPCLatticeTargetGroupAPI, identifier *string) error { + output, err := client.ListTargets(ctx, &vpclattice.ListTargetsInput{ TargetGroupIdentifier: identifier, }) if err != nil { - logging.Debugf("[ListTargetsWithContext Failed] %s", err) - return errors.WithStackTrace(err) + logging.Debugf("[ListTargets Failed] %s", err) + return err } var targets []types.Target @@ -52,68 +79,29 @@ func (network *VPCLatticeTargetGroup) nukeTargets(identifier *string) error { } if len(targets) > 0 { - // before deleting the targets, we need to deregister the targets registered with it - _, err = network.Client.DeregisterTargets(network.Context, &vpclattice.DeregisterTargetsInput{ + _, err = client.DeregisterTargets(ctx, &vpclattice.DeregisterTargetsInput{ TargetGroupIdentifier: identifier, Targets: targets, }) if err != nil { - logging.Debugf("[DeregisterTargetsWithContext Failed] %s", err) - return errors.WithStackTrace(err) + logging.Debugf("[DeregisterTargets Failed] %s", err) + return err } } return nil } -func (network *VPCLatticeTargetGroup) nuke(identifier *string) error { - - var err error - err = network.nukeTargets(identifier) - if err != nil { - return errors.WithStackTrace(err) +// deleteVPCLatticeTargetGroup deletes a single VPC Lattice Target Group after deregistering its targets. +func deleteVPCLatticeTargetGroup(ctx context.Context, client VPCLatticeTargetGroupAPI, identifier *string) error { + // First deregister all targets + if err := nukeVPCLatticeTargets(ctx, client, identifier); err != nil { + return err } - // delete the target group - _, err = network.Client.DeleteTargetGroup(network.Context, &vpclattice.DeleteTargetGroupInput{ + // Then delete the target group + _, err := client.DeleteTargetGroup(ctx, &vpclattice.DeleteTargetGroupInput{ TargetGroupIdentifier: identifier, }) - if err != nil { - return errors.WithStackTrace(err) - } - - return nil -} -func (network *VPCLatticeTargetGroup) nukeAll(identifiers []*string) error { - if len(identifiers) == 0 { - logging.Debugf("No %s to nuke in region %s", network.ResourceServiceName(), network.Region) - return nil - - } - - logging.Debugf("Deleting all %s in region %s", network.ResourceServiceName(), network.Region) - - deletedCount := 0 - for _, id := range identifiers { - - err := network.nuke(id) - - // Record status of this resource - e := report.Entry{ - Identifier: aws.ToString(id), - ResourceType: network.ResourceServiceName(), - Error: err, - } - report.Record(e) - - if err != nil { - logging.Debugf("[Failed] %s", err) - } else { - deletedCount++ - logging.Debugf("Deleted %s: %s", network.ResourceServiceName(), aws.ToString(id)) - } - } - - logging.Debugf("[OK] %d %s(s) terminated in %s", deletedCount, network.ResourceServiceName(), network.Region) - return nil + return err } diff --git a/aws/resources/vpc_lattice_target_group_test.go b/aws/resources/vpc_lattice_target_group_test.go index 98394b73..004b5f85 100644 --- a/aws/resources/vpc_lattice_target_group_test.go +++ b/aws/resources/vpc_lattice_target_group_test.go @@ -1 +1,109 @@ -package resources_test +package resources + +import ( + "context" + "regexp" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/vpclattice" + "github.com/aws/aws-sdk-go-v2/service/vpclattice/types" + "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockVPCLatticeTargetGroupClient struct { + ListTargetGroupsOutput vpclattice.ListTargetGroupsOutput + ListTargetsOutput vpclattice.ListTargetsOutput + DeregisterTargetsOutput vpclattice.DeregisterTargetsOutput + DeleteTargetGroupOutput vpclattice.DeleteTargetGroupOutput +} + +func (m *mockVPCLatticeTargetGroupClient) ListTargetGroups(ctx context.Context, params *vpclattice.ListTargetGroupsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListTargetGroupsOutput, error) { + return &m.ListTargetGroupsOutput, nil +} + +func (m *mockVPCLatticeTargetGroupClient) ListTargets(ctx context.Context, params *vpclattice.ListTargetsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListTargetsOutput, error) { + return &m.ListTargetsOutput, nil +} + +func (m *mockVPCLatticeTargetGroupClient) DeregisterTargets(ctx context.Context, params *vpclattice.DeregisterTargetsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeregisterTargetsOutput, error) { + return &m.DeregisterTargetsOutput, nil +} + +func (m *mockVPCLatticeTargetGroupClient) DeleteTargetGroup(ctx context.Context, params *vpclattice.DeleteTargetGroupInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeleteTargetGroupOutput, error) { + return &m.DeleteTargetGroupOutput, nil +} + +func TestVPCLatticeTargetGroup_ResourceName(t *testing.T) { + r := NewVPCLatticeTargetGroup() + assert.Equal(t, "vpc-lattice-target-group", r.ResourceName()) +} + +func TestVPCLatticeTargetGroup_MaxBatchSize(t *testing.T) { + r := NewVPCLatticeTargetGroup() + assert.Equal(t, 49, r.MaxBatchSize()) +} + +func TestListVPCLatticeTargetGroups(t *testing.T) { + t.Parallel() + + now := time.Now() + mock := &mockVPCLatticeTargetGroupClient{ + ListTargetGroupsOutput: vpclattice.ListTargetGroupsOutput{ + Items: []types.TargetGroupSummary{ + {Arn: aws.String("arn:aws:vpc-lattice:us-east-1:123456789012:targetgroup/tg-1"), Name: aws.String("tg-1"), CreatedAt: aws.Time(now)}, + {Arn: aws.String("arn:aws:vpc-lattice:us-east-1:123456789012:targetgroup/tg-2"), Name: aws.String("tg-2"), CreatedAt: aws.Time(now)}, + }, + }, + } + + ids, err := listVPCLatticeTargetGroups(context.Background(), mock, resource.Scope{}, config.ResourceType{}) + require.NoError(t, err) + require.ElementsMatch(t, []string{ + "arn:aws:vpc-lattice:us-east-1:123456789012:targetgroup/tg-1", + "arn:aws:vpc-lattice:us-east-1:123456789012:targetgroup/tg-2", + }, aws.ToStringSlice(ids)) +} + +func TestListVPCLatticeTargetGroups_WithFilter(t *testing.T) { + t.Parallel() + + now := time.Now() + mock := &mockVPCLatticeTargetGroupClient{ + ListTargetGroupsOutput: vpclattice.ListTargetGroupsOutput{ + Items: []types.TargetGroupSummary{ + {Arn: aws.String("arn:aws:vpc-lattice:us-east-1:123456789012:targetgroup/tg-1"), Name: aws.String("tg-1"), CreatedAt: aws.Time(now)}, + {Arn: aws.String("arn:aws:vpc-lattice:us-east-1:123456789012:targetgroup/skip-tg"), Name: aws.String("skip-tg"), CreatedAt: aws.Time(now)}, + }, + }, + } + + cfg := config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{RE: *regexp.MustCompile("skip-.*")}}, + }, + } + + ids, err := listVPCLatticeTargetGroups(context.Background(), mock, resource.Scope{}, cfg) + require.NoError(t, err) + require.Equal(t, []string{"arn:aws:vpc-lattice:us-east-1:123456789012:targetgroup/tg-1"}, aws.ToStringSlice(ids)) +} + +func TestDeleteVPCLatticeTargetGroup(t *testing.T) { + t.Parallel() + + mock := &mockVPCLatticeTargetGroupClient{ + ListTargetsOutput: vpclattice.ListTargetsOutput{ + Items: []types.TargetSummary{ + {Id: aws.String("target-1")}, + }, + }, + } + + err := deleteVPCLatticeTargetGroup(context.Background(), mock, aws.String("arn:aws:vpc-lattice:us-east-1:123456789012:targetgroup/tg-1")) + require.NoError(t, err) +} diff --git a/aws/resources/vpc_lattice_target_group_types.go b/aws/resources/vpc_lattice_target_group_types.go deleted file mode 100644 index 92fa7088..00000000 --- a/aws/resources/vpc_lattice_target_group_types.go +++ /dev/null @@ -1,72 +0,0 @@ -package resources - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/vpclattice" - "github.com/aws/aws-sdk-go-v2/service/vpclattice/types" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/go-commons/errors" -) - -type VPCLatticeAPI interface { - ListTargetGroups(ctx context.Context, params *vpclattice.ListTargetGroupsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListTargetGroupsOutput, error) - ListTargets(ctx context.Context, params *vpclattice.ListTargetsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.ListTargetsOutput, error) - DeregisterTargets(ctx context.Context, params *vpclattice.DeregisterTargetsInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeregisterTargetsOutput, error) - DeleteTargetGroup(ctx context.Context, params *vpclattice.DeleteTargetGroupInput, optFns ...func(*vpclattice.Options)) (*vpclattice.DeleteTargetGroupOutput, error) -} - -type VPCLatticeTargetGroup struct { - BaseAwsResource - Client VPCLatticeAPI - Region string - ARNs []string - TargetGroups map[string]*types.TargetGroupSummary -} - -func (sch *VPCLatticeTargetGroup) Init(cfg aws.Config) { - sch.Client = vpclattice.NewFromConfig(cfg) - sch.TargetGroups = make(map[string]*types.TargetGroupSummary, 0) -} - -// ResourceName - the simple name of the aws resource -func (n *VPCLatticeTargetGroup) ResourceName() string { - return "vpc-lattice-target-group" -} - -// ResourceIdentifiers - the arns of the aws certificate manager certificates -func (n *VPCLatticeTargetGroup) ResourceIdentifiers() []string { - return n.ARNs -} - -func (n *VPCLatticeTargetGroup) ResourceServiceName() string { - return "VPC Lattice Target Group" -} - -func (n *VPCLatticeTargetGroup) MaxBatchSize() int { - return maxBatchSize -} - -func (n *VPCLatticeTargetGroup) GetAndSetResourceConfig(configObj config.Config) config.ResourceType { - return configObj.VPCLatticeTargetGroup -} - -func (n *VPCLatticeTargetGroup) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) { - identifiers, err := n.getAll(c, configObj) - if err != nil { - return nil, err - } - - n.ARNs = aws.ToStringSlice(identifiers) - return n.ARNs, nil -} - -// Nuke - nuke 'em all!!! -func (n *VPCLatticeTargetGroup) Nuke(ctx context.Context, arns []string) error { - if err := n.nukeAll(aws.StringSlice(arns)); err != nil { - return errors.WithStackTrace(err) - } - - return nil -} diff --git a/commands/cli_test.go b/commands/cli_test.go index 1cf82ac9..5d6ee4ee 100644 --- a/commands/cli_test.go +++ b/commands/cli_test.go @@ -106,18 +106,18 @@ func TestStructuredErrors(t *testing.T) { func TestListResourceTypes(t *testing.T) { allAWSResourceTypes := aws.ListResourceTypes() assert.Greater(t, len(allAWSResourceTypes), 0) - assert.Contains(t, allAWSResourceTypes, (&resources.EC2Instances{}).ResourceName()) + assert.Contains(t, allAWSResourceTypes, resources.NewEC2Instances().ResourceName()) } func TestIsValidResourceType(t *testing.T) { allAWSResourceTypes := aws.ListResourceTypes() - ec2ResourceName := (*&resources.EC2Instances{}).ResourceName() + ec2ResourceName := resources.NewEC2Instances().ResourceName() assert.Equal(t, aws.IsValidResourceType(ec2ResourceName, allAWSResourceTypes), true) assert.Equal(t, aws.IsValidResourceType("xyz", allAWSResourceTypes), false) } func TestIsNukeable(t *testing.T) { - ec2ResourceName := (&resources.EC2Instances{}).ResourceName() + ec2ResourceName := resources.NewEC2Instances().ResourceName() amiResourceName := resources.NewAMIs().ResourceName() assert.Equal(t, aws.IsNukeable(ec2ResourceName, []string{ec2ResourceName}), true) diff --git a/resource/batch_deleter.go b/resource/batch_deleter.go index 535a0f48..cde93cd1 100644 --- a/resource/batch_deleter.go +++ b/resource/batch_deleter.go @@ -176,6 +176,24 @@ func BulkDeleter[C any](deleteFn BulkDeleteFunc[C]) NukerFunc[C] { } } +// DeleteThenWait combines a delete function with a wait function into a single DeleteFunc. +// Use this with SequentialDeleter for resources that need to wait for deletion to complete. +// +// Example: +// +// Nuker: resource.SequentialDeleter(resource.DeleteThenWait( +// deleteCluster, +// waitForClusterDeleted, +// )) +func DeleteThenWait[C any](deleteFn DeleteFunc[C], waitFn DeleteFunc[C]) DeleteFunc[C] { + return func(ctx context.Context, client C, id *string) error { + if err := deleteFn(ctx, client, id); err != nil { + return err + } + return waitFn(ctx, client, id) + } +} + // MultiStepDeleter creates a nuker that executes multiple steps per resource in sequence. // Use this for resources that require cleanup before deletion (e.g., detach policies, empty bucket). // Each resource is processed sequentially, but if any step fails for a resource, it moves to the next resource. @@ -216,3 +234,170 @@ func MultiStepDeleter[C any](steps ...DeleteFunc[C]) NukerFunc[C] { return allErrs.ErrorOrNil() } } + +// WaitAllFunc is a function that waits for multiple resources to be deleted. +// Used with SequentialDeleteThenWaitAll for batch waiting after all deletes complete. +type WaitAllFunc[C any] func(ctx context.Context, client C, ids []string) error + +// SequentialDeleteThenWaitAll creates a nuker that: +// 1. Deletes all resources sequentially, recording each result +// 2. Waits for ALL successfully deleted resources to be confirmed deleted +// +// Use this for resources where the delete API returns immediately but the resource +// takes time to be fully deleted, and the wait API can check multiple resources at once. +// +// Example: +// +// Nuker: resource.SequentialDeleteThenWaitAll( +// deleteASG, +// waitForASGsDeleted, // Uses autoscaling.NewGroupNotExistsWaiter +// ) +func SequentialDeleteThenWaitAll[C any](deleteFn DeleteFunc[C], waitAllFn WaitAllFunc[C]) NukerFunc[C] { + return func(ctx context.Context, client C, scope Scope, resourceType string, identifiers []*string) error { + if logEmptyAndSkip(identifiers, resourceType, scope) { + return nil + } + + if len(identifiers) > MaxBatchSizeLimit { + logging.Errorf("Nuking too many %s at once (%d): halting to avoid hitting rate limiting", + resourceType, len(identifiers)) + return fmt.Errorf("too many %s requested at once (%d > %d limit)", resourceType, len(identifiers), MaxBatchSizeLimit) + } + + logDeletionStart(len(identifiers), resourceType, scope) + + var allErrs *multierror.Error + var deletedIds []string + + // Phase 1: Delete all resources sequentially + for _, id := range identifiers { + idStr := aws.ToString(id) + err := deleteFn(ctx, client, id) + + if err != nil { + logging.Errorf("[Failed] %s %s: %s", resourceType, idStr, err) + allErrs = multierror.Append(allErrs, fmt.Errorf("%s %s: %w", resourceType, idStr, err)) + report.Record(report.Entry{ + Identifier: idStr, + ResourceType: resourceType, + Error: err, + }) + } else { + deletedIds = append(deletedIds, idStr) + logging.Debugf("[Deleted] %s: %s (waiting for confirmation)", resourceType, idStr) + } + } + + // Phase 2: Wait for all successfully deleted resources + if len(deletedIds) > 0 { + waitErr := waitAllFn(ctx, client, deletedIds) + // Record results for all deleted resources + for _, idStr := range deletedIds { + report.Record(report.Entry{ + Identifier: idStr, + ResourceType: resourceType, + Error: waitErr, + }) + if waitErr != nil { + logging.Errorf("[Failed] %s %s wait: %s", resourceType, idStr, waitErr) + } else { + logging.Debugf("[OK] Deleted %s: %s", resourceType, idStr) + } + } + if waitErr != nil { + allErrs = multierror.Append(allErrs, fmt.Errorf("wait for %s deletion: %w", resourceType, waitErr)) + } + } + + return allErrs.ErrorOrNil() + } +} + +// ConcurrentDeleteThenWaitAll creates a nuker that: +// 1. Deletes all resources concurrently with controlled parallelism +// 2. Waits for ALL successfully deleted resources to be confirmed deleted +// +// Use this for resources where concurrent deletion is safe and the wait API +// can check multiple resources at once. +// +// Example: +// +// Nuker: resource.ConcurrentDeleteThenWaitAll( +// deleteOpenSearchDomain, +// waitForOpenSearchDomainsDeleted, +// ) +func ConcurrentDeleteThenWaitAll[C any](deleteFn DeleteFunc[C], waitAllFn WaitAllFunc[C]) NukerFunc[C] { + return func(ctx context.Context, client C, scope Scope, resourceType string, identifiers []*string) error { + if logEmptyAndSkip(identifiers, resourceType, scope) { + return nil + } + + if len(identifiers) > MaxBatchSizeLimit { + logging.Errorf("Nuking too many %s at once (%d): halting to avoid hitting rate limiting", + resourceType, len(identifiers)) + return fmt.Errorf("too many %s requested at once (%d > %d limit)", resourceType, len(identifiers), MaxBatchSizeLimit) + } + + logDeletionStart(len(identifiers), resourceType, scope) + + // Phase 1: Delete all resources concurrently + sem := make(chan struct{}, DefaultMaxConcurrent) + var wg sync.WaitGroup + var mu sync.Mutex + var allErrs *multierror.Error + var deletedIds []string + + for _, id := range identifiers { + wg.Add(1) + sem <- struct{}{} + + go func(identifier *string) { + defer wg.Done() + defer func() { <-sem }() + + idStr := aws.ToString(identifier) + err := deleteFn(ctx, client, identifier) + + mu.Lock() + if err != nil { + logging.Errorf("[Failed] %s %s: %s", resourceType, idStr, err) + allErrs = multierror.Append(allErrs, fmt.Errorf("%s %s: %w", resourceType, idStr, err)) + report.Record(report.Entry{ + Identifier: idStr, + ResourceType: resourceType, + Error: err, + }) + } else { + deletedIds = append(deletedIds, idStr) + logging.Debugf("[Deleted] %s: %s (waiting for confirmation)", resourceType, idStr) + } + mu.Unlock() + }(id) + } + + wg.Wait() + + // Phase 2: Wait for all successfully deleted resources + if len(deletedIds) > 0 { + waitErr := waitAllFn(ctx, client, deletedIds) + // Record results for all deleted resources + for _, idStr := range deletedIds { + report.Record(report.Entry{ + Identifier: idStr, + ResourceType: resourceType, + Error: waitErr, + }) + if waitErr != nil { + logging.Errorf("[Failed] %s %s wait: %s", resourceType, idStr, waitErr) + } else { + logging.Debugf("[OK] Deleted %s: %s", resourceType, idStr) + } + } + if waitErr != nil { + allErrs = multierror.Append(allErrs, fmt.Errorf("wait for %s deletion: %w", resourceType, waitErr)) + } + } + + return allErrs.ErrorOrNil() + } +}