Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions internal/aws/awsutil/awsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package awsutil // import "github.com/open-telemetry/opentelemetry-collector-contrib/internal/aws/awsutil"

import "github.com/aws/aws-sdk-go/aws"

// AWSSessionSettings defines the common session configs for AWS components
type AWSSessionSettings struct {
// Maximum number of concurrent calls to AWS X-Ray to upload documents.
Expand Down Expand Up @@ -36,6 +38,8 @@ type AWSSessionSettings struct {
IMDSRetries int `mapstructure:"imds_retries"`
// External ID to verify third party role assumption
ExternalID string `mapstructure:"external_id"`
// Log Level for AWS SDK API calls
LogLevel *aws.LogLevelType `mapstructure:"log_level"`
}

func CreateDefaultSessionConfig() AWSSessionSettings {
Expand Down
57 changes: 43 additions & 14 deletions internal/aws/awsutil/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"log"
"net/http"
"net/url"
Expand All @@ -30,23 +31,36 @@ import (
"golang.org/x/net/http2"
)

type zapAWSLogger struct {
*zap.Logger
}

func (z *zapAWSLogger) Log(args ...interface{}) {
z.Info(fmt.Sprintln(args...))
}

type ConnAttr interface {
newAWSSession(logger *zap.Logger, cfg *AWSSessionSettings, region string) (*session.Session, error)
getEC2Region(s *session.Session, imdsRetries int) (string, error)
getEC2Region(logger *zap.Logger, logLevel *aws.LogLevelType, s *session.Session, imdsRetries int) (string, error)
}

// Conn implements connAttr interface.
type Conn struct{}

func (c *Conn) getEC2Region(s *session.Session, imdsRetries int) (string, error) {
func (c *Conn) getEC2Region(logger *zap.Logger, logLevel *aws.LogLevelType, s *session.Session, imdsRetries int) (string, error) {
region, err := ec2metadata.New(s, &aws.Config{
Retryer: override.NewIMDSRetryer(imdsRetries),
EC2MetadataEnableFallback: aws.Bool(false),
LogLevel: logLevel,
Logger: &zapAWSLogger{logger},
}).Region()
if err == nil {
return region, err
}
return ec2metadata.New(s, &aws.Config{}).Region()
return ec2metadata.New(s, &aws.Config{
LogLevel: logLevel,
Logger: &zapAWSLogger{logger},
}).Region()
}

type stsCredentialProvider struct {
Expand Down Expand Up @@ -200,7 +214,7 @@ func GetAWSConfigSession(logger *zap.Logger, cn ConnAttr, cfg *AWSSessionSetting
if err != nil {
logger.Error("Unable to retrieve default session", zap.Error(err))
} else {
awsRegion, err = cn.getEC2Region(es, cfg.IMDSRetries)
awsRegion, err = cn.getEC2Region(logger, cfg.LogLevel, es, cfg.IMDSRetries)
if err != nil {
logger.Error("Unable to retrieve the region from the EC2 instance", zap.Error(err))
} else {
Expand All @@ -226,6 +240,8 @@ func GetAWSConfigSession(logger *zap.Logger, cn ConnAttr, cfg *AWSSessionSetting
Endpoint: aws.String(cfg.Endpoint),
HTTPClient: http,
CredentialsChainVerboseErrors: aws.Bool(true),
LogLevel: cfg.LogLevel,
Logger: &zapAWSLogger{logger},
}
return config, s, nil
}
Expand Down Expand Up @@ -277,10 +293,12 @@ func (c *Conn) newAWSSession(logger *zap.Logger, cfg *AWSSessionSettings, region
logger.Warn("could not get default session before trying to get role sts", zap.Error(err))
return nil, err
}
stsCreds := newStsCredentials(s, cfg.RoleARN, region)
stsCreds := newStsCredentials(logger, cfg.LogLevel, s, cfg.RoleARN, region)

s, err = session.NewSession(&aws.Config{
Credentials: stsCreds,
LogLevel: cfg.LogLevel,
Logger: &zapAWSLogger{logger},
})
if err != nil {
logger.Error("Error in creating session object : ", zap.Error(err))
Expand All @@ -299,7 +317,7 @@ func getSTSCreds(logger *zap.Logger, region string, cfg *AWSSessionSettings) (*c
return nil, err
}

stsCred := getSTSCredsFromRegionEndpoint(logger, t, region, cfg.RoleARN, cfg.ExternalID)
stsCred := getSTSCredsFromRegionEndpoint(logger, cfg.LogLevel, t, region, cfg.RoleARN, cfg.ExternalID)
// Make explicit call to fetch credentials.
_, err = stsCred.Get()
if err != nil {
Expand All @@ -309,7 +327,7 @@ func getSTSCreds(logger *zap.Logger, region string, cfg *AWSSessionSettings) (*c

if awsErr.Code() == sts.ErrCodeRegionDisabledException {
logger.Error("Region ", zap.String("region", region), zap.Error(awsErr))
stsCred = getSTSCredsFromPrimaryRegionEndpoint(logger, t, cfg.RoleARN, region, cfg.ExternalID)
stsCred = getSTSCredsFromPrimaryRegionEndpoint(logger, cfg.LogLevel, t, cfg.RoleARN, region, cfg.ExternalID)
}
}
}
Expand All @@ -319,14 +337,19 @@ func getSTSCreds(logger *zap.Logger, region string, cfg *AWSSessionSettings) (*c
// getSTSCredsFromRegionEndpoint fetches STS credentials for provided roleARN from regional endpoint.
// AWS STS recommends that you provide both the Region and endpoint when you make calls to a Regional endpoint.
// Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code
func getSTSCredsFromRegionEndpoint(logger *zap.Logger, sess *session.Session, region string,
func getSTSCredsFromRegionEndpoint(logger *zap.Logger, logLevel *aws.LogLevelType, sess *session.Session, region string,
roleArn, externalID string,
) *credentials.Credentials {
regionalEndpoint := getSTSRegionalEndpoint(region)
// if regionalEndpoint is "", the STS endpoint is Global endpoint for classic regions except ap-east-1 - (HKG)
// for other opt-in regions, region value will create STS regional endpoint.
// This will be only in the case, if provided region is not present in aws_regions.go
c := &aws.Config{Region: aws.String(region), Endpoint: &regionalEndpoint}
c := &aws.Config{
Region: aws.String(region),
Endpoint: &regionalEndpoint,
LogLevel: logLevel,
Logger: &zapAWSLogger{logger},
}
st := sts.New(sess, c)
logger.Info("STS Endpoint ", zap.String("endpoint", st.Endpoint))
options := []func(*stscreds.AssumeRoleProvider){}
Expand All @@ -340,18 +363,18 @@ func getSTSCredsFromRegionEndpoint(logger *zap.Logger, sess *session.Session, re

// getSTSCredsFromPrimaryRegionEndpoint fetches STS credentials for provided roleARN from primary region endpoint in
// the respective partition.
func getSTSCredsFromPrimaryRegionEndpoint(logger *zap.Logger, t *session.Session, roleArn string,
func getSTSCredsFromPrimaryRegionEndpoint(logger *zap.Logger, logLevel *aws.LogLevelType, t *session.Session, roleArn string,
region string, externalID string,
) *credentials.Credentials {
logger.Info("Credentials for provided RoleARN being fetched from STS primary region endpoint.")
partitionID := getPartition(region)
switch partitionID {
case endpoints.AwsPartitionID:
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.UsEast1RegionID, roleArn, externalID)
return getSTSCredsFromRegionEndpoint(logger, logLevel, t, endpoints.UsEast1RegionID, roleArn, externalID)
case endpoints.AwsCnPartitionID:
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.CnNorth1RegionID, roleArn, externalID)
return getSTSCredsFromRegionEndpoint(logger, logLevel, t, endpoints.CnNorth1RegionID, roleArn, externalID)
case endpoints.AwsUsGovPartitionID:
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.UsGovWest1RegionID, roleArn, externalID)
return getSTSCredsFromRegionEndpoint(logger, logLevel, t, endpoints.UsGovWest1RegionID, roleArn, externalID)
}

return nil
Expand All @@ -375,6 +398,8 @@ func GetDefaultSession(logger *zap.Logger, cfg *AWSSessionSettings) (*session.Se
logger.Debug("Fallback shared config file(s)", zap.Strings("files", cfgFiles))
awsConfig := aws.Config{
Credentials: getRootCredentials(cfg),
LogLevel: cfg.LogLevel,
Logger: &zapAWSLogger{logger},
}
result, serr := session.NewSessionWithOptions(session.Options{
Config: awsConfig,
Expand Down Expand Up @@ -465,12 +490,14 @@ func getCredentialProviderChain(cfg *AWSSessionSettings) []credentials.Provider
return credProviders
}

func newStsCredentials(c client.ConfigProvider, roleARN string, region string) *credentials.Credentials {
func newStsCredentials(logger *zap.Logger, logLevel *aws.LogLevelType, c client.ConfigProvider, roleARN string, region string) *credentials.Credentials {
regional := &stscreds.AssumeRoleProvider{
Client: newStsClient(c, &aws.Config{
Region: aws.String(region),
STSRegionalEndpoint: endpoints.RegionalSTSEndpoint,
HTTPClient: &http.Client{Timeout: 1 * time.Minute},
LogLevel: logLevel,
Logger: &zapAWSLogger{logger},
}),
RoleARN: roleARN,
Duration: stscreds.DefaultDuration,
Expand All @@ -484,6 +511,8 @@ func newStsCredentials(c client.ConfigProvider, roleARN string, region string) *
Endpoint: aws.String(getFallbackEndpoint(fallbackRegion)),
STSRegionalEndpoint: endpoints.RegionalSTSEndpoint,
HTTPClient: &http.Client{Timeout: 1 * time.Minute},
LogLevel: logLevel,
Logger: &zapAWSLogger{logger},
}),
RoleARN: roleARN,
Duration: stscreds.DefaultDuration,
Expand Down
4 changes: 2 additions & 2 deletions internal/aws/awsutil/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type mockConn struct {
sn *session.Session
}

func (c *mockConn) getEC2Region(_ *session.Session, _ int) (string, error) {
func (c *mockConn) getEC2Region(_ *zap.Logger, _ *aws.LogLevelType, _ *session.Session, _ int) (string, error) {
args := c.Called(nil)
errorStr := args.String(0)
var err error
Expand Down Expand Up @@ -130,7 +130,7 @@ func TestNewAWSSessionWithErr(t *testing.T) {
Region: aws.String("us-east-1"),
})
assert.NotNil(t, se)
_, err = conn.getEC2Region(se, aWSSessionSettings.IMDSRetries)
_, err = conn.getEC2Region(logger, nil, se, aWSSessionSettings.IMDSRetries)
assert.Error(t, err)
}

Expand Down
Loading