diff --git a/cmd/yace/main.go b/cmd/yace/main.go index 17837bd50..4ff3c85a2 100644 --- a/cmd/yace/main.go +++ b/cmd/yace/main.go @@ -29,9 +29,8 @@ import ( "golang.org/x/sync/semaphore" exporter "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg" + "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients" "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch" - v1 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/v1" - v2 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/v2" "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/config" ) @@ -267,20 +266,11 @@ func startScraper(c *cli.Context) error { featureFlags := c.StringSlice(enableFeatureFlag) s := NewScraper(featureFlags) - var cache cachingFactory - cache, err = v2.NewFactory(logger, jobsCfg, fips) + cache, err := clients.NewFactory(logger, jobsCfg, fips) if err != nil { return fmt.Errorf("failed to construct aws sdk v2 client cache: %w", err) } - // Switch to v1 SDK if feature flag is enabled - for _, featureFlag := range featureFlags { - if featureFlag == config.AwsSdkV1 { - cache = v1.NewFactory(logger, jobsCfg, fips) - logger.Info("Using aws sdk v1") - } - } - ctx, cancelRunningScrape := context.WithCancel(context.Background()) go s.decoupled(ctx, logger, jobsCfg, cache) @@ -325,21 +315,12 @@ func startScraper(c *cli.Context) error { } logger.Info("Reset clients cache") - var cache cachingFactory - cache, err = v2.NewFactory(logger, newJobsCfg, fips) + cache, err := clients.NewFactory(logger, newJobsCfg, fips) if err != nil { logger.Error("Failed to construct aws sdk v2 client cache", "err", err, "path", configFile) return } - // Switch to v1 SDK if feature flag is enabled - for _, featureFlag := range featureFlags { - if featureFlag == config.AwsSdkV1 { - cache = v1.NewFactory(logger, newJobsCfg, fips) - logger.Info("Using aws sdk v1") - } - } - cancelRunningScrape() ctx, cancelRunningScrape = context.WithCancel(context.Background()) go s.decoupled(ctx, logger, newJobsCfg, cache) diff --git a/docs/feature_flags.md b/docs/feature_flags.md index d0a09d416..b1fcf2427 100644 --- a/docs/feature_flags.md +++ b/docs/feature_flags.md @@ -4,12 +4,6 @@ List of features or changes that are disabled by default since they are breaking You can enable them using the `-enable-feature` flag with a comma separated list of features. They may be enabled by default in future versions. -## AWS SDK v1 - -`-enable-feature=aws-sdk-v1` - -Uses the v1 version of the aws sdk for go for backward compatibility. By default, YACE now uses AWS SDK v2 which was released in Jan 2021 and comes with large performance gains. - ## Always return info metrics `-enable-feature=always-return-info-metrics` diff --git a/go.mod b/go.mod index 52086a38d..d5ba169ac 100644 --- a/go.mod +++ b/go.mod @@ -48,7 +48,6 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 31888bca7..2715ab77a 100644 --- a/go.sum +++ b/go.sum @@ -68,8 +68,6 @@ github.com/grafana/regexp v0.0.0-20240607082908-2cb410fa05da h1:BML5sNe+bw2uO8t8 github.com/grafana/regexp v0.0.0-20240607082908-2cb410fa05da/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= -github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -129,7 +127,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/clients/account/client.go b/pkg/clients/account/client.go index 5e103b2d2..0293d9ce4 100644 --- a/pkg/clients/account/client.go +++ b/pkg/clients/account/client.go @@ -12,12 +12,54 @@ // limitations under the License. package account -import "context" +import ( + "context" + "errors" + "log/slog" -type Client interface { - // GetAccount returns the AWS account ID for the configured authenticated client. - GetAccount(ctx context.Context) (string, error) + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/sts" +) - // GetAccountAlias returns the account alias if there's one set, otherwise an empty string. - GetAccountAlias(ctx context.Context) (string, error) +type client struct { + logger *slog.Logger + stsClient *sts.Client + iamClient *iam.Client +} + +func NewClient(logger *slog.Logger, stsClient *sts.Client, iamClient *iam.Client) Client { + return &client{ + logger: logger, + stsClient: stsClient, + iamClient: iamClient, + } +} + +func (c client) GetAccount(ctx context.Context) (string, error) { + result, err := c.stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return "", err + } + if result.Account == nil { + return "", errors.New("aws sts GetCallerIdentityWithContext returned no account") + } + return *result.Account, nil +} + +func (c client) GetAccountAlias(ctx context.Context) (string, error) { + acctAliasOut, err := c.iamClient.ListAccountAliases(ctx, &iam.ListAccountAliasesInput{}) + if err != nil { + return "", err + } + + possibleAccountAlias := "" + + // Since a single account can only have one alias, and an authenticated SDK session corresponds to a single account, + // the output can have at most one alias. + // https://docs.aws.amazon.com/IAM/latest/APIReference/API_ListAccountAliases.html + if len(acctAliasOut.AccountAliases) > 0 { + possibleAccountAlias = acctAliasOut.AccountAliases[0] + } + + return possibleAccountAlias, nil } diff --git a/pkg/clients/account/iface.go b/pkg/clients/account/iface.go new file mode 100644 index 000000000..5e103b2d2 --- /dev/null +++ b/pkg/clients/account/iface.go @@ -0,0 +1,23 @@ +// Copyright 2024 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package account + +import "context" + +type Client interface { + // GetAccount returns the AWS account ID for the configured authenticated client. + GetAccount(ctx context.Context) (string, error) + + // GetAccountAlias returns the account alias if there's one set, otherwise an empty string. + GetAccountAlias(ctx context.Context) (string, error) +} diff --git a/pkg/clients/account/v1/client.go b/pkg/clients/account/v1/client.go deleted file mode 100644 index 97c629ee7..000000000 --- a/pkg/clients/account/v1/client.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v1 - -import ( - "context" - "errors" - "log/slog" - - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/account" - - "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/iam/iamiface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" -) - -type client struct { - logger *slog.Logger - stsClient stsiface.STSAPI - iamClient iamiface.IAMAPI -} - -func NewClient(logger *slog.Logger, stsClient stsiface.STSAPI, iamClient iamiface.IAMAPI) account.Client { - return &client{ - logger: logger, - stsClient: stsClient, - iamClient: iamClient, - } -} - -func (c client) GetAccount(ctx context.Context) (string, error) { - result, err := c.stsClient.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) - if err != nil { - return "", err - } - if result.Account == nil { - return "", errors.New("aws sts GetCallerIdentityWithContext returned no account") - } - return *result.Account, nil -} - -func (c client) GetAccountAlias(ctx context.Context) (string, error) { - acctAliasOut, err := c.iamClient.ListAccountAliasesWithContext(ctx, &iam.ListAccountAliasesInput{}) - if err != nil { - return "", err - } - - possibleAccountAlias := "" - - // Since a single account can only have one alias, and an authenticated SDK session corresponds to a single account, - // the output can have at most one alias. - // https://docs.aws.amazon.com/IAM/latest/APIReference/API_ListAccountAliases.html - if len(acctAliasOut.AccountAliases) > 0 { - possibleAccountAlias = *acctAliasOut.AccountAliases[0] - } - - return possibleAccountAlias, nil -} diff --git a/pkg/clients/account/v2/client.go b/pkg/clients/account/v2/client.go deleted file mode 100644 index 253204489..000000000 --- a/pkg/clients/account/v2/client.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v2 - -import ( - "context" - "errors" - "log/slog" - - "github.com/aws/aws-sdk-go-v2/service/iam" - "github.com/aws/aws-sdk-go-v2/service/sts" - - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/account" -) - -type client struct { - logger *slog.Logger - stsClient *sts.Client - iamClient *iam.Client -} - -func NewClient(logger *slog.Logger, stsClient *sts.Client, iamClient *iam.Client) account.Client { - return &client{ - logger: logger, - stsClient: stsClient, - iamClient: iamClient, - } -} - -func (c client) GetAccount(ctx context.Context) (string, error) { - result, err := c.stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) - if err != nil { - return "", err - } - if result.Account == nil { - return "", errors.New("aws sts GetCallerIdentityWithContext returned no account") - } - return *result.Account, nil -} - -func (c client) GetAccountAlias(ctx context.Context) (string, error) { - acctAliasOut, err := c.iamClient.ListAccountAliases(ctx, &iam.ListAccountAliasesInput{}) - if err != nil { - return "", err - } - - possibleAccountAlias := "" - - // Since a single account can only have one alias, and an authenticated SDK session corresponds to a single account, - // the output can have at most one alias. - // https://docs.aws.amazon.com/IAM/latest/APIReference/API_ListAccountAliases.html - if len(acctAliasOut.AccountAliases) > 0 { - possibleAccountAlias = acctAliasOut.AccountAliases[0] - } - - return possibleAccountAlias, nil -} diff --git a/pkg/clients/cloudwatch/client.go b/pkg/clients/cloudwatch/client.go index 9d98ecf83..53eeba020 100644 --- a/pkg/clients/cloudwatch/client.go +++ b/pkg/clients/cloudwatch/client.go @@ -17,82 +17,200 @@ import ( "log/slog" "time" + "github.com/aws/aws-sdk-go-v2/aws" + aws_cloudwatch "github.com/aws/aws-sdk-go-v2/service/cloudwatch" + "github.com/aws/aws-sdk-go-v2/service/cloudwatch/types" + "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" + "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/promutil" ) -const ( - listMetricsCall = "ListMetrics" - getMetricDataCall = "GetMetricData" - getMetricStatisticsCall = "GetMetricStatistics" -) +type client struct { + logger *slog.Logger + cloudwatchAPI *aws_cloudwatch.Client +} -type Client interface { - // ListMetrics returns the list of metrics and dimensions for a given namespace - // and metric name. Results pagination is handled automatically: the caller can - // optionally pass a non-nil func in order to handle results pages. - ListMetrics(ctx context.Context, namespace string, metric *model.MetricConfig, recentlyActiveOnly bool, fn func(page []*model.Metric)) error +func NewClient(logger *slog.Logger, cloudwatchAPI *aws_cloudwatch.Client) Client { + return &client{ + logger: logger, + cloudwatchAPI: cloudwatchAPI, + } +} - // GetMetricData returns the output of the GetMetricData CloudWatch API. - // Results pagination is handled automatically. - GetMetricData(ctx context.Context, getMetricData []*model.CloudwatchData, namespace string, startTime time.Time, endTime time.Time) []MetricDataResult +func (c client) ListMetrics(ctx context.Context, namespace string, metric *model.MetricConfig, recentlyActiveOnly bool, fn func(page []*model.Metric)) error { + filter := &aws_cloudwatch.ListMetricsInput{ + MetricName: aws.String(metric.Name), + Namespace: aws.String(namespace), + } + if recentlyActiveOnly { + filter.RecentlyActive = types.RecentlyActivePt3h + } - // GetMetricStatistics returns the output of the GetMetricStatistics CloudWatch API. - GetMetricStatistics(ctx context.Context, logger *slog.Logger, dimensions []model.Dimension, namespace string, metric *model.MetricConfig) []*model.MetricStatisticsResult -} + c.logger.Debug("ListMetrics", "input", filter) -// ConcurrencyLimiter limits the concurrency when calling AWS CloudWatch APIs. The functions implemented -// by this interface follow the same as a normal semaphore, but accept and operation identifier. Some -// implementations might use this to keep a different semaphore, with different reentrance values, per -// operation. -type ConcurrencyLimiter interface { - // Acquire takes one "ticket" from the concurrency limiter for op. If there's none available, the caller - // routine will be blocked until there's room available. - Acquire(op string) - - // Release gives back one "ticket" to the concurrency limiter identified by op. If there's one or more - // routines waiting for one, one will be woken up. - Release(op string) -} + paginator := aws_cloudwatch.NewListMetricsPaginator(c.cloudwatchAPI, filter, func(options *aws_cloudwatch.ListMetricsPaginatorOptions) { + options.StopOnDuplicateToken = true + }) -type MetricDataResult struct { - ID string - DataPoints []DataPoint + for paginator.HasMorePages() { + promutil.CloudwatchAPICounter.WithLabelValues("ListMetrics").Inc() + page, err := paginator.NextPage(ctx) + if err != nil { + promutil.CloudwatchAPIErrorCounter.WithLabelValues("ListMetrics").Inc() + c.logger.Error("ListMetrics error", "err", err) + return err + } + + metricsPage := toModelMetric(page) + c.logger.Debug("ListMetrics", "output", metricsPage) + + fn(metricsPage) + } + + return nil } -type DataPoint struct { - Value *float64 - Timestamp time.Time +func toModelMetric(page *aws_cloudwatch.ListMetricsOutput) []*model.Metric { + modelMetrics := make([]*model.Metric, 0, len(page.Metrics)) + for _, cloudwatchMetric := range page.Metrics { + modelMetric := &model.Metric{ + MetricName: *cloudwatchMetric.MetricName, + Namespace: *cloudwatchMetric.Namespace, + Dimensions: toModelDimensions(cloudwatchMetric.Dimensions), + } + modelMetrics = append(modelMetrics, modelMetric) + } + return modelMetrics } -type limitedConcurrencyClient struct { - client Client - limiter ConcurrencyLimiter +func toModelDimensions(dimensions []types.Dimension) []model.Dimension { + modelDimensions := make([]model.Dimension, 0, len(dimensions)) + for _, dimension := range dimensions { + modelDimension := model.Dimension{ + Name: *dimension.Name, + Value: *dimension.Value, + } + modelDimensions = append(modelDimensions, modelDimension) + } + return modelDimensions } -func NewLimitedConcurrencyClient(client Client, limiter ConcurrencyLimiter) Client { - return &limitedConcurrencyClient{ - client: client, - limiter: limiter, +func (c client) GetMetricData(ctx context.Context, getMetricData []*model.CloudwatchData, namespace string, startTime time.Time, endTime time.Time) []MetricDataResult { + metricDataQueries := make([]types.MetricDataQuery, 0, len(getMetricData)) + exportAllDataPoints := false + for _, data := range getMetricData { + metricStat := &types.MetricStat{ + Metric: &types.Metric{ + Dimensions: toCloudWatchDimensions(data.Dimensions), + MetricName: &data.MetricName, + Namespace: &namespace, + }, + Period: aws.Int32(int32(data.GetMetricDataProcessingParams.Period)), + Stat: &data.GetMetricDataProcessingParams.Statistic, + } + metricDataQueries = append(metricDataQueries, types.MetricDataQuery{ + Id: &data.GetMetricDataProcessingParams.QueryID, + MetricStat: metricStat, + ReturnData: aws.Bool(true), + }) + exportAllDataPoints = exportAllDataPoints || data.MetricMigrationParams.ExportAllDataPoints + } + + input := &aws_cloudwatch.GetMetricDataInput{ + EndTime: &endTime, + StartTime: &startTime, + MetricDataQueries: metricDataQueries, + ScanBy: "TimestampDescending", } + var resp aws_cloudwatch.GetMetricDataOutput + promutil.CloudwatchGetMetricDataAPIMetricsCounter.Add(float64(len(input.MetricDataQueries))) + c.logger.Debug("GetMetricData", "input", input) + + paginator := aws_cloudwatch.NewGetMetricDataPaginator(c.cloudwatchAPI, input, func(options *aws_cloudwatch.GetMetricDataPaginatorOptions) { + options.StopOnDuplicateToken = true + }) + for paginator.HasMorePages() { + promutil.CloudwatchAPICounter.WithLabelValues("GetMetricData").Inc() + promutil.CloudwatchGetMetricDataAPICounter.Inc() + + page, err := paginator.NextPage(ctx) + if err != nil { + promutil.CloudwatchAPIErrorCounter.WithLabelValues("GetMetricData").Inc() + c.logger.Error("GetMetricData error", "err", err) + return nil + } + resp.MetricDataResults = append(resp.MetricDataResults, page.MetricDataResults...) + } + + c.logger.Debug("GetMetricData", "output", resp) + + return toMetricDataResult(resp, exportAllDataPoints) } -func (c limitedConcurrencyClient) GetMetricStatistics(ctx context.Context, logger *slog.Logger, dimensions []model.Dimension, namespace string, metric *model.MetricConfig) []*model.MetricStatisticsResult { - c.limiter.Acquire(getMetricStatisticsCall) - res := c.client.GetMetricStatistics(ctx, logger, dimensions, namespace, metric) - c.limiter.Release(getMetricStatisticsCall) - return res +func toMetricDataResult(resp aws_cloudwatch.GetMetricDataOutput, exportAllDataPoints bool) []MetricDataResult { + output := make([]MetricDataResult, 0, len(resp.MetricDataResults)) + for _, metricDataResult := range resp.MetricDataResults { + mappedResult := MetricDataResult{ + ID: *metricDataResult.Id, + DataPoints: make([]DataPoint, 0, len(metricDataResult.Timestamps)), + } + for i := 0; i < len(metricDataResult.Timestamps); i++ { + mappedResult.DataPoints = append(mappedResult.DataPoints, DataPoint{ + Value: &metricDataResult.Values[i], + Timestamp: metricDataResult.Timestamps[i], + }) + + if !exportAllDataPoints { + break + } + } + output = append(output, mappedResult) + } + return output } -func (c limitedConcurrencyClient) GetMetricData(ctx context.Context, getMetricData []*model.CloudwatchData, namespace string, startTime time.Time, endTime time.Time) []MetricDataResult { - c.limiter.Acquire(getMetricDataCall) - res := c.client.GetMetricData(ctx, getMetricData, namespace, startTime, endTime) - c.limiter.Release(getMetricDataCall) - return res +func (c client) GetMetricStatistics(ctx context.Context, logger *slog.Logger, dimensions []model.Dimension, namespace string, metric *model.MetricConfig) []*model.MetricStatisticsResult { + filter := createGetMetricStatisticsInput(logger, dimensions, &namespace, metric) + c.logger.Debug("GetMetricStatistics", "input", filter) + + resp, err := c.cloudwatchAPI.GetMetricStatistics(ctx, filter) + + c.logger.Debug("GetMetricStatistics", "output", resp) + + promutil.CloudwatchAPICounter.WithLabelValues("GetMetricStatistics").Inc() + promutil.CloudwatchGetMetricStatisticsAPICounter.Inc() + + if err != nil { + promutil.CloudwatchAPIErrorCounter.WithLabelValues("GetMetricStatistics").Inc() + c.logger.Error("Failed to get metric statistics", "err", err) + return nil + } + + ptrs := make([]*types.Datapoint, 0, len(resp.Datapoints)) + for _, datapoint := range resp.Datapoints { + ptrs = append(ptrs, &datapoint) + } + + return toModelDataPoints(ptrs) } -func (c limitedConcurrencyClient) ListMetrics(ctx context.Context, namespace string, metric *model.MetricConfig, recentlyActiveOnly bool, fn func(page []*model.Metric)) error { - c.limiter.Acquire(listMetricsCall) - err := c.client.ListMetrics(ctx, namespace, metric, recentlyActiveOnly, fn) - c.limiter.Release(listMetricsCall) - return err +func toModelDataPoints(cwDataPoints []*types.Datapoint) []*model.MetricStatisticsResult { + modelDataPoints := make([]*model.MetricStatisticsResult, 0, len(cwDataPoints)) + + for _, cwDatapoint := range cwDataPoints { + extendedStats := make(map[string]*float64, len(cwDatapoint.ExtendedStatistics)) + for name, value := range cwDatapoint.ExtendedStatistics { + extendedStats[name] = &value + } + modelDataPoints = append(modelDataPoints, &model.MetricStatisticsResult{ + Average: cwDatapoint.Average, + ExtendedStatistics: extendedStats, + Maximum: cwDatapoint.Maximum, + Minimum: cwDatapoint.Minimum, + SampleCount: cwDatapoint.SampleCount, + Sum: cwDatapoint.Sum, + Timestamp: cwDatapoint.Timestamp, + }) + } + return modelDataPoints } diff --git a/pkg/clients/cloudwatch/v2/client_test.go b/pkg/clients/cloudwatch/client_test.go similarity index 75% rename from pkg/clients/cloudwatch/v2/client_test.go rename to pkg/clients/cloudwatch/client_test.go index f45cfe434..eee67806b 100644 --- a/pkg/clients/cloudwatch/v2/client_test.go +++ b/pkg/clients/cloudwatch/client_test.go @@ -10,19 +10,17 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -package v2 +package cloudwatch import ( "testing" "time" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/cloudwatch" + aws_cloudwatch "github.com/aws/aws-sdk-go-v2/service/cloudwatch" "github.com/aws/aws-sdk-go-v2/service/cloudwatch/types" "github.com/stretchr/testify/require" - - cloudwatch_client "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch" ) func Test_toMetricDataResult(t *testing.T) { @@ -31,15 +29,15 @@ func Test_toMetricDataResult(t *testing.T) { type testCase struct { name string exportAllDataPoints bool - getMetricDataOutput cloudwatch.GetMetricDataOutput - expectedMetricDataResults []cloudwatch_client.MetricDataResult + getMetricDataOutput aws_cloudwatch.GetMetricDataOutput + expectedMetricDataResults []MetricDataResult } testCases := []testCase{ { name: "all metrics present", exportAllDataPoints: false, - getMetricDataOutput: cloudwatch.GetMetricDataOutput{ + getMetricDataOutput: aws_cloudwatch.GetMetricDataOutput{ MetricDataResults: []types.MetricDataResult{ { Id: aws.String("metric-1"), @@ -53,14 +51,14 @@ func Test_toMetricDataResult(t *testing.T) { }, }, }, - expectedMetricDataResults: []cloudwatch_client.MetricDataResult{ + expectedMetricDataResults: []MetricDataResult{ { - ID: "metric-1", DataPoints: []cloudwatch_client.DataPoint{ + ID: "metric-1", DataPoints: []DataPoint{ {Value: aws.Float64(1.0), Timestamp: ts.Add(10 * time.Minute)}, }, }, { - ID: "metric-2", DataPoints: []cloudwatch_client.DataPoint{ + ID: "metric-2", DataPoints: []DataPoint{ {Value: aws.Float64(2.0), Timestamp: ts}, }, }, @@ -69,7 +67,7 @@ func Test_toMetricDataResult(t *testing.T) { { name: "metric with no values", exportAllDataPoints: false, - getMetricDataOutput: cloudwatch.GetMetricDataOutput{ + getMetricDataOutput: aws_cloudwatch.GetMetricDataOutput{ MetricDataResults: []types.MetricDataResult{ { Id: aws.String("metric-1"), @@ -83,22 +81,22 @@ func Test_toMetricDataResult(t *testing.T) { }, }, }, - expectedMetricDataResults: []cloudwatch_client.MetricDataResult{ + expectedMetricDataResults: []MetricDataResult{ { - ID: "metric-1", DataPoints: []cloudwatch_client.DataPoint{ + ID: "metric-1", DataPoints: []DataPoint{ {Value: aws.Float64(1.0), Timestamp: ts.Add(10 * time.Minute)}, }, }, { ID: "metric-2", - DataPoints: []cloudwatch_client.DataPoint{}, + DataPoints: []DataPoint{}, }, }, }, { name: "export all data points", exportAllDataPoints: true, - getMetricDataOutput: cloudwatch.GetMetricDataOutput{ + getMetricDataOutput: aws_cloudwatch.GetMetricDataOutput{ MetricDataResults: []types.MetricDataResult{ { Id: aws.String("metric-1"), @@ -112,16 +110,16 @@ func Test_toMetricDataResult(t *testing.T) { }, }, }, - expectedMetricDataResults: []cloudwatch_client.MetricDataResult{ + expectedMetricDataResults: []MetricDataResult{ { - ID: "metric-1", DataPoints: []cloudwatch_client.DataPoint{ + ID: "metric-1", DataPoints: []DataPoint{ {Value: aws.Float64(1.0), Timestamp: ts.Add(10 * time.Minute)}, {Value: aws.Float64(2.0), Timestamp: ts.Add(5 * time.Minute)}, {Value: aws.Float64(3.0), Timestamp: ts}, }, }, { - ID: "metric-2", DataPoints: []cloudwatch_client.DataPoint{ + ID: "metric-2", DataPoints: []DataPoint{ {Value: aws.Float64(2.0), Timestamp: ts}, }, }, diff --git a/pkg/clients/cloudwatch/iface.go b/pkg/clients/cloudwatch/iface.go new file mode 100644 index 000000000..9d98ecf83 --- /dev/null +++ b/pkg/clients/cloudwatch/iface.go @@ -0,0 +1,98 @@ +// Copyright 2024 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package cloudwatch + +import ( + "context" + "log/slog" + "time" + + "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" +) + +const ( + listMetricsCall = "ListMetrics" + getMetricDataCall = "GetMetricData" + getMetricStatisticsCall = "GetMetricStatistics" +) + +type Client interface { + // ListMetrics returns the list of metrics and dimensions for a given namespace + // and metric name. Results pagination is handled automatically: the caller can + // optionally pass a non-nil func in order to handle results pages. + ListMetrics(ctx context.Context, namespace string, metric *model.MetricConfig, recentlyActiveOnly bool, fn func(page []*model.Metric)) error + + // GetMetricData returns the output of the GetMetricData CloudWatch API. + // Results pagination is handled automatically. + GetMetricData(ctx context.Context, getMetricData []*model.CloudwatchData, namespace string, startTime time.Time, endTime time.Time) []MetricDataResult + + // GetMetricStatistics returns the output of the GetMetricStatistics CloudWatch API. + GetMetricStatistics(ctx context.Context, logger *slog.Logger, dimensions []model.Dimension, namespace string, metric *model.MetricConfig) []*model.MetricStatisticsResult +} + +// ConcurrencyLimiter limits the concurrency when calling AWS CloudWatch APIs. The functions implemented +// by this interface follow the same as a normal semaphore, but accept and operation identifier. Some +// implementations might use this to keep a different semaphore, with different reentrance values, per +// operation. +type ConcurrencyLimiter interface { + // Acquire takes one "ticket" from the concurrency limiter for op. If there's none available, the caller + // routine will be blocked until there's room available. + Acquire(op string) + + // Release gives back one "ticket" to the concurrency limiter identified by op. If there's one or more + // routines waiting for one, one will be woken up. + Release(op string) +} + +type MetricDataResult struct { + ID string + DataPoints []DataPoint +} + +type DataPoint struct { + Value *float64 + Timestamp time.Time +} + +type limitedConcurrencyClient struct { + client Client + limiter ConcurrencyLimiter +} + +func NewLimitedConcurrencyClient(client Client, limiter ConcurrencyLimiter) Client { + return &limitedConcurrencyClient{ + client: client, + limiter: limiter, + } +} + +func (c limitedConcurrencyClient) GetMetricStatistics(ctx context.Context, logger *slog.Logger, dimensions []model.Dimension, namespace string, metric *model.MetricConfig) []*model.MetricStatisticsResult { + c.limiter.Acquire(getMetricStatisticsCall) + res := c.client.GetMetricStatistics(ctx, logger, dimensions, namespace, metric) + c.limiter.Release(getMetricStatisticsCall) + return res +} + +func (c limitedConcurrencyClient) GetMetricData(ctx context.Context, getMetricData []*model.CloudwatchData, namespace string, startTime time.Time, endTime time.Time) []MetricDataResult { + c.limiter.Acquire(getMetricDataCall) + res := c.client.GetMetricData(ctx, getMetricData, namespace, startTime, endTime) + c.limiter.Release(getMetricDataCall) + return res +} + +func (c limitedConcurrencyClient) ListMetrics(ctx context.Context, namespace string, metric *model.MetricConfig, recentlyActiveOnly bool, fn func(page []*model.Metric)) error { + c.limiter.Acquire(listMetricsCall) + err := c.client.ListMetrics(ctx, namespace, metric, recentlyActiveOnly, fn) + c.limiter.Release(listMetricsCall) + return err +} diff --git a/pkg/clients/cloudwatch/v2/input.go b/pkg/clients/cloudwatch/input.go similarity index 99% rename from pkg/clients/cloudwatch/v2/input.go rename to pkg/clients/cloudwatch/input.go index 5c27fb9fb..2965669be 100644 --- a/pkg/clients/cloudwatch/v2/input.go +++ b/pkg/clients/cloudwatch/input.go @@ -10,7 +10,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -package v2 +package cloudwatch import ( "log/slog" diff --git a/pkg/clients/cloudwatch/v1/client.go b/pkg/clients/cloudwatch/v1/client.go deleted file mode 100644 index 92081df02..000000000 --- a/pkg/clients/cloudwatch/v1/client.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v1 - -import ( - "context" - "log/slog" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/cloudwatch" - "github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface" - - cloudwatch_client "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/promutil" -) - -type client struct { - logger *slog.Logger - cloudwatchAPI cloudwatchiface.CloudWatchAPI -} - -func NewClient(logger *slog.Logger, cloudwatchAPI cloudwatchiface.CloudWatchAPI) cloudwatch_client.Client { - return &client{ - logger: logger, - cloudwatchAPI: cloudwatchAPI, - } -} - -func (c client) ListMetrics(ctx context.Context, namespace string, metric *model.MetricConfig, recentlyActiveOnly bool, fn func(page []*model.Metric)) error { - filter := &cloudwatch.ListMetricsInput{ - MetricName: aws.String(metric.Name), - Namespace: aws.String(namespace), - } - if recentlyActiveOnly { - filter.RecentlyActive = aws.String("PT3H") - } - - c.logger.Debug("ListMetrics", "input", filter) - - err := c.cloudwatchAPI.ListMetricsPagesWithContext(ctx, filter, func(page *cloudwatch.ListMetricsOutput, lastPage bool) bool { - promutil.CloudwatchAPICounter.WithLabelValues("ListMetrics").Inc() - - metricsPage := toModelMetric(page) - - c.logger.Debug("ListMetrics", "output", metricsPage, "last_page", lastPage) - - fn(metricsPage) - return !lastPage - }) - if err != nil { - promutil.CloudwatchAPIErrorCounter.WithLabelValues("ListMetrics").Inc() - c.logger.Error("ListMetrics error", "err", err) - return err - } - - return nil -} - -func toModelMetric(page *cloudwatch.ListMetricsOutput) []*model.Metric { - modelMetrics := make([]*model.Metric, 0, len(page.Metrics)) - for _, cloudwatchMetric := range page.Metrics { - modelMetric := &model.Metric{ - MetricName: *cloudwatchMetric.MetricName, - Namespace: *cloudwatchMetric.Namespace, - Dimensions: toModelDimensions(cloudwatchMetric.Dimensions), - } - modelMetrics = append(modelMetrics, modelMetric) - } - return modelMetrics -} - -func toModelDimensions(dimensions []*cloudwatch.Dimension) []model.Dimension { - modelDimensions := make([]model.Dimension, 0, len(dimensions)) - for _, dimension := range dimensions { - modelDimension := model.Dimension{ - Name: *dimension.Name, - Value: *dimension.Value, - } - modelDimensions = append(modelDimensions, modelDimension) - } - return modelDimensions -} - -func (c client) GetMetricData(ctx context.Context, getMetricData []*model.CloudwatchData, namespace string, startTime time.Time, endTime time.Time) []cloudwatch_client.MetricDataResult { - metricDataQueries := make([]*cloudwatch.MetricDataQuery, 0, len(getMetricData)) - exportAllDataPoints := false - for _, data := range getMetricData { - metricStat := &cloudwatch.MetricStat{ - Metric: &cloudwatch.Metric{ - Dimensions: toCloudWatchDimensions(data.Dimensions), - MetricName: &data.MetricName, - Namespace: &namespace, - }, - Period: &data.GetMetricDataProcessingParams.Period, - Stat: &data.GetMetricDataProcessingParams.Statistic, - } - metricDataQueries = append(metricDataQueries, &cloudwatch.MetricDataQuery{ - Id: &data.GetMetricDataProcessingParams.QueryID, - MetricStat: metricStat, - ReturnData: aws.Bool(true), - }) - exportAllDataPoints = exportAllDataPoints || data.MetricMigrationParams.ExportAllDataPoints - } - input := &cloudwatch.GetMetricDataInput{ - EndTime: &endTime, - StartTime: &startTime, - MetricDataQueries: metricDataQueries, - ScanBy: aws.String("TimestampDescending"), - } - promutil.CloudwatchGetMetricDataAPIMetricsCounter.Add(float64(len(input.MetricDataQueries))) - c.logger.Debug("GetMetricData", "input", input) - - var resp cloudwatch.GetMetricDataOutput - // Using the paged version of the function - err := c.cloudwatchAPI.GetMetricDataPagesWithContext(ctx, input, - func(page *cloudwatch.GetMetricDataOutput, lastPage bool) bool { - promutil.CloudwatchGetMetricDataAPICounter.Inc() - promutil.CloudwatchAPICounter.WithLabelValues("GetMetricData").Inc() - resp.MetricDataResults = append(resp.MetricDataResults, page.MetricDataResults...) - return !lastPage - }) - - c.logger.Debug("GetMetricData", "output", resp) - - if err != nil { - promutil.CloudwatchAPIErrorCounter.WithLabelValues("GetMetricData").Inc() - c.logger.Error("GetMetricData error", "err", err) - return nil - } - return toMetricDataResult(resp, exportAllDataPoints) -} - -func toMetricDataResult(resp cloudwatch.GetMetricDataOutput, exportAllDataPoints bool) []cloudwatch_client.MetricDataResult { - output := make([]cloudwatch_client.MetricDataResult, 0, len(resp.MetricDataResults)) - for _, metricDataResult := range resp.MetricDataResults { - mappedResult := cloudwatch_client.MetricDataResult{ - ID: *metricDataResult.Id, - DataPoints: make([]cloudwatch_client.DataPoint, 0, len(metricDataResult.Timestamps)), - } - for i := 0; i < len(metricDataResult.Timestamps); i++ { - mappedResult.DataPoints = append(mappedResult.DataPoints, cloudwatch_client.DataPoint{ - Value: metricDataResult.Values[i], - Timestamp: *metricDataResult.Timestamps[i], - }) - - if !exportAllDataPoints { - break - } - } - output = append(output, mappedResult) - } - return output -} - -func (c client) GetMetricStatistics(ctx context.Context, logger *slog.Logger, dimensions []model.Dimension, namespace string, metric *model.MetricConfig) []*model.MetricStatisticsResult { - filter := createGetMetricStatisticsInput(dimensions, &namespace, metric, logger) - - c.logger.Debug("GetMetricStatistics", "input", filter) - - resp, err := c.cloudwatchAPI.GetMetricStatisticsWithContext(ctx, filter) - - c.logger.Debug("GetMetricStatistics", "output", resp) - - promutil.CloudwatchGetMetricStatisticsAPICounter.Inc() - promutil.CloudwatchAPICounter.WithLabelValues("GetMetricStatistics").Inc() - - if err != nil { - promutil.CloudwatchAPIErrorCounter.WithLabelValues("GetMetricStatistics").Inc() - c.logger.Error("Failed to get metric statistics", "err", err) - return nil - } - - return toModelDataPoints(resp.Datapoints) -} - -func toModelDataPoints(cwDataPoints []*cloudwatch.Datapoint) []*model.MetricStatisticsResult { - modelDataPoints := make([]*model.MetricStatisticsResult, 0, len(cwDataPoints)) - - for _, cwDatapoint := range cwDataPoints { - modelDataPoints = append(modelDataPoints, &model.MetricStatisticsResult{ - Average: cwDatapoint.Average, - ExtendedStatistics: cwDatapoint.ExtendedStatistics, - Maximum: cwDatapoint.Maximum, - Minimum: cwDatapoint.Minimum, - SampleCount: cwDatapoint.SampleCount, - Sum: cwDatapoint.Sum, - Timestamp: cwDatapoint.Timestamp, - }) - } - return modelDataPoints -} diff --git a/pkg/clients/cloudwatch/v1/client_test.go b/pkg/clients/cloudwatch/v1/client_test.go deleted file mode 100644 index b3553b5a4..000000000 --- a/pkg/clients/cloudwatch/v1/client_test.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v1 - -import ( - "testing" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/cloudwatch" - "github.com/stretchr/testify/require" - - cloudwatch_client "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" -) - -func TestDimensionsToCliString(t *testing.T) { - // Setup Test - - // Arrange - dimensions := []model.Dimension{} - expected := "" - - // Act - actual := dimensionsToCliString(dimensions) - - // Assert - if actual != expected { - t.Fatalf("\nexpected: %q\nactual: %q", expected, actual) - } -} - -func Test_toMetricDataResult(t *testing.T) { - ts := time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC) - - type testCase struct { - name string - getMetricDataOutput cloudwatch.GetMetricDataOutput - expectedMetricDataResults []cloudwatch_client.MetricDataResult - exportAllDataPoints bool - } - - testCases := []testCase{ - { - name: "all metrics present", - exportAllDataPoints: false, - getMetricDataOutput: cloudwatch.GetMetricDataOutput{ - MetricDataResults: []*cloudwatch.MetricDataResult{ - { - Id: aws.String("metric-1"), - Values: []*float64{aws.Float64(1.0), aws.Float64(2.0), aws.Float64(3.0)}, - Timestamps: []*time.Time{aws.Time(ts.Add(10 * time.Minute)), aws.Time(ts.Add(5 * time.Minute)), aws.Time(ts)}, - }, - { - Id: aws.String("metric-2"), - Values: []*float64{aws.Float64(2.0)}, - Timestamps: []*time.Time{aws.Time(ts)}, - }, - }, - }, - expectedMetricDataResults: []cloudwatch_client.MetricDataResult{ - { - ID: "metric-1", DataPoints: []cloudwatch_client.DataPoint{ - {Value: aws.Float64(1.0), Timestamp: ts.Add(10 * time.Minute)}, - }, - }, - { - ID: "metric-2", DataPoints: []cloudwatch_client.DataPoint{ - {Value: aws.Float64(2.0), Timestamp: ts}, - }, - }, - }, - }, - { - name: "metric with no values", - exportAllDataPoints: false, - getMetricDataOutput: cloudwatch.GetMetricDataOutput{ - MetricDataResults: []*cloudwatch.MetricDataResult{ - { - Id: aws.String("metric-1"), - Values: []*float64{aws.Float64(1.0), aws.Float64(2.0), aws.Float64(3.0)}, - Timestamps: []*time.Time{aws.Time(ts.Add(10 * time.Minute)), aws.Time(ts.Add(5 * time.Minute)), aws.Time(ts)}, - }, - { - Id: aws.String("metric-2"), - Values: []*float64{}, - Timestamps: []*time.Time{}, - }, - }, - }, - expectedMetricDataResults: []cloudwatch_client.MetricDataResult{ - { - ID: "metric-1", DataPoints: []cloudwatch_client.DataPoint{ - {Value: aws.Float64(1.0), Timestamp: ts.Add(10 * time.Minute)}, - }, - }, - { - ID: "metric-2", - DataPoints: []cloudwatch_client.DataPoint{}, - }, - }, - }, - { - name: "export all data points", - exportAllDataPoints: true, - getMetricDataOutput: cloudwatch.GetMetricDataOutput{ - MetricDataResults: []*cloudwatch.MetricDataResult{ - { - Id: aws.String("metric-1"), - Values: []*float64{aws.Float64(1.0), aws.Float64(2.0), aws.Float64(3.0)}, - Timestamps: []*time.Time{aws.Time(ts.Add(10 * time.Minute)), aws.Time(ts.Add(5 * time.Minute)), aws.Time(ts)}, - }, - { - Id: aws.String("metric-2"), - Values: []*float64{aws.Float64(2.0)}, - Timestamps: []*time.Time{aws.Time(ts)}, - }, - }, - }, - expectedMetricDataResults: []cloudwatch_client.MetricDataResult{ - { - ID: "metric-1", DataPoints: []cloudwatch_client.DataPoint{ - {Value: aws.Float64(1.0), Timestamp: ts.Add(10 * time.Minute)}, - {Value: aws.Float64(2.0), Timestamp: ts.Add(5 * time.Minute)}, - {Value: aws.Float64(3.0), Timestamp: ts}, - }, - }, - { - ID: "metric-2", DataPoints: []cloudwatch_client.DataPoint{ - {Value: aws.Float64(2.0), Timestamp: ts}, - }, - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - metricDataResults := toMetricDataResult(tc.getMetricDataOutput, tc.exportAllDataPoints) - require.Equal(t, tc.expectedMetricDataResults, metricDataResults) - }) - } -} diff --git a/pkg/clients/cloudwatch/v1/input.go b/pkg/clients/cloudwatch/v1/input.go deleted file mode 100644 index 309c03dbe..000000000 --- a/pkg/clients/cloudwatch/v1/input.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v1 - -import ( - "log/slog" - "strconv" - "strings" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/cloudwatch" - - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/promutil" -) - -func toCloudWatchDimensions(dimensions []model.Dimension) []*cloudwatch.Dimension { - cwDim := make([]*cloudwatch.Dimension, 0, len(dimensions)) - for _, dim := range dimensions { - // Don't take pointers directly to loop variables - cDim := dim - cwDim = append(cwDim, &cloudwatch.Dimension{ - Name: &cDim.Name, - Value: &cDim.Value, - }) - } - return cwDim -} - -func createGetMetricStatisticsInput(dimensions []model.Dimension, namespace *string, metric *model.MetricConfig, logger *slog.Logger) *cloudwatch.GetMetricStatisticsInput { - period := metric.Period - length := metric.Length - delay := metric.Delay - endTime := time.Now().Add(-time.Duration(delay) * time.Second) - startTime := time.Now().Add(-(time.Duration(length) + time.Duration(delay)) * time.Second) - - var statistics []*string - var extendedStatistics []*string - for _, statistic := range metric.Statistics { - if promutil.Percentile.MatchString(statistic) { - extendedStatistics = append(extendedStatistics, aws.String(statistic)) - } else { - statistics = append(statistics, aws.String(statistic)) - } - } - - output := &cloudwatch.GetMetricStatisticsInput{ - Dimensions: toCloudWatchDimensions(dimensions), - Namespace: namespace, - StartTime: &startTime, - EndTime: &endTime, - Period: &period, - MetricName: &metric.Name, - Statistics: statistics, - ExtendedStatistics: extendedStatistics, - } - - logger.Debug("CLI helper - " + - "aws cloudwatch get-metric-statistics" + - " --metric-name " + metric.Name + - " --dimensions " + dimensionsToCliString(dimensions) + - " --namespace " + *namespace + - " --statistics " + *statistics[0] + - " --period " + strconv.FormatInt(period, 10) + - " --start-time " + startTime.Format(time.RFC3339) + - " --end-time " + endTime.Format(time.RFC3339)) - - return output -} - -func dimensionsToCliString(dimensions []model.Dimension) string { - out := strings.Builder{} - for _, dim := range dimensions { - out.WriteString("Name=") - out.WriteString(dim.Name) - out.WriteString(",Value=") - out.WriteString(dim.Value) - out.WriteString(" ") - } - return out.String() -} diff --git a/pkg/clients/cloudwatch/v2/client.go b/pkg/clients/cloudwatch/v2/client.go deleted file mode 100644 index 0ba6933bd..000000000 --- a/pkg/clients/cloudwatch/v2/client.go +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v2 - -import ( - "context" - "log/slog" - "time" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/cloudwatch" - "github.com/aws/aws-sdk-go-v2/service/cloudwatch/types" - - cloudwatch_client "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/promutil" -) - -type client struct { - logger *slog.Logger - cloudwatchAPI *cloudwatch.Client -} - -func NewClient(logger *slog.Logger, cloudwatchAPI *cloudwatch.Client) cloudwatch_client.Client { - return &client{ - logger: logger, - cloudwatchAPI: cloudwatchAPI, - } -} - -func (c client) ListMetrics(ctx context.Context, namespace string, metric *model.MetricConfig, recentlyActiveOnly bool, fn func(page []*model.Metric)) error { - filter := &cloudwatch.ListMetricsInput{ - MetricName: aws.String(metric.Name), - Namespace: aws.String(namespace), - } - if recentlyActiveOnly { - filter.RecentlyActive = types.RecentlyActivePt3h - } - - c.logger.Debug("ListMetrics", "input", filter) - - paginator := cloudwatch.NewListMetricsPaginator(c.cloudwatchAPI, filter, func(options *cloudwatch.ListMetricsPaginatorOptions) { - options.StopOnDuplicateToken = true - }) - - for paginator.HasMorePages() { - promutil.CloudwatchAPICounter.WithLabelValues("ListMetrics").Inc() - page, err := paginator.NextPage(ctx) - if err != nil { - promutil.CloudwatchAPIErrorCounter.WithLabelValues("ListMetrics").Inc() - c.logger.Error("ListMetrics error", "err", err) - return err - } - - metricsPage := toModelMetric(page) - c.logger.Debug("ListMetrics", "output", metricsPage) - - fn(metricsPage) - } - - return nil -} - -func toModelMetric(page *cloudwatch.ListMetricsOutput) []*model.Metric { - modelMetrics := make([]*model.Metric, 0, len(page.Metrics)) - for _, cloudwatchMetric := range page.Metrics { - modelMetric := &model.Metric{ - MetricName: *cloudwatchMetric.MetricName, - Namespace: *cloudwatchMetric.Namespace, - Dimensions: toModelDimensions(cloudwatchMetric.Dimensions), - } - modelMetrics = append(modelMetrics, modelMetric) - } - return modelMetrics -} - -func toModelDimensions(dimensions []types.Dimension) []model.Dimension { - modelDimensions := make([]model.Dimension, 0, len(dimensions)) - for _, dimension := range dimensions { - modelDimension := model.Dimension{ - Name: *dimension.Name, - Value: *dimension.Value, - } - modelDimensions = append(modelDimensions, modelDimension) - } - return modelDimensions -} - -func (c client) GetMetricData(ctx context.Context, getMetricData []*model.CloudwatchData, namespace string, startTime time.Time, endTime time.Time) []cloudwatch_client.MetricDataResult { - metricDataQueries := make([]types.MetricDataQuery, 0, len(getMetricData)) - exportAllDataPoints := false - for _, data := range getMetricData { - metricStat := &types.MetricStat{ - Metric: &types.Metric{ - Dimensions: toCloudWatchDimensions(data.Dimensions), - MetricName: &data.MetricName, - Namespace: &namespace, - }, - Period: aws.Int32(int32(data.GetMetricDataProcessingParams.Period)), - Stat: &data.GetMetricDataProcessingParams.Statistic, - } - metricDataQueries = append(metricDataQueries, types.MetricDataQuery{ - Id: &data.GetMetricDataProcessingParams.QueryID, - MetricStat: metricStat, - ReturnData: aws.Bool(true), - }) - exportAllDataPoints = exportAllDataPoints || data.MetricMigrationParams.ExportAllDataPoints - } - - input := &cloudwatch.GetMetricDataInput{ - EndTime: &endTime, - StartTime: &startTime, - MetricDataQueries: metricDataQueries, - ScanBy: "TimestampDescending", - } - var resp cloudwatch.GetMetricDataOutput - promutil.CloudwatchGetMetricDataAPIMetricsCounter.Add(float64(len(input.MetricDataQueries))) - c.logger.Debug("GetMetricData", "input", input) - - paginator := cloudwatch.NewGetMetricDataPaginator(c.cloudwatchAPI, input, func(options *cloudwatch.GetMetricDataPaginatorOptions) { - options.StopOnDuplicateToken = true - }) - for paginator.HasMorePages() { - promutil.CloudwatchAPICounter.WithLabelValues("GetMetricData").Inc() - promutil.CloudwatchGetMetricDataAPICounter.Inc() - - page, err := paginator.NextPage(ctx) - if err != nil { - promutil.CloudwatchAPIErrorCounter.WithLabelValues("GetMetricData").Inc() - c.logger.Error("GetMetricData error", "err", err) - return nil - } - resp.MetricDataResults = append(resp.MetricDataResults, page.MetricDataResults...) - } - - c.logger.Debug("GetMetricData", "output", resp) - - return toMetricDataResult(resp, exportAllDataPoints) -} - -func toMetricDataResult(resp cloudwatch.GetMetricDataOutput, exportAllDataPoints bool) []cloudwatch_client.MetricDataResult { - output := make([]cloudwatch_client.MetricDataResult, 0, len(resp.MetricDataResults)) - for _, metricDataResult := range resp.MetricDataResults { - mappedResult := cloudwatch_client.MetricDataResult{ - ID: *metricDataResult.Id, - DataPoints: make([]cloudwatch_client.DataPoint, 0, len(metricDataResult.Timestamps)), - } - for i := 0; i < len(metricDataResult.Timestamps); i++ { - mappedResult.DataPoints = append(mappedResult.DataPoints, cloudwatch_client.DataPoint{ - Value: &metricDataResult.Values[i], - Timestamp: metricDataResult.Timestamps[i], - }) - - if !exportAllDataPoints { - break - } - } - output = append(output, mappedResult) - } - return output -} - -func (c client) GetMetricStatistics(ctx context.Context, logger *slog.Logger, dimensions []model.Dimension, namespace string, metric *model.MetricConfig) []*model.MetricStatisticsResult { - filter := createGetMetricStatisticsInput(logger, dimensions, &namespace, metric) - c.logger.Debug("GetMetricStatistics", "input", filter) - - resp, err := c.cloudwatchAPI.GetMetricStatistics(ctx, filter) - - c.logger.Debug("GetMetricStatistics", "output", resp) - - promutil.CloudwatchAPICounter.WithLabelValues("GetMetricStatistics").Inc() - promutil.CloudwatchGetMetricStatisticsAPICounter.Inc() - - if err != nil { - promutil.CloudwatchAPIErrorCounter.WithLabelValues("GetMetricStatistics").Inc() - c.logger.Error("Failed to get metric statistics", "err", err) - return nil - } - - ptrs := make([]*types.Datapoint, 0, len(resp.Datapoints)) - for _, datapoint := range resp.Datapoints { - ptrs = append(ptrs, &datapoint) - } - - return toModelDataPoints(ptrs) -} - -func toModelDataPoints(cwDataPoints []*types.Datapoint) []*model.MetricStatisticsResult { - modelDataPoints := make([]*model.MetricStatisticsResult, 0, len(cwDataPoints)) - - for _, cwDatapoint := range cwDataPoints { - extendedStats := make(map[string]*float64, len(cwDatapoint.ExtendedStatistics)) - for name, value := range cwDatapoint.ExtendedStatistics { - extendedStats[name] = &value - } - modelDataPoints = append(modelDataPoints, &model.MetricStatisticsResult{ - Average: cwDatapoint.Average, - ExtendedStatistics: extendedStats, - Maximum: cwDatapoint.Maximum, - Minimum: cwDatapoint.Minimum, - SampleCount: cwDatapoint.SampleCount, - Sum: cwDatapoint.Sum, - Timestamp: cwDatapoint.Timestamp, - }) - } - return modelDataPoints -} diff --git a/pkg/clients/factory.go b/pkg/clients/factory.go index 286fe5a43..6cd907e74 100644 --- a/pkg/clients/factory.go +++ b/pkg/clients/factory.go @@ -13,16 +13,462 @@ package clients import ( + "context" + "fmt" + "log/slog" + "os" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + aws_config "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/amp" + "github.com/aws/aws-sdk-go-v2/service/apigateway" + "github.com/aws/aws-sdk-go-v2/service/apigatewayv2" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" + "github.com/aws/aws-sdk-go-v2/service/cloudwatch" + "github.com/aws/aws-sdk-go-v2/service/databasemigrationservice" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" + "github.com/aws/aws-sdk-go-v2/service/shield" + "github.com/aws/aws-sdk-go-v2/service/storagegateway" + "github.com/aws/aws-sdk-go-v2/service/sts" + aws_logging "github.com/aws/smithy-go/logging" + "go.uber.org/atomic" + "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/account" cloudwatch_client "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch" "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/tagging" "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" ) -// Factory is an interface to abstract away all logic required to produce the different -// YACE specific clients which wrap AWS clients -type Factory interface { - GetCloudwatchClient(region string, role model.Role, concurrency cloudwatch_client.ConcurrencyConfig) cloudwatch_client.Client - GetTaggingClient(region string, role model.Role, concurrencyLimit int) tagging.Client - GetAccountClient(region string, role model.Role) account.Client +type awsRegion = string + +type CachingFactory struct { + logger *slog.Logger + stsOptions func(*sts.Options) + clients map[model.Role]map[awsRegion]*cachedClients + mu sync.Mutex + refreshed *atomic.Bool + cleared *atomic.Bool + fipsEnabled bool + endpointURLOverride string +} + +type cachedClients struct { + awsConfig *aws.Config + // if we know that this job is only used for static + // then we don't have to construct as many cached connections + // later on + onlyStatic bool + cloudwatch cloudwatch_client.Client + tagging tagging.Client + account account.Client +} + +// NewFactory creates a new client factory to use when fetching data from AWS with sdk v2 +func NewFactory(logger *slog.Logger, jobsCfg model.JobsConfig, fips bool) (*CachingFactory, error) { + var options []func(*aws_config.LoadOptions) error + options = append(options, aws_config.WithLogger(aws_logging.LoggerFunc(func(classification aws_logging.Classification, format string, v ...interface{}) { + switch classification { + case aws_logging.Debug: + if logger.Enabled(context.Background(), slog.LevelDebug) { + logger.Debug(fmt.Sprintf(format, v...)) + } + case aws_logging.Warn: + logger.Warn(fmt.Sprintf(format, v...)) + default: // AWS logging only supports debug or warn, log everything else as error + logger.Error(fmt.Sprintf(format, v...), "err", "unexected aws error classification", "classification", classification) + } + }))) + + options = append(options, aws_config.WithLogConfigurationWarnings(true)) + + endpointURLOverride := os.Getenv("AWS_ENDPOINT_URL") + + options = append(options, aws_config.WithRetryMaxAttempts(5)) + + c, err := aws_config.LoadDefaultConfig(context.TODO(), options...) + if err != nil { + return nil, fmt.Errorf("failed to load default aws config: %w", err) + } + + stsOptions := createStsOptions(jobsCfg.StsRegion, logger.Enabled(context.Background(), slog.LevelDebug), endpointURLOverride, fips) + cache := map[model.Role]map[awsRegion]*cachedClients{} + for _, discoveryJob := range jobsCfg.DiscoveryJobs { + for _, role := range discoveryJob.Roles { + if _, ok := cache[role]; !ok { + cache[role] = map[awsRegion]*cachedClients{} + } + for _, region := range discoveryJob.Regions { + regionConfig := awsConfigForRegion(role, &c, region, stsOptions) + cache[role][region] = &cachedClients{ + awsConfig: regionConfig, + onlyStatic: false, + } + } + } + } + + for _, staticJob := range jobsCfg.StaticJobs { + for _, role := range staticJob.Roles { + if _, ok := cache[role]; !ok { + cache[role] = map[awsRegion]*cachedClients{} + } + for _, region := range staticJob.Regions { + // Discovery job client definitions have precedence + if _, exists := cache[role][region]; !exists { + regionConfig := awsConfigForRegion(role, &c, region, stsOptions) + cache[role][region] = &cachedClients{ + awsConfig: regionConfig, + onlyStatic: true, + } + } + } + } + } + + for _, customNamespaceJob := range jobsCfg.CustomNamespaceJobs { + for _, role := range customNamespaceJob.Roles { + if _, ok := cache[role]; !ok { + cache[role] = map[awsRegion]*cachedClients{} + } + for _, region := range customNamespaceJob.Regions { + // Discovery job client definitions have precedence + if _, exists := cache[role][region]; !exists { + regionConfig := awsConfigForRegion(role, &c, region, stsOptions) + cache[role][region] = &cachedClients{ + awsConfig: regionConfig, + onlyStatic: true, + } + } + } + } + } + + return &CachingFactory{ + logger: logger, + clients: cache, + fipsEnabled: fips, + stsOptions: stsOptions, + endpointURLOverride: endpointURLOverride, + cleared: atomic.NewBool(false), + refreshed: atomic.NewBool(false), + }, nil +} + +func (c *CachingFactory) GetCloudwatchClient(region string, role model.Role, concurrency cloudwatch_client.ConcurrencyConfig) cloudwatch_client.Client { + if !c.refreshed.Load() { + // if we have not refreshed then we need to lock in case we are accessing concurrently + c.mu.Lock() + defer c.mu.Unlock() + } + if client := c.clients[role][region].cloudwatch; client != nil { + return cloudwatch_client.NewLimitedConcurrencyClient(client, concurrency.NewLimiter()) + } + c.clients[role][region].cloudwatch = cloudwatch_client.NewClient(c.logger, c.createCloudwatchClient(c.clients[role][region].awsConfig)) + return cloudwatch_client.NewLimitedConcurrencyClient(c.clients[role][region].cloudwatch, concurrency.NewLimiter()) +} + +func (c *CachingFactory) GetTaggingClient(region string, role model.Role, concurrencyLimit int) tagging.Client { + if !c.refreshed.Load() { + // if we have not refreshed then we need to lock in case we are accessing concurrently + c.mu.Lock() + defer c.mu.Unlock() + } + if client := c.clients[role][region].tagging; client != nil { + return tagging.NewLimitedConcurrencyClient(client, concurrencyLimit) + } + c.clients[role][region].tagging = tagging.NewClient( + c.logger, + c.createTaggingClient(c.clients[role][region].awsConfig), + c.createAutoScalingClient(c.clients[role][region].awsConfig), + c.createAPIGatewayClient(c.clients[role][region].awsConfig), + c.createAPIGatewayV2Client(c.clients[role][region].awsConfig), + c.createEC2Client(c.clients[role][region].awsConfig), + c.createDMSClient(c.clients[role][region].awsConfig), + c.createPrometheusClient(c.clients[role][region].awsConfig), + c.createStorageGatewayClient(c.clients[role][region].awsConfig), + c.createShieldClient(c.clients[role][region].awsConfig), + ) + return tagging.NewLimitedConcurrencyClient(c.clients[role][region].tagging, concurrencyLimit) +} + +func (c *CachingFactory) GetAccountClient(region string, role model.Role) account.Client { + if !c.refreshed.Load() { + // if we have not refreshed then we need to lock in case we are accessing concurrently + c.mu.Lock() + defer c.mu.Unlock() + } + if client := c.clients[role][region].account; client != nil { + return client + } + + stsClient := c.createStsClient(c.clients[role][region].awsConfig) + iamClient := c.createIAMClient(c.clients[role][region].awsConfig) + c.clients[role][region].account = account.NewClient(c.logger, stsClient, iamClient) + return c.clients[role][region].account +} + +func (c *CachingFactory) Refresh() { + if c.refreshed.Load() { + return + } + c.mu.Lock() + defer c.mu.Unlock() + // Avoid double refresh in the event Refresh() is called concurrently + if c.refreshed.Load() { + return + } + + for _, regionClients := range c.clients { + for _, cache := range regionClients { + cache.cloudwatch = cloudwatch_client.NewClient(c.logger, c.createCloudwatchClient(cache.awsConfig)) + if cache.onlyStatic { + continue + } + + cache.tagging = tagging.NewClient( + c.logger, + c.createTaggingClient(cache.awsConfig), + c.createAutoScalingClient(cache.awsConfig), + c.createAPIGatewayClient(cache.awsConfig), + c.createAPIGatewayV2Client(cache.awsConfig), + c.createEC2Client(cache.awsConfig), + c.createDMSClient(cache.awsConfig), + c.createPrometheusClient(cache.awsConfig), + c.createStorageGatewayClient(cache.awsConfig), + c.createShieldClient(cache.awsConfig), + ) + + cache.account = account.NewClient(c.logger, c.createStsClient(cache.awsConfig), c.createIAMClient(cache.awsConfig)) + } + } + + c.refreshed.Store(true) + c.cleared.Store(false) +} + +func (c *CachingFactory) Clear() { + if c.cleared.Load() { + return + } + // Prevent concurrent reads/write if clear is called during execution + c.mu.Lock() + defer c.mu.Unlock() + // Avoid double clear in the event Refresh() is called concurrently + if c.cleared.Load() { + return + } + + for _, regions := range c.clients { + for _, cache := range regions { + cache.cloudwatch = nil + cache.account = nil + cache.tagging = nil + } + } + + c.refreshed.Store(false) + c.cleared.Store(true) +} + +func (c *CachingFactory) createCloudwatchClient(regionConfig *aws.Config) *cloudwatch.Client { + return cloudwatch.NewFromConfig(*regionConfig, func(options *cloudwatch.Options) { + if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { + options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody + } + if c.endpointURLOverride != "" { + options.BaseEndpoint = aws.String(c.endpointURLOverride) + } + + // Setting an explicit retryer will override the default settings on the config + options.Retryer = retry.NewStandard(func(options *retry.StandardOptions) { + options.MaxAttempts = 5 + options.MaxBackoff = 3 * time.Second + }) + + if c.fipsEnabled { + options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled + } + }) +} + +func (c *CachingFactory) createTaggingClient(regionConfig *aws.Config) *resourcegroupstaggingapi.Client { + return resourcegroupstaggingapi.NewFromConfig(*regionConfig, func(options *resourcegroupstaggingapi.Options) { + if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { + options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody + } + if c.endpointURLOverride != "" { + options.BaseEndpoint = aws.String(c.endpointURLOverride) + } + // The FIPS setting is ignored because FIPS is not available for resource groups tagging apis + // If enabled the SDK will try to use non-existent FIPS URLs, https://github.com/aws/aws-sdk-go-v2/issues/2138#issuecomment-1570791988 + // AWS FIPS Reference: https://aws.amazon.com/compliance/fips/ + }) +} + +func (c *CachingFactory) createAutoScalingClient(assumedConfig *aws.Config) *autoscaling.Client { + return autoscaling.NewFromConfig(*assumedConfig, func(options *autoscaling.Options) { + if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { + options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody + } + if c.endpointURLOverride != "" { + options.BaseEndpoint = aws.String(c.endpointURLOverride) + } + // The FIPS setting is ignored because FIPS is not available for EC2 autoscaling apis + // If enabled the SDK will try to use non-existent FIPS URLs, https://github.com/aws/aws-sdk-go-v2/issues/2138#issuecomment-1570791988 + // AWS FIPS Reference: https://aws.amazon.com/compliance/fips/ + // EC2 autoscaling has FIPS compliant URLs for govcloud, but they do not use any FIPS prefixing, and should work + // with sdk v2s EndpointResolverV2 + }) +} + +func (c *CachingFactory) createAPIGatewayClient(assumedConfig *aws.Config) *apigateway.Client { + return apigateway.NewFromConfig(*assumedConfig, func(options *apigateway.Options) { + if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { + options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody + } + if c.endpointURLOverride != "" { + options.BaseEndpoint = aws.String(c.endpointURLOverride) + } + if c.fipsEnabled { + options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled + } + }) +} + +func (c *CachingFactory) createAPIGatewayV2Client(assumedConfig *aws.Config) *apigatewayv2.Client { + return apigatewayv2.NewFromConfig(*assumedConfig, func(options *apigatewayv2.Options) { + if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { + options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody + } + if c.endpointURLOverride != "" { + options.BaseEndpoint = aws.String(c.endpointURLOverride) + } + if c.fipsEnabled { + options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled + } + }) +} + +func (c *CachingFactory) createEC2Client(assumedConfig *aws.Config) *ec2.Client { + return ec2.NewFromConfig(*assumedConfig, func(options *ec2.Options) { + if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { + options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody + } + if c.endpointURLOverride != "" { + options.BaseEndpoint = aws.String(c.endpointURLOverride) + } + if c.fipsEnabled { + options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled + } + }) +} + +func (c *CachingFactory) createDMSClient(assumedConfig *aws.Config) *databasemigrationservice.Client { + return databasemigrationservice.NewFromConfig(*assumedConfig, func(options *databasemigrationservice.Options) { + if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { + options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody + } + if c.endpointURLOverride != "" { + options.BaseEndpoint = aws.String(c.endpointURLOverride) + } + if c.fipsEnabled { + options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled + } + }) +} + +func (c *CachingFactory) createStorageGatewayClient(assumedConfig *aws.Config) *storagegateway.Client { + return storagegateway.NewFromConfig(*assumedConfig, func(options *storagegateway.Options) { + if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { + options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody + } + if c.endpointURLOverride != "" { + options.BaseEndpoint = aws.String(c.endpointURLOverride) + } + if c.fipsEnabled { + options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled + } + }) +} + +func (c *CachingFactory) createPrometheusClient(assumedConfig *aws.Config) *amp.Client { + return amp.NewFromConfig(*assumedConfig, func(options *amp.Options) { + if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { + options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody + } + if c.endpointURLOverride != "" { + options.BaseEndpoint = aws.String(c.endpointURLOverride) + } + // The FIPS setting is ignored because FIPS is not available for amp apis + // If enabled the SDK will try to use non-existent FIPS URLs, https://github.com/aws/aws-sdk-go-v2/issues/2138#issuecomment-1570791988 + // AWS FIPS Reference: https://aws.amazon.com/compliance/fips/ + }) +} + +func (c *CachingFactory) createStsClient(awsConfig *aws.Config) *sts.Client { + return sts.NewFromConfig(*awsConfig, c.stsOptions) +} + +func (c *CachingFactory) createIAMClient(awsConfig *aws.Config) *iam.Client { + return iam.NewFromConfig(*awsConfig) +} + +func (c *CachingFactory) createShieldClient(awsConfig *aws.Config) *shield.Client { + return shield.NewFromConfig(*awsConfig, func(options *shield.Options) { + if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { + options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody + } + if c.endpointURLOverride != "" { + options.BaseEndpoint = aws.String(c.endpointURLOverride) + } + if c.fipsEnabled { + options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled + } + }) +} + +func createStsOptions(stsRegion string, isDebugLoggingEnabled bool, endpointURLOverride string, fipsEnabled bool) func(*sts.Options) { + return func(options *sts.Options) { + if stsRegion != "" { + options.Region = stsRegion + } + if isDebugLoggingEnabled { + options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody + } + if endpointURLOverride != "" { + options.BaseEndpoint = aws.String(endpointURLOverride) + } + if fipsEnabled { + options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled + } + } +} + +var defaultRole = model.Role{} + +func awsConfigForRegion(r model.Role, c *aws.Config, region awsRegion, stsOptions func(*sts.Options)) *aws.Config { + regionalConfig := c.Copy() + regionalConfig.Region = region + + if r == defaultRole { + return ®ionalConfig + } + + // based on https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials/stscreds#hdr-Assume_Role + // found via https://github.com/aws/aws-sdk-go-v2/issues/1382 + regionalSts := sts.NewFromConfig(*c, stsOptions) + credentials := stscreds.NewAssumeRoleProvider(regionalSts, r.RoleArn, func(options *stscreds.AssumeRoleOptions) { + if r.ExternalID != "" { + options.ExternalID = aws.String(r.ExternalID) + } + }) + regionalConfig.Credentials = aws.NewCredentialsCache(credentials) + + return ®ionalConfig } diff --git a/pkg/clients/v2/factory_test.go b/pkg/clients/factory_test.go similarity index 99% rename from pkg/clients/v2/factory_test.go rename to pkg/clients/factory_test.go index 6d7836daa..9d3f2d1b5 100644 --- a/pkg/clients/v2/factory_test.go +++ b/pkg/clients/factory_test.go @@ -10,7 +10,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -package v2 +package clients import ( "context" diff --git a/pkg/clients/iface.go b/pkg/clients/iface.go new file mode 100644 index 000000000..286fe5a43 --- /dev/null +++ b/pkg/clients/iface.go @@ -0,0 +1,28 @@ +// Copyright 2024 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package clients + +import ( + "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/account" + cloudwatch_client "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch" + "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/tagging" + "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" +) + +// Factory is an interface to abstract away all logic required to produce the different +// YACE specific clients which wrap AWS clients +type Factory interface { + GetCloudwatchClient(region string, role model.Role, concurrency cloudwatch_client.ConcurrencyConfig) cloudwatch_client.Client + GetTaggingClient(region string, role model.Role, concurrencyLimit int) tagging.Client + GetAccountClient(region string, role model.Role) account.Client +} diff --git a/pkg/clients/tagging/client.go b/pkg/clients/tagging/client.go index c43310e64..1cba11c48 100644 --- a/pkg/clients/tagging/client.go +++ b/pkg/clients/tagging/client.go @@ -14,32 +14,154 @@ package tagging import ( "context" - "errors" + "fmt" + "log/slog" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/amp" + "github.com/aws/aws-sdk-go-v2/service/apigateway" + "github.com/aws/aws-sdk-go-v2/service/apigatewayv2" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" + "github.com/aws/aws-sdk-go-v2/service/databasemigrationservice" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" + "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi/types" + "github.com/aws/aws-sdk-go-v2/service/shield" + "github.com/aws/aws-sdk-go-v2/service/storagegateway" + + "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/config" "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" + "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/promutil" ) -type Client interface { - GetResources(ctx context.Context, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) +type client struct { + logger *slog.Logger + taggingAPI *resourcegroupstaggingapi.Client + autoscalingAPI *autoscaling.Client + apiGatewayAPI *apigateway.Client + apiGatewayV2API *apigatewayv2.Client + ec2API *ec2.Client + dmsAPI *databasemigrationservice.Client + prometheusSvcAPI *amp.Client + storageGatewayAPI *storagegateway.Client + shieldAPI *shield.Client } -var ErrExpectedToFindResources = errors.New("expected to discover resources but none were found") - -type limitedConcurrencyClient struct { - client Client - sem chan struct{} +func NewClient( + logger *slog.Logger, + taggingAPI *resourcegroupstaggingapi.Client, + autoscalingAPI *autoscaling.Client, + apiGatewayAPI *apigateway.Client, + apiGatewayV2API *apigatewayv2.Client, + ec2API *ec2.Client, + dmsClient *databasemigrationservice.Client, + prometheusClient *amp.Client, + storageGatewayAPI *storagegateway.Client, + shieldAPI *shield.Client, +) Client { + return &client{ + logger: logger, + taggingAPI: taggingAPI, + autoscalingAPI: autoscalingAPI, + apiGatewayAPI: apiGatewayAPI, + apiGatewayV2API: apiGatewayV2API, + ec2API: ec2API, + dmsAPI: dmsClient, + prometheusSvcAPI: prometheusClient, + storageGatewayAPI: storageGatewayAPI, + shieldAPI: shieldAPI, + } } -func NewLimitedConcurrencyClient(client Client, maxConcurrency int) Client { - return &limitedConcurrencyClient{ - client: client, - sem: make(chan struct{}, maxConcurrency), +func (c client) GetResources(ctx context.Context, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) { + svc := config.SupportedServices.GetService(job.Namespace) + var resources []*model.TaggedResource + shouldHaveDiscoveredResources := false + + if len(svc.ResourceFilters) > 0 { + shouldHaveDiscoveredResources = true + filters := make([]string, 0, len(svc.ResourceFilters)) + for _, filter := range svc.ResourceFilters { + filters = append(filters, *filter) + } + var tagFilters []types.TagFilter + if len(job.SearchTags) > 0 { + for i := range job.SearchTags { + // Because everything with the AWS APIs is pointers we need a pointer to the `Key` field from the SearchTag. + // We can't take a pointer to any fields from loop variable or the pointer will always be the same and this logic will be broken. + st := job.SearchTags[i] + + // AWS's GetResources has a TagFilter option which matches the semantics of our SearchTags where all filters must match + // Their value matching implementation is different though so instead of mapping the Key and Value we only map the Keys. + // Their API docs say, "If you don't specify a value for a key, the response returns all resources that are tagged with that key, with any or no value." + // which makes this a safe way to reduce the amount of data we need to filter out. + // https://docs.aws.amazon.com/resourcegroupstagging/latest/APIReference/API_GetResources.html#resourcegrouptagging-GetResources-request-TagFilters + tagFilters = append(tagFilters, types.TagFilter{Key: &st.Key}) + } + } + inputparams := &resourcegroupstaggingapi.GetResourcesInput{ + ResourceTypeFilters: filters, + ResourcesPerPage: aws.Int32(int32(100)), // max allowed value according to API docs + TagFilters: tagFilters, + } + + paginator := resourcegroupstaggingapi.NewGetResourcesPaginator(c.taggingAPI, inputparams, func(options *resourcegroupstaggingapi.GetResourcesPaginatorOptions) { + options.StopOnDuplicateToken = true + }) + for paginator.HasMorePages() { + promutil.ResourceGroupTaggingAPICounter.Inc() + page, err := paginator.NextPage(ctx) + if err != nil { + return nil, err + } + + for _, resourceTagMapping := range page.ResourceTagMappingList { + resource := model.TaggedResource{ + ARN: *resourceTagMapping.ResourceARN, + Namespace: job.Namespace, + Region: region, + Tags: make([]model.Tag, 0, len(resourceTagMapping.Tags)), + } + + for _, t := range resourceTagMapping.Tags { + resource.Tags = append(resource.Tags, model.Tag{Key: *t.Key, Value: *t.Value}) + } + + if resource.FilterThroughTags(job.SearchTags) { + resources = append(resources, &resource) + } else { + c.logger.Debug("Skipping resource because search tags do not match", "arn", resource.ARN) + } + } + } + + c.logger.Debug("GetResourcesPages finished", "total", len(resources)) + } + + if ext, ok := ServiceFilters[svc.Namespace]; ok { + if ext.ResourceFunc != nil { + shouldHaveDiscoveredResources = true + newResources, err := ext.ResourceFunc(ctx, c, job, region) + if err != nil { + return nil, fmt.Errorf("failed to apply ResourceFunc for %s, %w", svc.Namespace, err) + } + resources = append(resources, newResources...) + c.logger.Debug("ResourceFunc finished", "total", len(resources)) + } + + if ext.FilterFunc != nil { + filteredResources, err := ext.FilterFunc(ctx, c, resources) + if err != nil { + return nil, fmt.Errorf("failed to apply FilterFunc for %s, %w", svc.Namespace, err) + } + resources = filteredResources + c.logger.Debug("FilterFunc finished", "total", len(resources)) + } + } + + if shouldHaveDiscoveredResources && len(resources) == 0 { + return nil, ErrExpectedToFindResources } -} -func (c limitedConcurrencyClient) GetResources(ctx context.Context, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) { - c.sem <- struct{}{} - res, err := c.client.GetResources(ctx, job, region) - <-c.sem - return res, err + return resources, nil } diff --git a/pkg/clients/tagging/v2/filters.go b/pkg/clients/tagging/filters.go similarity index 99% rename from pkg/clients/tagging/v2/filters.go rename to pkg/clients/tagging/filters.go index 8be43eb62..3dbc00801 100644 --- a/pkg/clients/tagging/v2/filters.go +++ b/pkg/clients/tagging/filters.go @@ -10,7 +10,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -package v2 +package tagging import ( "context" diff --git a/pkg/clients/tagging/iface.go b/pkg/clients/tagging/iface.go new file mode 100644 index 000000000..c43310e64 --- /dev/null +++ b/pkg/clients/tagging/iface.go @@ -0,0 +1,45 @@ +// Copyright 2024 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package tagging + +import ( + "context" + "errors" + + "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" +) + +type Client interface { + GetResources(ctx context.Context, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) +} + +var ErrExpectedToFindResources = errors.New("expected to discover resources but none were found") + +type limitedConcurrencyClient struct { + client Client + sem chan struct{} +} + +func NewLimitedConcurrencyClient(client Client, maxConcurrency int) Client { + return &limitedConcurrencyClient{ + client: client, + sem: make(chan struct{}, maxConcurrency), + } +} + +func (c limitedConcurrencyClient) GetResources(ctx context.Context, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) { + c.sem <- struct{}{} + res, err := c.client.GetResources(ctx, job, region) + <-c.sem + return res, err +} diff --git a/pkg/clients/tagging/implementation_test.go b/pkg/clients/tagging/implementation_test.go deleted file mode 100644 index 12f6a9605..000000000 --- a/pkg/clients/tagging/implementation_test.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package tagging_test - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - v1 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/tagging/v1" - v2 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/tagging/v2" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/config" -) - -func Test_Services_Have_Filters_In_V1_and_V2(t *testing.T) { - for _, service := range config.SupportedServices { - namespace := service.Namespace - t.Run(fmt.Sprintf("%s has filter definitions in v1 and v2", namespace), func(t *testing.T) { - v1Filters, v1Exists := v1.ServiceFilters[namespace] - v2Filters, v2Exists := v2.ServiceFilters[namespace] - - require.Equal(t, v1Exists, v2Exists, "Service filters are only implemented for v1 or v2 but should be implemented for both") - - v1FilterFuncNil := v1Filters.FilterFunc == nil - v2FilterFuncNil := v2Filters.FilterFunc == nil - assert.Equal(t, v1FilterFuncNil, v2FilterFuncNil, "FilterFunc is only implemented for v1 or v2 but should be implemented for both") - - v1ResourceFuncNil := v1Filters.ResourceFunc == nil - v2ResourceFuncNil := v2Filters.ResourceFunc == nil - assert.Equal(t, v1ResourceFuncNil, v2ResourceFuncNil, "ResourceFunc is only implemented for v1 or v2 but should be implemented for both") - }) - } -} diff --git a/pkg/clients/tagging/v1/client.go b/pkg/clients/tagging/v1/client.go deleted file mode 100644 index f0f6624f6..000000000 --- a/pkg/clients/tagging/v1/client.go +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v1 - -import ( - "context" - "fmt" - "log/slog" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/apigateway/apigatewayiface" - "github.com/aws/aws-sdk-go/service/apigatewayv2/apigatewayv2iface" - "github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface" - "github.com/aws/aws-sdk-go/service/databasemigrationservice/databasemigrationserviceiface" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/prometheusservice/prometheusserviceiface" - "github.com/aws/aws-sdk-go/service/resourcegroupstaggingapi" - "github.com/aws/aws-sdk-go/service/resourcegroupstaggingapi/resourcegroupstaggingapiiface" - "github.com/aws/aws-sdk-go/service/shield/shieldiface" - "github.com/aws/aws-sdk-go/service/storagegateway/storagegatewayiface" - - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/tagging" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/config" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/promutil" -) - -type client struct { - logger *slog.Logger - taggingAPI resourcegroupstaggingapiiface.ResourceGroupsTaggingAPIAPI - autoscalingAPI autoscalingiface.AutoScalingAPI - apiGatewayAPI apigatewayiface.APIGatewayAPI - apiGatewayV2API apigatewayv2iface.ApiGatewayV2API - ec2API ec2iface.EC2API - dmsAPI databasemigrationserviceiface.DatabaseMigrationServiceAPI - prometheusSvcAPI prometheusserviceiface.PrometheusServiceAPI - storageGatewayAPI storagegatewayiface.StorageGatewayAPI - shieldAPI shieldiface.ShieldAPI -} - -func NewClient( - logger *slog.Logger, - taggingAPI resourcegroupstaggingapiiface.ResourceGroupsTaggingAPIAPI, - autoscalingAPI autoscalingiface.AutoScalingAPI, - apiGatewayAPI apigatewayiface.APIGatewayAPI, - apiGatewayV2API apigatewayv2iface.ApiGatewayV2API, - ec2API ec2iface.EC2API, - dmsClient databasemigrationserviceiface.DatabaseMigrationServiceAPI, - prometheusClient prometheusserviceiface.PrometheusServiceAPI, - storageGatewayAPI storagegatewayiface.StorageGatewayAPI, - shieldAPI shieldiface.ShieldAPI, -) tagging.Client { - return &client{ - logger: logger, - taggingAPI: taggingAPI, - autoscalingAPI: autoscalingAPI, - apiGatewayAPI: apiGatewayAPI, - apiGatewayV2API: apiGatewayV2API, - ec2API: ec2API, - dmsAPI: dmsClient, - prometheusSvcAPI: prometheusClient, - storageGatewayAPI: storageGatewayAPI, - shieldAPI: shieldAPI, - } -} - -func (c client) GetResources(ctx context.Context, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) { - svc := config.SupportedServices.GetService(job.Namespace) - var resources []*model.TaggedResource - shouldHaveDiscoveredResources := false - - if len(svc.ResourceFilters) > 0 { - shouldHaveDiscoveredResources = true - - var tagFilters []*resourcegroupstaggingapi.TagFilter - if len(job.SearchTags) > 0 { - for i := range job.SearchTags { - // Because everything with the AWS APIs is pointers we need a pointer to the `Key` field from the SearchTag. - // We can't take a pointer to any fields from loop variable or the pointer will always be the same and this logic will be broken. - st := job.SearchTags[i] - - // AWS's GetResources has a TagFilter option which matches the semantics of our SearchTags where all filters must match - // Their value matching implementation is different though so instead of mapping the Key and Value we only map the Keys. - // Their API docs say, "If you don't specify a value for a key, the response returns all resources that are tagged with that key, with any or no value." - // which makes this a safe way to reduce the amount of data we need to filter out. - // https://docs.aws.amazon.com/resourcegroupstagging/latest/APIReference/API_GetResources.html#resourcegrouptagging-GetResources-request-TagFilters - tagFilters = append(tagFilters, &resourcegroupstaggingapi.TagFilter{Key: &st.Key}) - } - } - - inputparams := &resourcegroupstaggingapi.GetResourcesInput{ - ResourceTypeFilters: svc.ResourceFilters, - ResourcesPerPage: aws.Int64(100), // max allowed value according to API docs - TagFilters: tagFilters, - } - pageNum := 0 - - err := c.taggingAPI.GetResourcesPagesWithContext(ctx, inputparams, func(page *resourcegroupstaggingapi.GetResourcesOutput, lastPage bool) bool { - pageNum++ - promutil.ResourceGroupTaggingAPICounter.Inc() - - for _, resourceTagMapping := range page.ResourceTagMappingList { - resource := model.TaggedResource{ - ARN: aws.StringValue(resourceTagMapping.ResourceARN), - Namespace: job.Namespace, - Region: region, - Tags: make([]model.Tag, 0, len(resourceTagMapping.Tags)), - } - - for _, t := range resourceTagMapping.Tags { - resource.Tags = append(resource.Tags, model.Tag{Key: *t.Key, Value: *t.Value}) - } - - if resource.FilterThroughTags(job.SearchTags) { - resources = append(resources, &resource) - } else { - c.logger.Debug("Skipping resource because search tags do not match", "arn", resource.ARN) - } - } - return !lastPage - }) - if err != nil { - return nil, err - } - - c.logger.Debug("GetResourcesPages finished", "total", len(resources)) - } - - if ext, ok := ServiceFilters[svc.Namespace]; ok { - if ext.ResourceFunc != nil { - shouldHaveDiscoveredResources = true - newResources, err := ext.ResourceFunc(ctx, c, job, region) - if err != nil { - return nil, fmt.Errorf("failed to apply ResourceFunc for %s, %w", svc.Namespace, err) - } - resources = append(resources, newResources...) - c.logger.Debug("ResourceFunc finished", "total", len(resources)) - } - - if ext.FilterFunc != nil { - filteredResources, err := ext.FilterFunc(ctx, c, resources) - if err != nil { - return nil, fmt.Errorf("failed to apply FilterFunc for %s, %w", svc.Namespace, err) - } - resources = filteredResources - c.logger.Debug("FilterFunc finished", "total", len(resources)) - } - } - - if shouldHaveDiscoveredResources && len(resources) == 0 { - return nil, tagging.ErrExpectedToFindResources - } - - return resources, nil -} diff --git a/pkg/clients/tagging/v1/filters.go b/pkg/clients/tagging/v1/filters.go deleted file mode 100644 index b02820450..000000000 --- a/pkg/clients/tagging/v1/filters.go +++ /dev/null @@ -1,372 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v1 - -import ( - "context" - "fmt" - "strings" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/service/apigateway" - "github.com/aws/aws-sdk-go/service/apigatewayv2" - "github.com/aws/aws-sdk-go/service/autoscaling" - "github.com/aws/aws-sdk-go/service/databasemigrationservice" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/prometheusservice" - "github.com/aws/aws-sdk-go/service/shield" - "github.com/aws/aws-sdk-go/service/storagegateway" - - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/promutil" -) - -type ServiceFilter struct { - // ResourceFunc can be used to fetch additional resources - ResourceFunc func(context.Context, client, model.DiscoveryJob, string) ([]*model.TaggedResource, error) - - // FilterFunc can be used to the input resources or to drop based on some condition - FilterFunc func(context.Context, client, []*model.TaggedResource) ([]*model.TaggedResource, error) -} - -// ServiceFilters maps a service namespace to (optional) ServiceFilter -var ServiceFilters = map[string]ServiceFilter{ - "AWS/ApiGateway": { - // ApiGateway ARNs use the Id (for v1 REST APIs) and ApiId (for v2 APIs) instead of - // the ApiName (display name). See https://docs.aws.amazon.com/apigateway/latest/developerguide/arn-format-reference.html - // However, in metrics, the ApiId dimension uses the ApiName as value. - // - // Here we use the ApiGateway API to map resource correctly. For backward compatibility, - // in v1 REST APIs we change the ARN to replace the ApiId with ApiName, while for v2 APIs - // we leave the ARN as-is. - FilterFunc: func(ctx context.Context, client client, inputResources []*model.TaggedResource) ([]*model.TaggedResource, error) { - var limit int64 = 500 // max number of results per page. default=25, max=500 - const maxPages = 10 - input := apigateway.GetRestApisInput{Limit: &limit} - output := apigateway.GetRestApisOutput{} - var pageNum int - - err := client.apiGatewayAPI.GetRestApisPagesWithContext(ctx, &input, func(page *apigateway.GetRestApisOutput, _ bool) bool { - promutil.APIGatewayAPICounter.Inc() - pageNum++ - output.Items = append(output.Items, page.Items...) - return pageNum <= maxPages - }) - if err != nil { - return nil, fmt.Errorf("error calling apiGatewayAPI.GetRestApisPages, %w", err) - } - - outputV2, err := client.apiGatewayV2API.GetApisWithContext(ctx, &apigatewayv2.GetApisInput{}) - promutil.APIGatewayAPIV2Counter.Inc() - if err != nil { - return nil, fmt.Errorf("error calling apiGatewayAPIv2.GetApis, %w", err) - } - - var outputResources []*model.TaggedResource - for _, resource := range inputResources { - for i, gw := range output.Items { - if strings.HasSuffix(resource.ARN, "/restapis/"+*gw.Id) { - r := resource - r.ARN = strings.ReplaceAll(resource.ARN, *gw.Id, *gw.Name) - outputResources = append(outputResources, r) - output.Items = append(output.Items[:i], output.Items[i+1:]...) - break - } - } - for i, gw := range outputV2.Items { - if strings.HasSuffix(resource.ARN, "/apis/"+*gw.ApiId) { - outputResources = append(outputResources, resource) - outputV2.Items = append(outputV2.Items[:i], outputV2.Items[i+1:]...) - break - } - } - } - return outputResources, nil - }, - }, - "AWS/AutoScaling": { - ResourceFunc: func(ctx context.Context, client client, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) { - pageNum := 0 - var resources []*model.TaggedResource - err := client.autoscalingAPI.DescribeAutoScalingGroupsPagesWithContext(ctx, &autoscaling.DescribeAutoScalingGroupsInput{}, - func(page *autoscaling.DescribeAutoScalingGroupsOutput, _ bool) bool { - pageNum++ - promutil.AutoScalingAPICounter.Inc() - - for _, asg := range page.AutoScalingGroups { - resource := model.TaggedResource{ - ARN: aws.StringValue(asg.AutoScalingGroupARN), - Namespace: job.Namespace, - Region: region, - } - - for _, t := range asg.Tags { - resource.Tags = append(resource.Tags, model.Tag{Key: *t.Key, Value: *t.Value}) - } - - if resource.FilterThroughTags(job.SearchTags) { - resources = append(resources, &resource) - } - } - return pageNum < 100 - }, - ) - if err != nil { - return nil, fmt.Errorf("error calling autoscalingAPI.DescribeAutoScalingGroups, %w", err) - } - return resources, nil - }, - }, - "AWS/DMS": { - // Append the replication instance identifier to DMS task and instance ARNs - FilterFunc: func(ctx context.Context, client client, inputResources []*model.TaggedResource) ([]*model.TaggedResource, error) { - if len(inputResources) == 0 { - return inputResources, nil - } - - replicationInstanceIdentifiers := make(map[string]string) - pageNum := 0 - if err := client.dmsAPI.DescribeReplicationInstancesPagesWithContext(ctx, nil, - func(page *databasemigrationservice.DescribeReplicationInstancesOutput, _ bool) bool { - pageNum++ - promutil.DmsAPICounter.Inc() - - for _, instance := range page.ReplicationInstances { - replicationInstanceIdentifiers[aws.StringValue(instance.ReplicationInstanceArn)] = aws.StringValue(instance.ReplicationInstanceIdentifier) - } - - return pageNum < 100 - }, - ); err != nil { - return nil, fmt.Errorf("error calling dmsAPI.DescribeReplicationInstances, %w", err) - } - pageNum = 0 - if err := client.dmsAPI.DescribeReplicationTasksPagesWithContext(ctx, nil, - func(page *databasemigrationservice.DescribeReplicationTasksOutput, _ bool) bool { - pageNum++ - promutil.DmsAPICounter.Inc() - - for _, task := range page.ReplicationTasks { - taskInstanceArn := aws.StringValue(task.ReplicationInstanceArn) - if instanceIdentifier, ok := replicationInstanceIdentifiers[taskInstanceArn]; ok { - replicationInstanceIdentifiers[aws.StringValue(task.ReplicationTaskArn)] = instanceIdentifier - } - } - - return pageNum < 100 - }, - ); err != nil { - return nil, fmt.Errorf("error calling dmsAPI.DescribeReplicationTasks, %w", err) - } - - var outputResources []*model.TaggedResource - for _, resource := range inputResources { - r := resource - // Append the replication instance identifier to replication instance and task ARNs - if instanceIdentifier, ok := replicationInstanceIdentifiers[r.ARN]; ok { - r.ARN = fmt.Sprintf("%s/%s", r.ARN, instanceIdentifier) - } - outputResources = append(outputResources, r) - } - return outputResources, nil - }, - }, - "AWS/EC2Spot": { - ResourceFunc: func(ctx context.Context, client client, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) { - pageNum := 0 - var resources []*model.TaggedResource - err := client.ec2API.DescribeSpotFleetRequestsPagesWithContext(ctx, &ec2.DescribeSpotFleetRequestsInput{}, - func(page *ec2.DescribeSpotFleetRequestsOutput, _ bool) bool { - pageNum++ - promutil.Ec2APICounter.Inc() - - for _, ec2Spot := range page.SpotFleetRequestConfigs { - resource := model.TaggedResource{ - ARN: aws.StringValue(ec2Spot.SpotFleetRequestId), - Namespace: job.Namespace, - Region: region, - } - - for _, t := range ec2Spot.Tags { - resource.Tags = append(resource.Tags, model.Tag{Key: *t.Key, Value: *t.Value}) - } - - if resource.FilterThroughTags(job.SearchTags) { - resources = append(resources, &resource) - } - } - return pageNum < 100 - }, - ) - if err != nil { - return nil, fmt.Errorf("error calling describing ec2API.DescribeSpotFleetRequests, %w", err) - } - return resources, nil - }, - }, - "AWS/Prometheus": { - ResourceFunc: func(ctx context.Context, client client, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) { - pageNum := 0 - var resources []*model.TaggedResource - err := client.prometheusSvcAPI.ListWorkspacesPagesWithContext(ctx, &prometheusservice.ListWorkspacesInput{}, - func(page *prometheusservice.ListWorkspacesOutput, _ bool) bool { - pageNum++ - promutil.ManagedPrometheusAPICounter.Inc() - - for _, ws := range page.Workspaces { - resource := model.TaggedResource{ - ARN: aws.StringValue(ws.Arn), - Namespace: job.Namespace, - Region: region, - } - - for key, value := range ws.Tags { - resource.Tags = append(resource.Tags, model.Tag{Key: key, Value: *value}) - } - - if resource.FilterThroughTags(job.SearchTags) { - resources = append(resources, &resource) - } - } - return pageNum < 100 - }, - ) - if err != nil { - return nil, fmt.Errorf("error while calling prometheusSvcAPI.ListWorkspaces, %w", err) - } - return resources, nil - }, - }, - "AWS/StorageGateway": { - ResourceFunc: func(ctx context.Context, client client, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) { - pageNum := 0 - var resources []*model.TaggedResource - err := client.storageGatewayAPI.ListGatewaysPagesWithContext(ctx, &storagegateway.ListGatewaysInput{}, - func(page *storagegateway.ListGatewaysOutput, _ bool) bool { - pageNum++ - promutil.StoragegatewayAPICounter.Inc() - - for _, gwa := range page.Gateways { - resource := model.TaggedResource{ - ARN: fmt.Sprintf("%s/%s", *gwa.GatewayId, *gwa.GatewayName), - Namespace: job.Namespace, - Region: region, - } - - tagsRequest := &storagegateway.ListTagsForResourceInput{ - ResourceARN: gwa.GatewayARN, - } - tagsResponse, _ := client.storageGatewayAPI.ListTagsForResource(tagsRequest) - promutil.StoragegatewayAPICounter.Inc() - - for _, t := range tagsResponse.Tags { - resource.Tags = append(resource.Tags, model.Tag{Key: *t.Key, Value: *t.Value}) - } - - if resource.FilterThroughTags(job.SearchTags) { - resources = append(resources, &resource) - } - } - - return pageNum < 100 - }, - ) - if err != nil { - return nil, fmt.Errorf("error calling storageGatewayAPI.ListGateways, %w", err) - } - return resources, nil - }, - }, - "AWS/TransitGateway": { - ResourceFunc: func(ctx context.Context, client client, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) { - pageNum := 0 - var resources []*model.TaggedResource - err := client.ec2API.DescribeTransitGatewayAttachmentsPagesWithContext(ctx, &ec2.DescribeTransitGatewayAttachmentsInput{}, - func(page *ec2.DescribeTransitGatewayAttachmentsOutput, _ bool) bool { - pageNum++ - promutil.Ec2APICounter.Inc() - - for _, tgwa := range page.TransitGatewayAttachments { - resource := model.TaggedResource{ - ARN: fmt.Sprintf("%s/%s", *tgwa.TransitGatewayId, *tgwa.TransitGatewayAttachmentId), - Namespace: job.Namespace, - Region: region, - } - - for _, t := range tgwa.Tags { - resource.Tags = append(resource.Tags, model.Tag{Key: *t.Key, Value: *t.Value}) - } - - if resource.FilterThroughTags(job.SearchTags) { - resources = append(resources, &resource) - } - } - return pageNum < 100 - }, - ) - if err != nil { - return nil, fmt.Errorf("error calling ec2API.DescribeTransitGatewayAttachments, %w", err) - } - return resources, nil - }, - }, - "AWS/DDoSProtection": { - // Resource discovery only targets the protections, protections are global, so they will only be discoverable in us-east-1. - // Outside us-east-1 no resources are going to be found. We use the shield.ListProtections API to get the protections + - // protected resources to add to the tagged resources. This data is eventually usable for joining with metrics. - ResourceFunc: func(ctx context.Context, c client, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) { - var output []*model.TaggedResource - pageNum := 0 - // Default page size is only 20 which can easily lead to throttling - input := &shield.ListProtectionsInput{MaxResults: aws.Int64(1000)} - err := c.shieldAPI.ListProtectionsPagesWithContext(ctx, input, func(page *shield.ListProtectionsOutput, _ bool) bool { - promutil.ShieldAPICounter.Inc() - for _, protection := range page.Protections { - protectedResourceArn := *protection.ResourceArn - protectionArn := *protection.ProtectionArn - protectedResource, err := arn.Parse(protectedResourceArn) - if err != nil { - continue - } - - // Shield covers regional services, - // EC2 (arn:aws:ec2:::eip-allocation/*) - // load balancers (arn:aws:elasticloadbalancing:::loadbalancer:*) - // where the region of the protectedResource ARN should match the region for the job to prevent - // duplicating resources across all regions - // Shield also covers other global services, - // global accelerator (arn:aws:globalaccelerator:::accelerator/*) - // route53 (arn:aws:route53:::hostedzone/*) - // where the protectedResource contains no region. Just like other global services the metrics for - // these land in us-east-1 so any protected resource without a region should be added when the job - // is for us-east-1 - if protectedResource.Region == region || (protectedResource.Region == "" && region == "us-east-1") { - taggedResource := &model.TaggedResource{ - ARN: protectedResourceArn, - Namespace: job.Namespace, - Region: region, - Tags: []model.Tag{{Key: "ProtectionArn", Value: protectionArn}}, - } - output = append(output, taggedResource) - } - } - return pageNum < 100 - }) - if err != nil { - return nil, fmt.Errorf("error calling shiled.ListProtections, %w", err) - } - return output, nil - }, - }, -} diff --git a/pkg/clients/tagging/v1/filters_test.go b/pkg/clients/tagging/v1/filters_test.go deleted file mode 100644 index 83108ec7d..000000000 --- a/pkg/clients/tagging/v1/filters_test.go +++ /dev/null @@ -1,463 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v1 - -import ( - "context" - "reflect" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/apigateway" - "github.com/aws/aws-sdk-go/service/apigateway/apigatewayiface" - "github.com/aws/aws-sdk-go/service/apigatewayv2" - "github.com/aws/aws-sdk-go/service/apigatewayv2/apigatewayv2iface" - "github.com/aws/aws-sdk-go/service/databasemigrationservice" - "github.com/aws/aws-sdk-go/service/databasemigrationservice/databasemigrationserviceiface" - - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/config" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" -) - -func TestValidServiceNames(t *testing.T) { - for svc, filter := range ServiceFilters { - if config.SupportedServices.GetService(svc) == nil { - t.Errorf("invalid service name '%s'", svc) - t.Fail() - } - - if filter.FilterFunc == nil && filter.ResourceFunc == nil { - t.Errorf("no filter functions defined for service name '%s'", svc) - t.FailNow() - } - } -} - -func TestApiGatewayFilterFunc(t *testing.T) { - tests := []struct { - name string - iface client - inputResources []*model.TaggedResource - outputResources []*model.TaggedResource - }{ - { - "api gateway resources skip stages", - client{ - apiGatewayAPI: apiGatewayClient{ - getRestApisOutput: &apigateway.GetRestApisOutput{ - Items: []*apigateway.RestApi{ - { - ApiKeySource: nil, - BinaryMediaTypes: nil, - CreatedDate: nil, - Description: nil, - DisableExecuteApiEndpoint: nil, - EndpointConfiguration: nil, - Id: aws.String("gwid1234"), - MinimumCompressionSize: nil, - Name: aws.String("apiname"), - Policy: nil, - Tags: nil, - Version: nil, - Warnings: nil, - }, - }, - Position: nil, - }, - }, - apiGatewayV2API: apiGatewayV2Client{ - getRestApisOutput: &apigatewayv2.GetApisOutput{ - Items: []*apigatewayv2.Api{}, - }, - }, - }, - []*model.TaggedResource{ - { - ARN: "arn:aws:apigateway:us-east-1::/restapis/gwid1234/stages/main", - Namespace: "apigateway", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value", - }, - }, - }, - { - ARN: "arn:aws:apigateway:us-east-1::/restapis/gwid1234", - Namespace: "apigateway", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 2", - }, - }, - }, - }, - []*model.TaggedResource{ - { - ARN: "arn:aws:apigateway:us-east-1::/restapis/apiname", - Namespace: "apigateway", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 2", - }, - }, - }, - }, - }, - { - "api gateway v2", - client{ - apiGatewayAPI: apiGatewayClient{ - getRestApisOutput: &apigateway.GetRestApisOutput{ - Items: []*apigateway.RestApi{}, - }, - }, - apiGatewayV2API: apiGatewayV2Client{ - getRestApisOutput: &apigatewayv2.GetApisOutput{ - Items: []*apigatewayv2.Api{ - { - CreatedDate: nil, - Description: nil, - DisableExecuteApiEndpoint: nil, - ApiId: aws.String("gwid9876"), - Name: aws.String("apiv2name"), - Tags: nil, - Version: nil, - Warnings: nil, - }, - }, - NextToken: nil, - }, - }, - }, - []*model.TaggedResource{ - { - ARN: "arn:aws:apigateway:us-east-1::/apis/gwid9876/stages/$default", - Namespace: "apigateway", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value", - }, - }, - }, - { - ARN: "arn:aws:apigateway:us-east-1::/apis/gwid9876", - Namespace: "apigateway", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 2", - }, - }, - }, - }, - []*model.TaggedResource{ - { - ARN: "arn:aws:apigateway:us-east-1::/apis/gwid9876", - Namespace: "apigateway", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 2", - }, - }, - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - apigateway := ServiceFilters["AWS/ApiGateway"] - - outputResources, err := apigateway.FilterFunc(context.Background(), test.iface, test.inputResources) - if err != nil { - t.Logf("Error from FilterFunc: %v", err) - t.FailNow() - } - if len(outputResources) != len(test.outputResources) { - t.Logf("len(outputResources) = %d, want %d", len(outputResources), len(test.outputResources)) - t.Fail() - } - for i, resource := range outputResources { - if len(test.outputResources) <= i { - break - } - wantResource := *test.outputResources[i] - if !reflect.DeepEqual(*resource, wantResource) { - t.Errorf("outputResources[%d] = %+v, want %+v", i, *resource, wantResource) - } - } - }) - } -} - -func TestDMSFilterFunc(t *testing.T) { - tests := []struct { - name string - iface client - inputResources []*model.TaggedResource - outputResources []*model.TaggedResource - }{ - { - "empty input resources", - client{}, - []*model.TaggedResource{}, - []*model.TaggedResource{}, - }, - { - "replication tasks and instances", - client{ - dmsAPI: dmsClient{ - describeReplicationInstancesOutput: &databasemigrationservice.DescribeReplicationInstancesOutput{ - ReplicationInstances: []*databasemigrationservice.ReplicationInstance{ - { - ReplicationInstanceArn: aws.String("arn:aws:dms:us-east-1:123123123123:rep:ABCDEFG1234567890"), - ReplicationInstanceIdentifier: aws.String("repl-instance-identifier-1"), - }, - { - ReplicationInstanceArn: aws.String("arn:aws:dms:us-east-1:123123123123:rep:ZZZZZZZZZZZZZZZZZ"), - ReplicationInstanceIdentifier: aws.String("repl-instance-identifier-2"), - }, - { - ReplicationInstanceArn: aws.String("arn:aws:dms:us-east-1:123123123123:rep:YYYYYYYYYYYYYYYYY"), - ReplicationInstanceIdentifier: aws.String("repl-instance-identifier-3"), - }, - }, - }, - describeReplicationTasksOutput: &databasemigrationservice.DescribeReplicationTasksOutput{ - ReplicationTasks: []*databasemigrationservice.ReplicationTask{ - { - ReplicationTaskArn: aws.String("arn:aws:dms:us-east-1:123123123123:task:9999999999999999"), - ReplicationInstanceArn: aws.String("arn:aws:dms:us-east-1:123123123123:rep:ZZZZZZZZZZZZZZZZZ"), - }, - { - ReplicationTaskArn: aws.String("arn:aws:dms:us-east-1:123123123123:task:2222222222222222"), - ReplicationInstanceArn: aws.String("arn:aws:dms:us-east-1:123123123123:rep:ZZZZZZZZZZZZZZZZZ"), - }, - { - ReplicationTaskArn: aws.String("arn:aws:dms:us-east-1:123123123123:task:3333333333333333"), - ReplicationInstanceArn: aws.String("arn:aws:dms:us-east-1:123123123123:rep:WWWWWWWWWWWWWWWWW"), - }, - }, - }, - }, - }, - []*model.TaggedResource{ - { - ARN: "arn:aws:dms:us-east-1:123123123123:rep:ABCDEFG1234567890", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value", - }, - }, - }, - { - ARN: "arn:aws:dms:us-east-1:123123123123:rep:WXYZ987654321", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 2", - }, - }, - }, - { - ARN: "arn:aws:dms:us-east-1:123123123123:task:9999999999999999", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 3", - }, - }, - }, - { - ARN: "arn:aws:dms:us-east-1:123123123123:task:5555555555555555", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 4", - }, - }, - }, - { - ARN: "arn:aws:dms:us-east-1:123123123123:subgrp:demo-subgrp", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 5", - }, - }, - }, - { - ARN: "arn:aws:dms:us-east-1:123123123123:endpoint:1111111111111111", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 6", - }, - }, - }, - }, - []*model.TaggedResource{ - { - ARN: "arn:aws:dms:us-east-1:123123123123:rep:ABCDEFG1234567890/repl-instance-identifier-1", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value", - }, - }, - }, - { - ARN: "arn:aws:dms:us-east-1:123123123123:rep:WXYZ987654321", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 2", - }, - }, - }, - { - ARN: "arn:aws:dms:us-east-1:123123123123:task:9999999999999999/repl-instance-identifier-2", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 3", - }, - }, - }, - { - ARN: "arn:aws:dms:us-east-1:123123123123:task:5555555555555555", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 4", - }, - }, - }, - { - ARN: "arn:aws:dms:us-east-1:123123123123:subgrp:demo-subgrp", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 5", - }, - }, - }, - { - ARN: "arn:aws:dms:us-east-1:123123123123:endpoint:1111111111111111", - Namespace: "dms", - Region: "us-east-1", - Tags: []model.Tag{ - { - Key: "Test", - Value: "Value 6", - }, - }, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - dms := ServiceFilters["AWS/DMS"] - - outputResources, err := dms.FilterFunc(context.Background(), test.iface, test.inputResources) - if err != nil { - t.Logf("Error from FilterFunc: %v", err) - t.FailNow() - } - if len(outputResources) != len(test.outputResources) { - t.Logf("len(outputResources) = %d, want %d", len(outputResources), len(test.outputResources)) - t.Fail() - } - for i, resource := range outputResources { - if len(test.outputResources) <= i { - break - } - wantResource := *test.outputResources[i] - if !reflect.DeepEqual(*resource, wantResource) { - t.Errorf("outputResources[%d] = %+v, want %+v", i, *resource, wantResource) - } - } - }) - } -} - -type apiGatewayClient struct { - apigatewayiface.APIGatewayAPI - getRestApisOutput *apigateway.GetRestApisOutput -} - -func (apigateway apiGatewayClient) GetRestApisPagesWithContext(_ aws.Context, _ *apigateway.GetRestApisInput, fn func(*apigateway.GetRestApisOutput, bool) bool, _ ...request.Option) error { - fn(apigateway.getRestApisOutput, true) - return nil -} - -type apiGatewayV2Client struct { - apigatewayv2iface.ApiGatewayV2API - getRestApisOutput *apigatewayv2.GetApisOutput -} - -func (apigateway apiGatewayV2Client) GetApisWithContext(_ aws.Context, _ *apigatewayv2.GetApisInput, _ ...request.Option) (*apigatewayv2.GetApisOutput, error) { - return apigateway.getRestApisOutput, nil -} - -type dmsClient struct { - databasemigrationserviceiface.DatabaseMigrationServiceAPI - describeReplicationInstancesOutput *databasemigrationservice.DescribeReplicationInstancesOutput - describeReplicationTasksOutput *databasemigrationservice.DescribeReplicationTasksOutput -} - -func (dms dmsClient) DescribeReplicationInstancesPagesWithContext(_ aws.Context, _ *databasemigrationservice.DescribeReplicationInstancesInput, fn func(*databasemigrationservice.DescribeReplicationInstancesOutput, bool) bool, _ ...request.Option) error { - fn(dms.describeReplicationInstancesOutput, true) - return nil -} - -func (dms dmsClient) DescribeReplicationTasksPagesWithContext(_ aws.Context, _ *databasemigrationservice.DescribeReplicationTasksInput, fn func(*databasemigrationservice.DescribeReplicationTasksOutput, bool) bool, _ ...request.Option) error { - fn(dms.describeReplicationTasksOutput, true) - return nil -} diff --git a/pkg/clients/tagging/v2/client.go b/pkg/clients/tagging/v2/client.go deleted file mode 100644 index 5f0d704ef..000000000 --- a/pkg/clients/tagging/v2/client.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v2 - -import ( - "context" - "fmt" - "log/slog" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/amp" - "github.com/aws/aws-sdk-go-v2/service/apigateway" - "github.com/aws/aws-sdk-go-v2/service/apigatewayv2" - "github.com/aws/aws-sdk-go-v2/service/autoscaling" - "github.com/aws/aws-sdk-go-v2/service/databasemigrationservice" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" - "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi/types" - "github.com/aws/aws-sdk-go-v2/service/shield" - "github.com/aws/aws-sdk-go-v2/service/storagegateway" - - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/tagging" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/config" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/promutil" -) - -type client struct { - logger *slog.Logger - taggingAPI *resourcegroupstaggingapi.Client - autoscalingAPI *autoscaling.Client - apiGatewayAPI *apigateway.Client - apiGatewayV2API *apigatewayv2.Client - ec2API *ec2.Client - dmsAPI *databasemigrationservice.Client - prometheusSvcAPI *amp.Client - storageGatewayAPI *storagegateway.Client - shieldAPI *shield.Client -} - -func NewClient( - logger *slog.Logger, - taggingAPI *resourcegroupstaggingapi.Client, - autoscalingAPI *autoscaling.Client, - apiGatewayAPI *apigateway.Client, - apiGatewayV2API *apigatewayv2.Client, - ec2API *ec2.Client, - dmsClient *databasemigrationservice.Client, - prometheusClient *amp.Client, - storageGatewayAPI *storagegateway.Client, - shieldAPI *shield.Client, -) tagging.Client { - return &client{ - logger: logger, - taggingAPI: taggingAPI, - autoscalingAPI: autoscalingAPI, - apiGatewayAPI: apiGatewayAPI, - apiGatewayV2API: apiGatewayV2API, - ec2API: ec2API, - dmsAPI: dmsClient, - prometheusSvcAPI: prometheusClient, - storageGatewayAPI: storageGatewayAPI, - shieldAPI: shieldAPI, - } -} - -func (c client) GetResources(ctx context.Context, job model.DiscoveryJob, region string) ([]*model.TaggedResource, error) { - svc := config.SupportedServices.GetService(job.Namespace) - var resources []*model.TaggedResource - shouldHaveDiscoveredResources := false - - if len(svc.ResourceFilters) > 0 { - shouldHaveDiscoveredResources = true - filters := make([]string, 0, len(svc.ResourceFilters)) - for _, filter := range svc.ResourceFilters { - filters = append(filters, *filter) - } - var tagFilters []types.TagFilter - if len(job.SearchTags) > 0 { - for i := range job.SearchTags { - // Because everything with the AWS APIs is pointers we need a pointer to the `Key` field from the SearchTag. - // We can't take a pointer to any fields from loop variable or the pointer will always be the same and this logic will be broken. - st := job.SearchTags[i] - - // AWS's GetResources has a TagFilter option which matches the semantics of our SearchTags where all filters must match - // Their value matching implementation is different though so instead of mapping the Key and Value we only map the Keys. - // Their API docs say, "If you don't specify a value for a key, the response returns all resources that are tagged with that key, with any or no value." - // which makes this a safe way to reduce the amount of data we need to filter out. - // https://docs.aws.amazon.com/resourcegroupstagging/latest/APIReference/API_GetResources.html#resourcegrouptagging-GetResources-request-TagFilters - tagFilters = append(tagFilters, types.TagFilter{Key: &st.Key}) - } - } - inputparams := &resourcegroupstaggingapi.GetResourcesInput{ - ResourceTypeFilters: filters, - ResourcesPerPage: aws.Int32(int32(100)), // max allowed value according to API docs - TagFilters: tagFilters, - } - - paginator := resourcegroupstaggingapi.NewGetResourcesPaginator(c.taggingAPI, inputparams, func(options *resourcegroupstaggingapi.GetResourcesPaginatorOptions) { - options.StopOnDuplicateToken = true - }) - for paginator.HasMorePages() { - promutil.ResourceGroupTaggingAPICounter.Inc() - page, err := paginator.NextPage(ctx) - if err != nil { - return nil, err - } - - for _, resourceTagMapping := range page.ResourceTagMappingList { - resource := model.TaggedResource{ - ARN: *resourceTagMapping.ResourceARN, - Namespace: job.Namespace, - Region: region, - Tags: make([]model.Tag, 0, len(resourceTagMapping.Tags)), - } - - for _, t := range resourceTagMapping.Tags { - resource.Tags = append(resource.Tags, model.Tag{Key: *t.Key, Value: *t.Value}) - } - - if resource.FilterThroughTags(job.SearchTags) { - resources = append(resources, &resource) - } else { - c.logger.Debug("Skipping resource because search tags do not match", "arn", resource.ARN) - } - } - } - - c.logger.Debug("GetResourcesPages finished", "total", len(resources)) - } - - if ext, ok := ServiceFilters[svc.Namespace]; ok { - if ext.ResourceFunc != nil { - shouldHaveDiscoveredResources = true - newResources, err := ext.ResourceFunc(ctx, c, job, region) - if err != nil { - return nil, fmt.Errorf("failed to apply ResourceFunc for %s, %w", svc.Namespace, err) - } - resources = append(resources, newResources...) - c.logger.Debug("ResourceFunc finished", "total", len(resources)) - } - - if ext.FilterFunc != nil { - filteredResources, err := ext.FilterFunc(ctx, c, resources) - if err != nil { - return nil, fmt.Errorf("failed to apply FilterFunc for %s, %w", svc.Namespace, err) - } - resources = filteredResources - c.logger.Debug("FilterFunc finished", "total", len(resources)) - } - } - - if shouldHaveDiscoveredResources && len(resources) == 0 { - return nil, tagging.ErrExpectedToFindResources - } - - return resources, nil -} diff --git a/pkg/clients/v1/factory.go b/pkg/clients/v1/factory.go deleted file mode 100644 index d6879ddd5..000000000 --- a/pkg/clients/v1/factory.go +++ /dev/null @@ -1,542 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v1 - -import ( - "context" - "log/slog" - "os" - "sync" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/apigateway" - "github.com/aws/aws-sdk-go/service/apigateway/apigatewayiface" - "github.com/aws/aws-sdk-go/service/apigatewayv2" - "github.com/aws/aws-sdk-go/service/apigatewayv2/apigatewayv2iface" - "github.com/aws/aws-sdk-go/service/autoscaling" - "github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface" - "github.com/aws/aws-sdk-go/service/cloudwatch" - "github.com/aws/aws-sdk-go/service/databasemigrationservice" - "github.com/aws/aws-sdk-go/service/databasemigrationservice/databasemigrationserviceiface" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/iam/iamiface" - "github.com/aws/aws-sdk-go/service/prometheusservice" - "github.com/aws/aws-sdk-go/service/prometheusservice/prometheusserviceiface" - "github.com/aws/aws-sdk-go/service/resourcegroupstaggingapi" - "github.com/aws/aws-sdk-go/service/shield" - "github.com/aws/aws-sdk-go/service/shield/shieldiface" - "github.com/aws/aws-sdk-go/service/storagegateway" - "github.com/aws/aws-sdk-go/service/storagegateway/storagegatewayiface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" - "go.uber.org/atomic" - - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/account" - account_v1 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/account/v1" - cloudwatch_client "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch" - cloudwatch_v1 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch/v1" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/tagging" - tagging_v1 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/tagging/v1" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" -) - -type CachingFactory struct { - stsRegion string - session *session.Session - endpointResolver endpoints.ResolverFunc - stscache map[model.Role]stsiface.STSAPI - iamcache map[model.Role]iamiface.IAMAPI - clients map[model.Role]map[string]*cachedClients - cleared *atomic.Bool - refreshed *atomic.Bool - mu sync.Mutex - fips bool - logger *slog.Logger -} - -type cachedClients struct { - // if we know that this job is only used for static - // then we don't have to construct as many cached connections - // later on - onlyStatic bool - cloudwatch cloudwatch_client.Client - tagging tagging.Client - account account.Client -} - -// Ensure the struct properly implements the interface -var _ clients.Factory = &CachingFactory{} - -// NewFactory creates a new client factory to use when fetching data from AWS with sdk v1 -func NewFactory(logger *slog.Logger, jobsCfg model.JobsConfig, fips bool) *CachingFactory { - stscache := map[model.Role]stsiface.STSAPI{} - iamcache := map[model.Role]iamiface.IAMAPI{} - cache := map[model.Role]map[string]*cachedClients{} - - for _, discoveryJob := range jobsCfg.DiscoveryJobs { - for _, role := range discoveryJob.Roles { - if _, ok := stscache[role]; !ok { - stscache[role] = nil - } - if _, ok := iamcache[role]; !ok { - iamcache[role] = nil - } - if _, ok := cache[role]; !ok { - cache[role] = map[string]*cachedClients{} - } - for _, region := range discoveryJob.Regions { - cache[role][region] = &cachedClients{} - } - } - } - - for _, staticJob := range jobsCfg.StaticJobs { - for _, role := range staticJob.Roles { - if _, ok := stscache[role]; !ok { - stscache[role] = nil - } - if _, ok := iamcache[role]; !ok { - iamcache[role] = nil - } - - if _, ok := cache[role]; !ok { - cache[role] = map[string]*cachedClients{} - } - - for _, region := range staticJob.Regions { - // Only write a new region in if the region does not exist - if _, ok := cache[role][region]; !ok { - cache[role][region] = &cachedClients{ - onlyStatic: true, - } - } - } - } - } - - for _, customNamespaceJob := range jobsCfg.CustomNamespaceJobs { - for _, role := range customNamespaceJob.Roles { - if _, ok := stscache[role]; !ok { - stscache[role] = nil - } - if _, ok := iamcache[role]; !ok { - iamcache[role] = nil - } - - if _, ok := cache[role]; !ok { - cache[role] = map[string]*cachedClients{} - } - - for _, region := range customNamespaceJob.Regions { - // Only write a new region in if the region does not exist - if _, ok := cache[role][region]; !ok { - cache[role][region] = &cachedClients{ - onlyStatic: true, - } - } - } - } - } - - endpointResolver := endpoints.DefaultResolver().EndpointFor - - endpointURLOverride := os.Getenv("AWS_ENDPOINT_URL") - if endpointURLOverride != "" { - // allow override of all endpoints for local testing - endpointResolver = func(_ string, _ string, _ ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { - return endpoints.ResolvedEndpoint{ - URL: endpointURLOverride, - }, nil - } - } - - return &CachingFactory{ - stsRegion: jobsCfg.StsRegion, - session: nil, - endpointResolver: endpointResolver, - stscache: stscache, - iamcache: iamcache, - clients: cache, - fips: fips, - cleared: atomic.NewBool(false), - refreshed: atomic.NewBool(false), - logger: logger, - } -} - -func (c *CachingFactory) Clear() { - if c.cleared.Load() { - return - } - - c.mu.Lock() - defer c.mu.Unlock() - - if c.cleared.Load() { - return - } - - for role := range c.stscache { - c.stscache[role] = nil - } - - for role := range c.iamcache { - c.iamcache[role] = nil - } - - for role, regions := range c.clients { - for region := range regions { - cachedClient := c.clients[role][region] - cachedClient.account = nil - cachedClient.cloudwatch = nil - cachedClient.tagging = nil - } - } - c.cleared.Store(true) - c.refreshed.Store(false) -} - -func (c *CachingFactory) Refresh() { - if c.refreshed.Load() { - return - } - - c.mu.Lock() - defer c.mu.Unlock() - // Double check Refresh wasn't called concurrently - if c.refreshed.Load() { - return - } - - // sessions really only need to be constructed once at runtime - if c.session == nil { - c.session = createAWSSession(c.endpointResolver, c.logger) - } - - for role := range c.stscache { - c.stscache[role] = createStsSession(c.session, role, c.stsRegion, c.fips, c.logger) - } - - for role := range c.iamcache { - c.iamcache[role] = createIamSession(c.session, role, c.fips, c.logger) - } - - for role, regions := range c.clients { - for region := range regions { - cachedClient := c.clients[role][region] - // if the role is just used in static jobs, then we - // can skip creating other sessions and potentially running - // into permissions errors or taking up needless cycles - cachedClient.cloudwatch = createCloudWatchClient(c.logger, c.session, ®ion, role, c.fips) - if cachedClient.onlyStatic { - continue - } - cachedClient.tagging = createTaggingClient(c.logger, c.session, ®ion, role, c.fips) - cachedClient.account = createAccountClient(c.logger, c.stscache[role], c.iamcache[role]) - } - } - - c.cleared.Store(false) - c.refreshed.Store(true) -} - -func createCloudWatchClient(logger *slog.Logger, s *session.Session, region *string, role model.Role, fips bool) cloudwatch_client.Client { - return cloudwatch_v1.NewClient( - logger, - createCloudwatchSession(s, region, role, fips, logger), - ) -} - -func createTaggingClient(logger *slog.Logger, session *session.Session, region *string, role model.Role, fips bool) tagging.Client { - // The createSession function for a service which does not support FIPS does not take a fips parameter - // This currently applies to createTagSession(Resource Groups Tagging), ASG (EC2 autoscaling), and Prometheus (Amazon Managed Prometheus) - // AWS FIPS Reference: https://aws.amazon.com/compliance/fips/ - return tagging_v1.NewClient( - logger, - createTagSession(session, region, role, logger), - createASGSession(session, region, role, logger), - createAPIGatewaySession(session, region, role, fips, logger), - createAPIGatewayV2Session(session, region, role, fips, logger), - createEC2Session(session, region, role, fips, logger), - createDMSSession(session, region, role, fips, logger), - createPrometheusSession(session, region, role, logger), - createStorageGatewaySession(session, region, role, fips, logger), - createShieldSession(session, region, role, fips, logger), - ) -} - -func createAccountClient(logger *slog.Logger, sts stsiface.STSAPI, iam iamiface.IAMAPI) account.Client { - return account_v1.NewClient(logger, sts, iam) -} - -func (c *CachingFactory) GetCloudwatchClient(region string, role model.Role, concurrency cloudwatch_client.ConcurrencyConfig) cloudwatch_client.Client { - if !c.refreshed.Load() { - // if we have not refreshed then we need to lock in case we are accessing concurrently - c.mu.Lock() - defer c.mu.Unlock() - } - if client := c.clients[role][region].cloudwatch; client != nil { - return cloudwatch_client.NewLimitedConcurrencyClient(client, concurrency.NewLimiter()) - } - c.clients[role][region].cloudwatch = createCloudWatchClient(c.logger, c.session, ®ion, role, c.fips) - return cloudwatch_client.NewLimitedConcurrencyClient(c.clients[role][region].cloudwatch, concurrency.NewLimiter()) -} - -func (c *CachingFactory) GetTaggingClient(region string, role model.Role, concurrencyLimit int) tagging.Client { - if !c.refreshed.Load() { - // if we have not refreshed then we need to lock in case we are accessing concurrently - c.mu.Lock() - defer c.mu.Unlock() - } - if client := c.clients[role][region].tagging; client != nil { - return tagging.NewLimitedConcurrencyClient(client, concurrencyLimit) - } - c.clients[role][region].tagging = createTaggingClient(c.logger, c.session, ®ion, role, c.fips) - return tagging.NewLimitedConcurrencyClient(c.clients[role][region].tagging, concurrencyLimit) -} - -func (c *CachingFactory) GetAccountClient(region string, role model.Role) account.Client { - if !c.refreshed.Load() { - // if we have not refreshed then we need to lock in case we are accessing concurrently - c.mu.Lock() - defer c.mu.Unlock() - } - if client := c.clients[role][region].account; client != nil { - return client - } - c.clients[role][region].account = createAccountClient(c.logger, c.stscache[role], c.iamcache[role]) - return c.clients[role][region].account -} - -func setExternalID(ID string) func(p *stscreds.AssumeRoleProvider) { - return func(p *stscreds.AssumeRoleProvider) { - if ID != "" { - p.ExternalID = aws.String(ID) - } - } -} - -func setSTSCreds(sess *session.Session, config *aws.Config, role model.Role) *aws.Config { - if role.RoleArn != "" { - config.Credentials = stscreds.NewCredentials( - sess, role.RoleArn, setExternalID(role.ExternalID)) - } - return config -} - -func getAwsRetryer() aws.RequestRetryer { - return client.DefaultRetryer{ - NumMaxRetries: 5, - // MaxThrottleDelay and MinThrottleDelay used for throttle errors - MaxThrottleDelay: 10 * time.Second, - MinThrottleDelay: 1 * time.Second, - // For other errors - MaxRetryDelay: 3 * time.Second, - MinRetryDelay: 1 * time.Second, - } -} - -func createAWSSession(resolver endpoints.ResolverFunc, logger *slog.Logger) *session.Session { - config := aws.Config{ - CredentialsChainVerboseErrors: aws.Bool(true), - EndpointResolver: resolver, - } - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - Config: config, - })) - return sess -} - -func createStsSession(sess *session.Session, role model.Role, region string, fips bool, logger *slog.Logger) *sts.STS { - maxStsRetries := 5 - config := &aws.Config{MaxRetries: &maxStsRetries} - - if region != "" { - config = config.WithRegion(region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint) - } - - if fips { - config.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled - } - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return sts.New(sess, setSTSCreds(sess, config, role)) -} - -func createIamSession(sess *session.Session, role model.Role, fips bool, logger *slog.Logger) *iam.IAM { - maxStsRetries := 5 - config := &aws.Config{MaxRetries: &maxStsRetries} - - if fips { - config.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled - } - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return iam.New(sess, setSTSCreds(sess, config, role)) -} - -func createCloudwatchSession(sess *session.Session, region *string, role model.Role, fips bool, logger *slog.Logger) *cloudwatch.CloudWatch { - config := &aws.Config{Region: region, Retryer: getAwsRetryer()} - - if fips { - config.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled - } - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return cloudwatch.New(sess, setSTSCreds(sess, config, role)) -} - -func createTagSession(sess *session.Session, region *string, role model.Role, logger *slog.Logger) *resourcegroupstaggingapi.ResourceGroupsTaggingAPI { - maxResourceGroupTaggingRetries := 5 - config := &aws.Config{ - Region: region, - MaxRetries: &maxResourceGroupTaggingRetries, - CredentialsChainVerboseErrors: aws.Bool(true), - } - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return resourcegroupstaggingapi.New(sess, setSTSCreds(sess, config, role)) -} - -func createAPIGatewaySession(sess *session.Session, region *string, role model.Role, fips bool, logger *slog.Logger) apigatewayiface.APIGatewayAPI { - maxAPIGatewayAPIRetries := 5 - config := &aws.Config{Region: region, MaxRetries: &maxAPIGatewayAPIRetries} - if fips { - config.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled - } - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return apigateway.New(sess, setSTSCreds(sess, config, role)) -} - -func createAPIGatewayV2Session(sess *session.Session, region *string, role model.Role, fips bool, logger *slog.Logger) apigatewayv2iface.ApiGatewayV2API { - maxAPIGatewayAPIRetries := 5 - config := &aws.Config{Region: region, MaxRetries: &maxAPIGatewayAPIRetries} - if fips { - config.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled - } - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return apigatewayv2.New(sess, setSTSCreds(sess, config, role)) -} - -func createASGSession(sess *session.Session, region *string, role model.Role, logger *slog.Logger) autoscalingiface.AutoScalingAPI { - maxAutoScalingAPIRetries := 5 - config := &aws.Config{Region: region, MaxRetries: &maxAutoScalingAPIRetries} - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return autoscaling.New(sess, setSTSCreds(sess, config, role)) -} - -func createStorageGatewaySession(sess *session.Session, region *string, role model.Role, fips bool, logger *slog.Logger) storagegatewayiface.StorageGatewayAPI { - maxStorageGatewayAPIRetries := 5 - config := &aws.Config{Region: region, MaxRetries: &maxStorageGatewayAPIRetries} - - if fips { - config.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled - } - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return storagegateway.New(sess, setSTSCreds(sess, config, role)) -} - -func createEC2Session(sess *session.Session, region *string, role model.Role, fips bool, logger *slog.Logger) ec2iface.EC2API { - maxEC2APIRetries := 10 - config := &aws.Config{Region: region, MaxRetries: &maxEC2APIRetries} - if fips { - config.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled - } - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return ec2.New(sess, setSTSCreds(sess, config, role)) -} - -func createPrometheusSession(sess *session.Session, region *string, role model.Role, logger *slog.Logger) prometheusserviceiface.PrometheusServiceAPI { - maxPrometheusAPIRetries := 10 - config := &aws.Config{Region: region, MaxRetries: &maxPrometheusAPIRetries} - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return prometheusservice.New(sess, setSTSCreds(sess, config, role)) -} - -func createDMSSession(sess *session.Session, region *string, role model.Role, fips bool, logger *slog.Logger) databasemigrationserviceiface.DatabaseMigrationServiceAPI { - maxDMSAPIRetries := 5 - config := &aws.Config{Region: region, MaxRetries: &maxDMSAPIRetries} - if fips { - config.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled - } - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return databasemigrationservice.New(sess, setSTSCreds(sess, config, role)) -} - -func createShieldSession(sess *session.Session, region *string, role model.Role, fips bool, logger *slog.Logger) shieldiface.ShieldAPI { - maxShieldAPIRetries := 5 - config := &aws.Config{Region: region, MaxRetries: &maxShieldAPIRetries} - if fips { - config.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled - } - - if logger != nil && logger.Enabled(context.Background(), slog.LevelDebug) { - config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) - } - - return shield.New(sess, setSTSCreds(sess, config, role)) -} diff --git a/pkg/clients/v1/factory_test.go b/pkg/clients/v1/factory_test.go deleted file mode 100644 index d18532d3b..000000000 --- a/pkg/clients/v1/factory_test.go +++ /dev/null @@ -1,1270 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v1 - -import ( - "fmt" - "sync" - "testing" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/awstesting/mock" - "github.com/aws/aws-sdk-go/service/sts/stsiface" - "github.com/prometheus/common/promslog" - "github.com/stretchr/testify/require" - "go.uber.org/atomic" - - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" -) - -func cmpCache(t *testing.T, initialCache *CachingFactory, cache *CachingFactory) { - for role := range initialCache.stscache { - if _, ok := cache.stscache[role]; !ok { - t.Logf("`role` not in sts cache %s", role.RoleArn) - t.Fail() - } - } - - for role, regionMap := range initialCache.clients { - if _, ok := cache.clients[role]; !ok { - t.Logf("`role` not in client cache %s", role.RoleArn) - t.Fail() - continue - } - - for region, client := range regionMap { - if _, ok := cache.clients[role][region]; !ok { - t.Logf("`region` %s not found in role %s", region, role.RoleArn) - t.Fail() - } - - if client == nil { - t.Logf("`client cache` is nil for region %s and role %v", region, role) - continue - } - - if cache.clients[role][region] == nil { - t.Logf("comparison `client cache` is nil for region %s and role %v", region, role) - continue - } - - if *client != *cache.clients[role][region] { - t.Logf("`client` %v is not equal to %v for role %v in region %s", *client, *cache.clients[role][region], role, region) - t.Logf("The cache for this client is %v", cache.clients[role]) - t.Logf("The cache for the comparison client is %v", client) - t.Fail() - } - } - } -} - -func TestNewClientCache(t *testing.T) { - tests := []struct { - descrip string - jobsCfg model.JobsConfig - fips bool - cache *CachingFactory - }{ - { - "an empty config gives an empty cache", - model.JobsConfig{}, - false, - &CachingFactory{ - logger: promslog.NewNopLogger(), - refreshed: atomic.NewBool(false), - cleared: atomic.NewBool(false), - }, - }, - { - "if fips is set then the clients has fips", - model.JobsConfig{}, - true, - &CachingFactory{ - fips: true, - logger: promslog.NewNopLogger(), - refreshed: atomic.NewBool(false), - cleared: atomic.NewBool(false), - }, - }, - { - "a ScrapeConf with only discovery jobs creates a cache", - model.JobsConfig{ - DiscoveryJobs: []model.DiscoveryJob{ - { - Regions: []string{"us-east-1", "us-west-2", "ap-northeast-3"}, - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn2", - }, - { - RoleArn: "some-arn", - ExternalID: "thing", - }, - }, - }, - { - Regions: []string{"ap-northeast-3"}, - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn5", - }, - }, - }, - }, - }, - false, - &CachingFactory{ - stscache: map[model.Role]stsiface.STSAPI{ - {RoleArn: "some-arn"}: nil, - {RoleArn: "some-arn", ExternalID: "thing"}: nil, - {RoleArn: "some-arn2"}: nil, - {RoleArn: "some-arn5"}: nil, - }, - clients: map[model.Role]map[string]*cachedClients{ - {RoleArn: "some-arn"}: { - "ap-northeast-3": &cachedClients{}, - "us-east-1": &cachedClients{}, - "us-west-2": &cachedClients{}, - }, - {RoleArn: "some-arn", ExternalID: "thing"}: { - "ap-northeast-3": &cachedClients{}, - "us-east-1": &cachedClients{}, - "us-west-2": &cachedClients{}, - }, - {RoleArn: "some-arn2"}: { - "ap-northeast-3": &cachedClients{}, - "us-east-1": &cachedClients{}, - "us-west-2": &cachedClients{}, - }, - {RoleArn: "some-arn5"}: { - "ap-northeast-3": &cachedClients{}, - }, - }, - logger: promslog.NewNopLogger(), - refreshed: atomic.NewBool(false), - cleared: atomic.NewBool(false), - }, - }, - { - "a ScrapeConf with only static jobs creates a cache", - model.JobsConfig{ - StaticJobs: []model.StaticJob{ - { - Name: "scrape-thing", - Regions: []string{"us-east-1", "eu-west-2"}, - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn2", - }, - { - RoleArn: "some-arn3", - }, - }, - }, - { - Name: "scrape-other-thing", - Regions: []string{"us-east-1"}, - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn2", - }, - { - RoleArn: "some-arn", - ExternalID: "thing", - }, - }, - }, - { - Name: "scrape-third-thing", - Regions: []string{"ap-northeast-1"}, - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn2", - }, - { - RoleArn: "some-arn4", - }, - }, - }, - }, - }, - false, - &CachingFactory{ - stscache: map[model.Role]stsiface.STSAPI{ - {RoleArn: "some-arn"}: nil, - {RoleArn: "some-arn", ExternalID: "thing"}: nil, - {RoleArn: "some-arn2"}: nil, - {RoleArn: "some-arn3"}: nil, - {RoleArn: "some-arn4"}: nil, - }, - clients: map[model.Role]map[string]*cachedClients{ - {RoleArn: "some-arn"}: { - "ap-northeast-1": &cachedClients{onlyStatic: true}, - "eu-west-2": &cachedClients{onlyStatic: true}, - "us-east-1": &cachedClients{onlyStatic: true}, - }, - {RoleArn: "some-arn", ExternalID: "thing"}: { - "us-east-1": &cachedClients{onlyStatic: true}, - }, - {RoleArn: "some-arn2"}: { - "ap-northeast-1": &cachedClients{onlyStatic: true}, - "eu-west-2": &cachedClients{onlyStatic: true}, - "us-east-1": &cachedClients{onlyStatic: true}, - }, - {RoleArn: "some-arn3"}: { - "eu-west-2": &cachedClients{onlyStatic: true}, - "us-east-1": &cachedClients{onlyStatic: true}, - }, - {RoleArn: "some-arn4"}: { - "ap-northeast-1": &cachedClients{onlyStatic: true}, - }, - }, - logger: promslog.NewNopLogger(), - refreshed: atomic.NewBool(false), - cleared: atomic.NewBool(false), - }, - }, - { - "a ScrapeConf with some overlapping static and discovery jobs creates a cache", - model.JobsConfig{ - DiscoveryJobs: []model.DiscoveryJob{ - { - Regions: []string{"us-east-1", "us-west-2", "ap-northeast-3"}, - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn2", - }, - { - RoleArn: "some-arn3", - }, - }, - }, - { - Regions: []string{"ap-northeast-3"}, - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn5", - }, - }, - }, - }, - StaticJobs: []model.StaticJob{ - { - Name: "scrape-thing", - Regions: []string{"us-east-1", "eu-west-2"}, - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn2", - }, - { - RoleArn: "some-arn3", - }, - }, - }, - { - Name: "scrape-other-thing", - Regions: []string{"us-east-1"}, - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn2", - }, - { - RoleArn: "some-arn", - ExternalID: "thing", - }, - }, - }, - { - Name: "scrape-third-thing", - Regions: []string{"ap-northeast-1"}, - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn2", - }, - { - RoleArn: "some-arn4", - }, - }, - }, - }, - }, - false, - &CachingFactory{ - stscache: map[model.Role]stsiface.STSAPI{ - {RoleArn: "some-arn"}: nil, - {RoleArn: "some-arn", ExternalID: "thing"}: nil, - {RoleArn: "some-arn2"}: nil, - {RoleArn: "some-arn3"}: nil, - {RoleArn: "some-arn4"}: nil, - {RoleArn: "some-arn5"}: nil, - }, - clients: map[model.Role]map[string]*cachedClients{ - {RoleArn: "some-arn"}: { - "ap-northeast-3": &cachedClients{}, - "us-east-1": &cachedClients{}, - "ap-northeast-1": &cachedClients{onlyStatic: true}, - "eu-west-2": &cachedClients{onlyStatic: true}, - "us-west-2": &cachedClients{}, - }, - {RoleArn: "some-arn", ExternalID: "thing"}: { - "us-east-1": &cachedClients{onlyStatic: true}, - }, - {RoleArn: "some-arn2"}: { - "ap-northeast-3": &cachedClients{}, - "ap-northeast-1": &cachedClients{onlyStatic: true}, - "eu-west-2": &cachedClients{onlyStatic: true}, - "us-east-1": &cachedClients{}, - "us-west-2": &cachedClients{}, - }, - {RoleArn: "some-arn3"}: { - "ap-northeast-3": &cachedClients{}, - "eu-west-2": &cachedClients{onlyStatic: true}, - "us-east-1": &cachedClients{}, - "us-west-2": &cachedClients{}, - }, - {RoleArn: "some-arn4"}: { - "ap-northeast-1": &cachedClients{onlyStatic: true}, - }, - {RoleArn: "some-arn5"}: { - "ap-northeast-3": &cachedClients{}, - }, - }, - logger: promslog.NewNopLogger(), - refreshed: atomic.NewBool(false), - cleared: atomic.NewBool(false), - }, - }, - { - "a ScrapeConf with only custom dimension jobs creates a cache", - model.JobsConfig{ - CustomNamespaceJobs: []model.CustomNamespaceJob{ - { - Name: "scrape-thing", - Regions: []string{"us-east-1", "eu-west-2"}, - Namespace: "CustomDimension", - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn2", - }, - { - RoleArn: "some-arn3", - }, - }, - }, - { - Name: "scrape-other-thing", - Regions: []string{"us-east-1"}, - Namespace: "CustomDimension", - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn2", - }, - { - RoleArn: "some-arn", - ExternalID: "thing", - }, - }, - }, - { - Name: "scrape-third-thing", - Regions: []string{"ap-northeast-1"}, - Namespace: "CustomDimension", - Roles: []model.Role{ - { - RoleArn: "some-arn", - }, - { - RoleArn: "some-arn2", - }, - { - RoleArn: "some-arn4", - }, - }, - }, - }, - }, - false, - &CachingFactory{ - stscache: map[model.Role]stsiface.STSAPI{ - {RoleArn: "some-arn"}: nil, - {RoleArn: "some-arn", ExternalID: "thing"}: nil, - {RoleArn: "some-arn2"}: nil, - {RoleArn: "some-arn3"}: nil, - {RoleArn: "some-arn4"}: nil, - }, - clients: map[model.Role]map[string]*cachedClients{ - {RoleArn: "some-arn"}: { - "ap-northeast-1": &cachedClients{onlyStatic: true}, - "eu-west-2": &cachedClients{onlyStatic: true}, - "us-east-1": &cachedClients{onlyStatic: true}, - }, - {RoleArn: "some-arn", ExternalID: "thing"}: { - "us-east-1": &cachedClients{onlyStatic: true}, - }, - {RoleArn: "some-arn2"}: { - "ap-northeast-1": &cachedClients{onlyStatic: true}, - "eu-west-2": &cachedClients{onlyStatic: true}, - "us-east-1": &cachedClients{onlyStatic: true}, - }, - {RoleArn: "some-arn3"}: { - "eu-west-2": &cachedClients{onlyStatic: true}, - "us-east-1": &cachedClients{onlyStatic: true}, - }, - {RoleArn: "some-arn4"}: { - "ap-northeast-1": &cachedClients{onlyStatic: true}, - }, - }, - logger: promslog.NewNopLogger(), - refreshed: atomic.NewBool(false), - cleared: atomic.NewBool(false), - }, - }, - } - - for _, l := range tests { - test := l - t.Run(test.descrip, func(t *testing.T) { - t.Parallel() - cache := NewFactory(promslog.NewNopLogger(), test.jobsCfg, test.fips) - t.Logf("the cache is: %v", cache) - - if test.cache.cleared.Load() != cache.cleared.Load() { - t.Logf("`cleared` not equal got %v, expected %v", cache.cleared, test.cache.cleared) - t.Fail() - } - - if test.cache.refreshed.Load() != cache.refreshed.Load() { - t.Logf("`refreshed` not equal got %v, expected %v", cache.refreshed, test.cache.refreshed) - t.Fail() - } - - if test.cache.fips != cache.fips { - t.Logf("`fips` not equal got %v, expected %v", cache.fips, test.cache.fips) - t.Fail() - } - - // Strict equality requires each set containing each other - cmpCache(t, test.cache, cache) - cmpCache(t, cache, test.cache) - }) - } -} - -func TestClear(t *testing.T) { - region := "us-east-1" - role := model.Role{} - - tests := []struct { - description string - cache *CachingFactory - }{ - { - "a new clear clears all clients", - &CachingFactory{ - session: mock.Session, - cleared: atomic.NewBool(false), - mu: sync.Mutex{}, - stscache: map[model.Role]stsiface.STSAPI{ - {}: nil, - }, - clients: map[model.Role]map[string]*cachedClients{ - {}: { - "us-east-1": &cachedClients{ - cloudwatch: createCloudWatchClient(promslog.NewNopLogger(), mock.Session, ®ion, role, false), - tagging: createTaggingClient(promslog.NewNopLogger(), mock.Session, ®ion, role, false), - account: createAccountClient(promslog.NewNopLogger(), nil, nil), - onlyStatic: true, - }, - }, - }, - logger: promslog.NewNopLogger(), - refreshed: atomic.NewBool(false), - }, - }, - { - "A second call to clear does nothing", - &CachingFactory{ - cleared: atomic.NewBool(true), - mu: sync.Mutex{}, - session: mock.Session, - stscache: map[model.Role]stsiface.STSAPI{ - {}: nil, - }, - clients: map[model.Role]map[string]*cachedClients{ - {}: { - "us-east-1": &cachedClients{ - cloudwatch: nil, - tagging: nil, - account: nil, - }, - }, - }, - logger: promslog.NewNopLogger(), - refreshed: atomic.NewBool(false), - }, - }, - } - - for _, l := range tests { - test := l - t.Run(test.description, func(t *testing.T) { - test.cache.Clear() - if !test.cache.cleared.Load() { - t.Log("Cache cleared flag not set") - t.Fail() - } - if test.cache.refreshed.Load() { - t.Log("Cache cleared flag set") - t.Fail() - } - - for role, client := range test.cache.stscache { - if client != nil { - t.Logf("STS `client` %v not cleared", role) - t.Fail() - } - } - - for role, regionMap := range test.cache.clients { - for region, client := range regionMap { - if client.cloudwatch != nil { - t.Logf("`cloudwatch client` %v in region %v is not nil", role, region) - t.Fail() - } - if client.tagging != nil { - t.Logf("`tagging client` %v in region %v is not nil", role, region) - t.Fail() - } - if client.account != nil { - t.Logf("`asg client` %v in region %v is not nil", role, region) - t.Fail() - } - } - } - }) - } -} - -func TestRefresh(t *testing.T) { - region := "us-east-1" - role := model.Role{} - - tests := []struct { - descrip string - cache *CachingFactory - cloudwatch bool - }{ - { - "a new refresh creates clients", - &CachingFactory{ - session: mock.Session, - refreshed: atomic.NewBool(false), - cleared: atomic.NewBool(false), - mu: sync.Mutex{}, - stscache: map[model.Role]stsiface.STSAPI{ - {}: nil, - }, - clients: map[model.Role]map[string]*cachedClients{ - {}: { - "us-east-1": &cachedClients{ - cloudwatch: nil, - tagging: nil, - account: nil, - }, - }, - }, - logger: promslog.NewNopLogger(), - }, - false, - }, - { - "a new refresh with static only creates only cloudwatch", - &CachingFactory{ - session: mock.Session, - refreshed: atomic.NewBool(false), - cleared: atomic.NewBool(false), - mu: sync.Mutex{}, - stscache: map[model.Role]stsiface.STSAPI{ - {}: nil, - }, - clients: map[model.Role]map[string]*cachedClients{ - {}: { - "us-east-1": &cachedClients{ - cloudwatch: nil, - tagging: nil, - account: nil, - onlyStatic: true, - }, - }, - }, - logger: promslog.NewNopLogger(), - }, - true, - }, - { - "A second call to refreshed does nothing", - &CachingFactory{ - refreshed: atomic.NewBool(true), - cleared: atomic.NewBool(false), - mu: sync.Mutex{}, - session: mock.Session, - stscache: map[model.Role]stsiface.STSAPI{ - {}: createStsSession(mock.Session, role, "", false, nil), - }, - clients: map[model.Role]map[string]*cachedClients{ - {}: { - "us-east-1": &cachedClients{ - cloudwatch: createCloudWatchClient(promslog.NewNopLogger(), mock.Session, ®ion, role, false), - tagging: createTaggingClient(promslog.NewNopLogger(), mock.Session, ®ion, role, false), - account: createAccountClient(promslog.NewNopLogger(), createStsSession(mock.Session, role, "", false, nil), createIamSession(mock.Session, role, false, nil)), - }, - }, - }, - logger: promslog.NewNopLogger(), - }, - false, - }, - } - - for _, l := range tests { - test := l - t.Run(test.descrip, func(t *testing.T) { - t.Parallel() - test.cache.Refresh() - - if !test.cache.refreshed.Load() { - t.Log("Cache refreshed flag not set") - t.Fail() - } - - if test.cache.cleared.Load() { - t.Log("Cache cleared flag set") - t.Fail() - } - - for role, client := range test.cache.stscache { - if client == nil { - t.Logf("STS `client` %v not refreshed", role) - t.Fail() - } - } - - for role, regionMap := range test.cache.clients { - for region, client := range regionMap { - if client.cloudwatch == nil { - t.Logf("`cloudwatch client` %v in region %v still nil", role, region) - t.Fail() - } - if test.cloudwatch { - continue - } - if client.tagging == nil { - t.Logf("`tagging client` %v in region %v still nil", role, region) - t.Fail() - } - if client.account == nil { - t.Logf("`asg client` %v in region %v still nil", role, region) - t.Fail() - } - } - } - }) - } -} - -func TestClientCacheGetCloudwatchClient(t *testing.T) { - testGetAWSClient( - t, "Cloudwatch", - func(t *testing.T, cache *CachingFactory, region string, role model.Role) { - iface := cache.GetCloudwatchClient(region, role, cloudwatch.ConcurrencyConfig{SingleLimit: 1}) - if iface == nil { - t.Fail() - return - } - }) -} - -func TestClientCacheGetTagging(t *testing.T) { - testGetAWSClient( - t, "Tagging", - func(t *testing.T, cache *CachingFactory, region string, role model.Role) { - iface := cache.GetTaggingClient(region, role, 1) - if iface == nil { - t.Fail() - return - } - }) -} - -func TestClientCacheGetAccount(t *testing.T) { - testGetAWSClient( - t, "Account", - func(t *testing.T, cache *CachingFactory, region string, role model.Role) { - iface := cache.GetAccountClient(region, role) - if iface == nil { - t.Fail() - return - } - }) -} - -func testGetAWSClient( - t *testing.T, - name string, - testClientGet func(*testing.T, *CachingFactory, string, model.Role), -) { - region := "us-east-1" - role := model.Role{} - tests := []struct { - descrip string - cache *CachingFactory - parallelRun bool - }{ - { - "locks during unrefreshed parallel call", - &CachingFactory{ - refreshed: atomic.NewBool(false), - mu: sync.Mutex{}, - session: mock.Session, - stscache: map[model.Role]stsiface.STSAPI{ - {}: nil, - }, - clients: map[model.Role]map[string]*cachedClients{ - {}: { - "us-east-1": &cachedClients{ - cloudwatch: createCloudWatchClient(promslog.NewNopLogger(), mock.Session, ®ion, role, false), - tagging: createTaggingClient(promslog.NewNopLogger(), mock.Session, ®ion, role, false), - account: createAccountClient(promslog.NewNopLogger(), createStsSession(mock.Session, role, "", false, nil), createIamSession(mock.Session, role, false, nil)), - }, - }, - }, - logger: promslog.NewNopLogger(), - }, - true, - }, - { - "returns clients if available", - &CachingFactory{ - refreshed: atomic.NewBool(true), - session: mock.Session, - mu: sync.Mutex{}, - stscache: map[model.Role]stsiface.STSAPI{ - {}: nil, - }, - clients: map[model.Role]map[string]*cachedClients{ - {}: { - "us-east-1": &cachedClients{ - cloudwatch: createCloudWatchClient(promslog.NewNopLogger(), mock.Session, ®ion, role, false), - tagging: createTaggingClient(promslog.NewNopLogger(), mock.Session, ®ion, role, false), - account: createAccountClient(promslog.NewNopLogger(), createStsSession(mock.Session, role, "", false, nil), createIamSession(mock.Session, role, false, nil)), - }, - }, - }, - logger: promslog.NewNopLogger(), - }, - false, - }, - { - "creates a new clients if not available", - &CachingFactory{ - refreshed: atomic.NewBool(true), - session: mock.Session, - mu: sync.Mutex{}, - stscache: map[model.Role]stsiface.STSAPI{ - {}: nil, - }, - clients: map[model.Role]map[string]*cachedClients{ - {}: { - "us-east-1": &cachedClients{}, - }, - }, - logger: promslog.NewNopLogger(), - }, - false, - }, - } - - for _, l := range tests { - test := l - t.Run(name+" "+test.descrip, func(t *testing.T) { - t.Parallel() - if test.parallelRun { - go testClientGet(t, test.cache, region, role) - } - testClientGet(t, test.cache, region, role) - - if test.cache.clients[role][region] == nil { - t.Log("cache is nil when it should be populated") - t.Fail() - } - }) - } -} - -func TestSetExternalID(t *testing.T) { - tests := []struct { - descrip string - ID string - isSet bool - }{ - { - "sets the external ID if not empty", - "should-be-set", - true, - }, - { - "external ID not set if empty", - "", - false, - }, - } - - for _, l := range tests { - test := l - t.Run(test.descrip, func(t *testing.T) { - f := setExternalID(test.ID) - p := &stscreds.AssumeRoleProvider{} - f(p) - if test.isSet { - if *p.ExternalID != test.ID { - t.Fail() - } - } - }) - } -} - -func TestSetSTSCreds(t *testing.T) { - tests := []struct { - descrip string - role model.Role - credentialsNil bool - externalID string - }{ - { - "sets the sts creds if the role arn is set", - model.Role{ - RoleArn: "this:arn", - }, - false, - "", - }, - { - "does not set the creds if role arn is not set", - model.Role{}, - true, - "", - }, - { - "does not set the creds if role arn is not set & external id is set", - model.Role{ - ExternalID: "thing", - }, - true, - "", - }, - } - - for _, l := range tests { - test := l - t.Run(test.descrip, func(t *testing.T) { - t.Parallel() - conf := setSTSCreds(mock.Session, &aws.Config{}, test.role) - if test.credentialsNil { - if conf.Credentials != nil { - t.Fail() - } - } else { - if conf.Credentials == nil { - t.Fail() - } - } - }) - } -} - -func TestCreateAWSSession(t *testing.T) { - tests := []struct { - descrip string - }{ - { - "creates an aws clients", - }, - } - - for _, l := range tests { - test := l - t.Run(test.descrip, func(t *testing.T) { - s := createAWSSession(endpoints.DefaultResolver().EndpointFor, nil) - if s == nil { - t.Fail() - } - }) - } -} - -func TestCreateStsSession(t *testing.T) { - tests := []struct { - descrip string - role model.Role - stsRegion string - }{ - { - "creates an sts clients with an empty role", - model.Role{}, - "", - }, - { - "creates an sts clients with region", - model.Role{}, - "eu-west-1", - }, - { - "creates an sts clients with an empty external id", - model.Role{ - RoleArn: "some:arn", - }, - "", - }, - { - "creates an sts clients with an empty role arn", - model.Role{ - ExternalID: "some-id", - }, - "", - }, - { - "creates an sts clients with an sts full role", - model.Role{ - RoleArn: "some:arn", - ExternalID: "some-id", - }, - "", - }, - } - - for _, l := range tests { - test := l - t.Run(test.descrip, func(t *testing.T) { - t.Parallel() - // just exercise the code path - iface := createStsSession(mock.Session, test.role, test.stsRegion, false, nil) - if iface == nil { - t.Fail() - } - }) - } -} - -func TestCreateCloudwatchSession(t *testing.T) { - testAWSClient( - t, - "Cloudwatch", - func(t *testing.T, s *session.Session, region *string, role model.Role, fips bool) { - iface := createCloudwatchSession(s, region, role, fips, nil) - if iface == nil { - t.Fail() - } - }) -} - -func TestCreateTagSession(t *testing.T) { - testAWSClient( - t, - "Tag", - func(t *testing.T, s *session.Session, region *string, role model.Role, _ bool) { - iface := createTagSession(s, region, role, promslog.NewNopLogger()) - if iface == nil { - t.Fail() - } - }) -} - -func TestCreateAPIGatewaySession(t *testing.T) { - testAWSClient( - t, - "APIGateway", - func(t *testing.T, s *session.Session, region *string, role model.Role, fips bool) { - iface := createAPIGatewaySession(s, region, role, fips, promslog.NewNopLogger()) - if iface == nil { - t.Fail() - } - }) -} - -func TestCreateAPIGatewayV2Session(t *testing.T) { - testAWSClient( - t, - "APIGatewayV2", - func(t *testing.T, s *session.Session, region *string, role model.Role, fips bool) { - iface := createAPIGatewayV2Session(s, region, role, fips, promslog.NewNopLogger()) - if iface == nil { - t.Fail() - } - }) -} - -func TestCreateASGSession(t *testing.T) { - testAWSClient( - t, - "ASG", - func(t *testing.T, s *session.Session, region *string, role model.Role, _ bool) { - iface := createASGSession(s, region, role, promslog.NewNopLogger()) - if iface == nil { - t.Fail() - } - }) -} - -func TestCreateEC2Session(t *testing.T) { - testAWSClient( - t, - "EC2", - func(t *testing.T, s *session.Session, region *string, role model.Role, fips bool) { - iface := createEC2Session(s, region, role, fips, promslog.NewNopLogger()) - if iface == nil { - t.Fail() - } - }) -} - -func TestCreatePrometheusSession(t *testing.T) { - testAWSClient( - t, - "Prometheus", - func(t *testing.T, s *session.Session, region *string, role model.Role, _ bool) { - iface := createPrometheusSession(s, region, role, promslog.NewNopLogger()) - if iface == nil { - t.Fail() - } - }) -} - -func TestCreateDMSSession(t *testing.T) { - testAWSClient( - t, - "DMS", - func(t *testing.T, s *session.Session, region *string, role model.Role, fips bool) { - iface := createDMSSession(s, region, role, fips, promslog.NewNopLogger()) - if iface == nil { - t.Fail() - } - }) -} - -func TestCreateStorageGatewaySession(t *testing.T) { - testAWSClient( - t, - "StorageGateway", - func(t *testing.T, s *session.Session, region *string, role model.Role, fips bool) { - iface := createStorageGatewaySession(s, region, role, fips, promslog.NewNopLogger()) - if iface == nil { - t.Fail() - } - }) -} - -func TestSTSResolvesFIPSEnabledEndpoints(t *testing.T) { - type testcase struct { - region string - expectedEndpoint string - } - - for _, tc := range []testcase{ - { - region: "us-east-1", - expectedEndpoint: "http://sts-fips.us-east-1.amazonaws.com", - }, - { - region: "us-west-1", - expectedEndpoint: "http://sts-fips.us-west-1.amazonaws.com", - }, - { - region: "us-gov-east-1", - expectedEndpoint: "http://sts.us-gov-east-1.amazonaws.com", - }, - } { - t.Run(tc.region, func(t *testing.T) { - var resolverError error - resolvedEndpoint := endpoints.ResolvedEndpoint{} - called := false - - mockSession := mock.Session - mockEndpoint := *mockSession.Config.Endpoint - previousResolver := mock.Session.Config.EndpointResolver - - // restore mock endpoint after - t.Cleanup(func() { - mockSession.Config.Endpoint = aws.String(mockEndpoint) - mockSession.Config.EndpointResolver = previousResolver - }) - - mockResolverFunc := mock.Session.Config.EndpointResolver.EndpointFor - mockSession.Config.EndpointResolver = endpoints.ResolverFunc(func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { - resolvedEndpoint, resolverError = mockResolverFunc(service, region, opts...) - - called = true - - return endpoints.ResolvedEndpoint{URL: mockEndpoint}, resolverError - }) - - mockSession.Config.Endpoint = nil - - sess := createStsSession(mock.Session, model.Role{}, tc.region, true, promslog.NewNopLogger()) - require.NotNil(t, sess) - - require.True(t, called, "expected endpoint resolver to be called") - require.NoError(t, resolverError, "no error expected when resolving endpoint") - require.Equal(t, tc.expectedEndpoint, resolvedEndpoint.URL) - }) - } -} - -func TestRaceConditionRefreshClear(t *testing.T) { - t.Parallel() - - // Create a factory with the test config - factory := NewFactory(promslog.NewNopLogger(), model.JobsConfig{}, false) - - // Number of concurrent operations to perform - iterations := 100 - - // Use WaitGroup to synchronize goroutines - var wg sync.WaitGroup - wg.Add(iterations) // For both Refresh and Clear calls - - // Start function to run concurrent operations - for i := 0; i < iterations; i++ { - // Launch goroutine to call Refresh - go func() { - defer wg.Done() - factory.Refresh() - factory.Clear() - }() - } - - // Create a channel to signal completion - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - - // Wait for either completion or timeout - select { - case <-done: - // Test completed successfully - case <-time.After(60 * time.Second): - require.Fail(t, "Test timed out after 60 seconds") - } -} - -func testAWSClient( - t *testing.T, - name string, - testClientCreation func(*testing.T, *session.Session, *string, model.Role, bool), -) { - tests := []struct { - descrip string - region string - role model.Role - fips bool - }{ - { - fmt.Sprintf("%s client without role and fips is created", name), - "us-east-1", - model.Role{}, - false, - }, - { - fmt.Sprintf("%s client without role and with fips is created", name), - "us-east-1", - model.Role{}, - true, - }, - { - fmt.Sprintf("%s client with roleARN and without external id is created", name), - "us-east-1", - model.Role{ - RoleArn: "some:arn", - }, - false, - }, - { - fmt.Sprintf("%s client with roleARN and with external id is created", name), - "us-east-1", - model.Role{ - RoleArn: "some:arn", - ExternalID: "some-id", - }, - false, - }, - { - fmt.Sprintf("%s client without roleARN and with external id is created", name), - "us-east-1", - model.Role{ - ExternalID: "some-id", - }, - false, - }, - } - - for _, l := range tests { - test := l - t.Run(test.descrip, func(t *testing.T) { - t.Parallel() - testClientCreation(t, mock.Session, &test.region, test.role, test.fips) - }) - } -} diff --git a/pkg/clients/v2/factory.go b/pkg/clients/v2/factory.go deleted file mode 100644 index aef4c7256..000000000 --- a/pkg/clients/v2/factory.go +++ /dev/null @@ -1,481 +0,0 @@ -// Copyright 2024 The Prometheus Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package v2 - -import ( - "context" - "fmt" - "log/slog" - "os" - "sync" - "time" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/retry" - aws_config "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/service/amp" - "github.com/aws/aws-sdk-go-v2/service/apigateway" - "github.com/aws/aws-sdk-go-v2/service/apigatewayv2" - "github.com/aws/aws-sdk-go-v2/service/autoscaling" - "github.com/aws/aws-sdk-go-v2/service/cloudwatch" - "github.com/aws/aws-sdk-go-v2/service/databasemigrationservice" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/aws/aws-sdk-go-v2/service/iam" - "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" - "github.com/aws/aws-sdk-go-v2/service/shield" - "github.com/aws/aws-sdk-go-v2/service/storagegateway" - "github.com/aws/aws-sdk-go-v2/service/sts" - aws_logging "github.com/aws/smithy-go/logging" - "go.uber.org/atomic" - - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/account" - account_v2 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/account/v2" - cloudwatch_client "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch" - cloudwatch_v2 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch/v2" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/tagging" - tagging_v2 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/tagging/v2" - "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model" -) - -type awsRegion = string - -type CachingFactory struct { - logger *slog.Logger - stsOptions func(*sts.Options) - clients map[model.Role]map[awsRegion]*cachedClients - mu sync.Mutex - refreshed *atomic.Bool - cleared *atomic.Bool - fipsEnabled bool - endpointURLOverride string -} - -type cachedClients struct { - awsConfig *aws.Config - // if we know that this job is only used for static - // then we don't have to construct as many cached connections - // later on - onlyStatic bool - cloudwatch cloudwatch_client.Client - tagging tagging.Client - account account.Client -} - -// Ensure the struct properly implements the interface -var _ clients.Factory = &CachingFactory{} - -// NewFactory creates a new client factory to use when fetching data from AWS with sdk v2 -func NewFactory(logger *slog.Logger, jobsCfg model.JobsConfig, fips bool) (*CachingFactory, error) { - var options []func(*aws_config.LoadOptions) error - options = append(options, aws_config.WithLogger(aws_logging.LoggerFunc(func(classification aws_logging.Classification, format string, v ...interface{}) { - switch classification { - case aws_logging.Debug: - if logger.Enabled(context.Background(), slog.LevelDebug) { - logger.Debug(fmt.Sprintf(format, v...)) - } - case aws_logging.Warn: - logger.Warn(fmt.Sprintf(format, v...)) - default: // AWS logging only supports debug or warn, log everything else as error - logger.Error(fmt.Sprintf(format, v...), "err", "unexected aws error classification", "classification", classification) - } - }))) - - options = append(options, aws_config.WithLogConfigurationWarnings(true)) - - endpointURLOverride := os.Getenv("AWS_ENDPOINT_URL") - - options = append(options, aws_config.WithRetryMaxAttempts(5)) - - c, err := aws_config.LoadDefaultConfig(context.TODO(), options...) - if err != nil { - return nil, fmt.Errorf("failed to load default aws config: %w", err) - } - - stsOptions := createStsOptions(jobsCfg.StsRegion, logger.Enabled(context.Background(), slog.LevelDebug), endpointURLOverride, fips) - cache := map[model.Role]map[awsRegion]*cachedClients{} - for _, discoveryJob := range jobsCfg.DiscoveryJobs { - for _, role := range discoveryJob.Roles { - if _, ok := cache[role]; !ok { - cache[role] = map[awsRegion]*cachedClients{} - } - for _, region := range discoveryJob.Regions { - regionConfig := awsConfigForRegion(role, &c, region, stsOptions) - cache[role][region] = &cachedClients{ - awsConfig: regionConfig, - onlyStatic: false, - } - } - } - } - - for _, staticJob := range jobsCfg.StaticJobs { - for _, role := range staticJob.Roles { - if _, ok := cache[role]; !ok { - cache[role] = map[awsRegion]*cachedClients{} - } - for _, region := range staticJob.Regions { - // Discovery job client definitions have precedence - if _, exists := cache[role][region]; !exists { - regionConfig := awsConfigForRegion(role, &c, region, stsOptions) - cache[role][region] = &cachedClients{ - awsConfig: regionConfig, - onlyStatic: true, - } - } - } - } - } - - for _, customNamespaceJob := range jobsCfg.CustomNamespaceJobs { - for _, role := range customNamespaceJob.Roles { - if _, ok := cache[role]; !ok { - cache[role] = map[awsRegion]*cachedClients{} - } - for _, region := range customNamespaceJob.Regions { - // Discovery job client definitions have precedence - if _, exists := cache[role][region]; !exists { - regionConfig := awsConfigForRegion(role, &c, region, stsOptions) - cache[role][region] = &cachedClients{ - awsConfig: regionConfig, - onlyStatic: true, - } - } - } - } - } - - return &CachingFactory{ - logger: logger, - clients: cache, - fipsEnabled: fips, - stsOptions: stsOptions, - endpointURLOverride: endpointURLOverride, - cleared: atomic.NewBool(false), - refreshed: atomic.NewBool(false), - }, nil -} - -func (c *CachingFactory) GetCloudwatchClient(region string, role model.Role, concurrency cloudwatch_client.ConcurrencyConfig) cloudwatch_client.Client { - if !c.refreshed.Load() { - // if we have not refreshed then we need to lock in case we are accessing concurrently - c.mu.Lock() - defer c.mu.Unlock() - } - if client := c.clients[role][region].cloudwatch; client != nil { - return cloudwatch_client.NewLimitedConcurrencyClient(client, concurrency.NewLimiter()) - } - c.clients[role][region].cloudwatch = cloudwatch_v2.NewClient(c.logger, c.createCloudwatchClient(c.clients[role][region].awsConfig)) - return cloudwatch_client.NewLimitedConcurrencyClient(c.clients[role][region].cloudwatch, concurrency.NewLimiter()) -} - -func (c *CachingFactory) GetTaggingClient(region string, role model.Role, concurrencyLimit int) tagging.Client { - if !c.refreshed.Load() { - // if we have not refreshed then we need to lock in case we are accessing concurrently - c.mu.Lock() - defer c.mu.Unlock() - } - if client := c.clients[role][region].tagging; client != nil { - return tagging.NewLimitedConcurrencyClient(client, concurrencyLimit) - } - c.clients[role][region].tagging = tagging_v2.NewClient( - c.logger, - c.createTaggingClient(c.clients[role][region].awsConfig), - c.createAutoScalingClient(c.clients[role][region].awsConfig), - c.createAPIGatewayClient(c.clients[role][region].awsConfig), - c.createAPIGatewayV2Client(c.clients[role][region].awsConfig), - c.createEC2Client(c.clients[role][region].awsConfig), - c.createDMSClient(c.clients[role][region].awsConfig), - c.createPrometheusClient(c.clients[role][region].awsConfig), - c.createStorageGatewayClient(c.clients[role][region].awsConfig), - c.createShieldClient(c.clients[role][region].awsConfig), - ) - return tagging.NewLimitedConcurrencyClient(c.clients[role][region].tagging, concurrencyLimit) -} - -func (c *CachingFactory) GetAccountClient(region string, role model.Role) account.Client { - if !c.refreshed.Load() { - // if we have not refreshed then we need to lock in case we are accessing concurrently - c.mu.Lock() - defer c.mu.Unlock() - } - if client := c.clients[role][region].account; client != nil { - return client - } - - stsClient := c.createStsClient(c.clients[role][region].awsConfig) - iamClient := c.createIAMClient(c.clients[role][region].awsConfig) - c.clients[role][region].account = account_v2.NewClient(c.logger, stsClient, iamClient) - return c.clients[role][region].account -} - -func (c *CachingFactory) Refresh() { - if c.refreshed.Load() { - return - } - c.mu.Lock() - defer c.mu.Unlock() - // Avoid double refresh in the event Refresh() is called concurrently - if c.refreshed.Load() { - return - } - - for _, regionClients := range c.clients { - for _, cache := range regionClients { - cache.cloudwatch = cloudwatch_v2.NewClient(c.logger, c.createCloudwatchClient(cache.awsConfig)) - if cache.onlyStatic { - continue - } - - cache.tagging = tagging_v2.NewClient( - c.logger, - c.createTaggingClient(cache.awsConfig), - c.createAutoScalingClient(cache.awsConfig), - c.createAPIGatewayClient(cache.awsConfig), - c.createAPIGatewayV2Client(cache.awsConfig), - c.createEC2Client(cache.awsConfig), - c.createDMSClient(cache.awsConfig), - c.createPrometheusClient(cache.awsConfig), - c.createStorageGatewayClient(cache.awsConfig), - c.createShieldClient(cache.awsConfig), - ) - - cache.account = account_v2.NewClient(c.logger, c.createStsClient(cache.awsConfig), c.createIAMClient(cache.awsConfig)) - } - } - - c.refreshed.Store(true) - c.cleared.Store(false) -} - -func (c *CachingFactory) Clear() { - if c.cleared.Load() { - return - } - // Prevent concurrent reads/write if clear is called during execution - c.mu.Lock() - defer c.mu.Unlock() - // Avoid double clear in the event Refresh() is called concurrently - if c.cleared.Load() { - return - } - - for _, regions := range c.clients { - for _, cache := range regions { - cache.cloudwatch = nil - cache.account = nil - cache.tagging = nil - } - } - - c.refreshed.Store(false) - c.cleared.Store(true) -} - -func (c *CachingFactory) createCloudwatchClient(regionConfig *aws.Config) *cloudwatch.Client { - return cloudwatch.NewFromConfig(*regionConfig, func(options *cloudwatch.Options) { - if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { - options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody - } - if c.endpointURLOverride != "" { - options.BaseEndpoint = aws.String(c.endpointURLOverride) - } - - // Setting an explicit retryer will override the default settings on the config - options.Retryer = retry.NewStandard(func(options *retry.StandardOptions) { - options.MaxAttempts = 5 - options.MaxBackoff = 3 * time.Second - }) - - if c.fipsEnabled { - options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled - } - }) -} - -func (c *CachingFactory) createTaggingClient(regionConfig *aws.Config) *resourcegroupstaggingapi.Client { - return resourcegroupstaggingapi.NewFromConfig(*regionConfig, func(options *resourcegroupstaggingapi.Options) { - if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { - options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody - } - if c.endpointURLOverride != "" { - options.BaseEndpoint = aws.String(c.endpointURLOverride) - } - // The FIPS setting is ignored because FIPS is not available for resource groups tagging apis - // If enabled the SDK will try to use non-existent FIPS URLs, https://github.com/aws/aws-sdk-go-v2/issues/2138#issuecomment-1570791988 - // AWS FIPS Reference: https://aws.amazon.com/compliance/fips/ - }) -} - -func (c *CachingFactory) createAutoScalingClient(assumedConfig *aws.Config) *autoscaling.Client { - return autoscaling.NewFromConfig(*assumedConfig, func(options *autoscaling.Options) { - if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { - options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody - } - if c.endpointURLOverride != "" { - options.BaseEndpoint = aws.String(c.endpointURLOverride) - } - // The FIPS setting is ignored because FIPS is not available for EC2 autoscaling apis - // If enabled the SDK will try to use non-existent FIPS URLs, https://github.com/aws/aws-sdk-go-v2/issues/2138#issuecomment-1570791988 - // AWS FIPS Reference: https://aws.amazon.com/compliance/fips/ - // EC2 autoscaling has FIPS compliant URLs for govcloud, but they do not use any FIPS prefixing, and should work - // with sdk v2s EndpointResolverV2 - }) -} - -func (c *CachingFactory) createAPIGatewayClient(assumedConfig *aws.Config) *apigateway.Client { - return apigateway.NewFromConfig(*assumedConfig, func(options *apigateway.Options) { - if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { - options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody - } - if c.endpointURLOverride != "" { - options.BaseEndpoint = aws.String(c.endpointURLOverride) - } - if c.fipsEnabled { - options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled - } - }) -} - -func (c *CachingFactory) createAPIGatewayV2Client(assumedConfig *aws.Config) *apigatewayv2.Client { - return apigatewayv2.NewFromConfig(*assumedConfig, func(options *apigatewayv2.Options) { - if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { - options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody - } - if c.endpointURLOverride != "" { - options.BaseEndpoint = aws.String(c.endpointURLOverride) - } - if c.fipsEnabled { - options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled - } - }) -} - -func (c *CachingFactory) createEC2Client(assumedConfig *aws.Config) *ec2.Client { - return ec2.NewFromConfig(*assumedConfig, func(options *ec2.Options) { - if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { - options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody - } - if c.endpointURLOverride != "" { - options.BaseEndpoint = aws.String(c.endpointURLOverride) - } - if c.fipsEnabled { - options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled - } - }) -} - -func (c *CachingFactory) createDMSClient(assumedConfig *aws.Config) *databasemigrationservice.Client { - return databasemigrationservice.NewFromConfig(*assumedConfig, func(options *databasemigrationservice.Options) { - if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { - options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody - } - if c.endpointURLOverride != "" { - options.BaseEndpoint = aws.String(c.endpointURLOverride) - } - if c.fipsEnabled { - options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled - } - }) -} - -func (c *CachingFactory) createStorageGatewayClient(assumedConfig *aws.Config) *storagegateway.Client { - return storagegateway.NewFromConfig(*assumedConfig, func(options *storagegateway.Options) { - if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { - options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody - } - if c.endpointURLOverride != "" { - options.BaseEndpoint = aws.String(c.endpointURLOverride) - } - if c.fipsEnabled { - options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled - } - }) -} - -func (c *CachingFactory) createPrometheusClient(assumedConfig *aws.Config) *amp.Client { - return amp.NewFromConfig(*assumedConfig, func(options *amp.Options) { - if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { - options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody - } - if c.endpointURLOverride != "" { - options.BaseEndpoint = aws.String(c.endpointURLOverride) - } - // The FIPS setting is ignored because FIPS is not available for amp apis - // If enabled the SDK will try to use non-existent FIPS URLs, https://github.com/aws/aws-sdk-go-v2/issues/2138#issuecomment-1570791988 - // AWS FIPS Reference: https://aws.amazon.com/compliance/fips/ - }) -} - -func (c *CachingFactory) createStsClient(awsConfig *aws.Config) *sts.Client { - return sts.NewFromConfig(*awsConfig, c.stsOptions) -} - -func (c *CachingFactory) createIAMClient(awsConfig *aws.Config) *iam.Client { - return iam.NewFromConfig(*awsConfig) -} - -func (c *CachingFactory) createShieldClient(awsConfig *aws.Config) *shield.Client { - return shield.NewFromConfig(*awsConfig, func(options *shield.Options) { - if c.logger != nil && c.logger.Enabled(context.Background(), slog.LevelDebug) { - options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody - } - if c.endpointURLOverride != "" { - options.BaseEndpoint = aws.String(c.endpointURLOverride) - } - if c.fipsEnabled { - options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled - } - }) -} - -func createStsOptions(stsRegion string, isDebugLoggingEnabled bool, endpointURLOverride string, fipsEnabled bool) func(*sts.Options) { - return func(options *sts.Options) { - if stsRegion != "" { - options.Region = stsRegion - } - if isDebugLoggingEnabled { - options.ClientLogMode = aws.LogRequestWithBody | aws.LogResponseWithBody - } - if endpointURLOverride != "" { - options.BaseEndpoint = aws.String(endpointURLOverride) - } - if fipsEnabled { - options.EndpointOptions.UseFIPSEndpoint = aws.FIPSEndpointStateEnabled - } - } -} - -var defaultRole = model.Role{} - -func awsConfigForRegion(r model.Role, c *aws.Config, region awsRegion, stsOptions func(*sts.Options)) *aws.Config { - regionalConfig := c.Copy() - regionalConfig.Region = region - - if r == defaultRole { - return ®ionalConfig - } - - // based on https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials/stscreds#hdr-Assume_Role - // found via https://github.com/aws/aws-sdk-go-v2/issues/1382 - regionalSts := sts.NewFromConfig(*c, stsOptions) - credentials := stscreds.NewAssumeRoleProvider(regionalSts, r.RoleArn, func(options *stscreds.AssumeRoleOptions) { - if r.ExternalID != "" { - options.ExternalID = aws.String(r.ExternalID) - } - }) - regionalConfig.Credentials = aws.NewCredentialsCache(credentials) - - return ®ionalConfig -} diff --git a/pkg/config/feature_flags.go b/pkg/config/feature_flags.go index c3642f509..1473898a2 100644 --- a/pkg/config/feature_flags.go +++ b/pkg/config/feature_flags.go @@ -16,9 +16,6 @@ import "context" type flagsCtxKey struct{} -// AwsSdkV1 is a feature flag used to enable the use of aws sdk v1 (v2 is the default) -const AwsSdkV1 = "aws-sdk-v1" - // AlwaysReturnInfoMetrics is a feature flag used to enable the return of info metrics even when there are no corresponding CloudWatch metrics const AlwaysReturnInfoMetrics = "always-return-info-metrics"