|
4 | 4 | "bytes" |
5 | 5 | "context" |
6 | 6 | "fmt" |
| 7 | + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" |
7 | 8 | "io" |
8 | 9 | "io/ioutil" |
9 | 10 | "os" |
@@ -42,14 +43,20 @@ const ( |
42 | 43 |
|
43 | 44 | awsCustomCABundleEnvVar = "AWS_CA_BUNDLE" |
44 | 45 |
|
45 | | - awsWebIdentityTokenFilePathEnvKey = "AWS_WEB_IDENTITY_TOKEN_FILE" |
| 46 | + awsWebIdentityTokenFilePathEnvVar = "AWS_WEB_IDENTITY_TOKEN_FILE" |
46 | 47 |
|
47 | | - awsRoleARNEnvKey = "AWS_ROLE_ARN" |
48 | | - awsRoleSessionNameEnvKey = "AWS_ROLE_SESSION_NAME" |
| 48 | + awsRoleARNEnvVar = "AWS_ROLE_ARN" |
| 49 | + awsRoleSessionNameEnvVar = "AWS_ROLE_SESSION_NAME" |
49 | 50 |
|
50 | | - awsEnableEndpointDiscoveryEnvKey = "AWS_ENABLE_ENDPOINT_DISCOVERY" |
| 51 | + awsEnableEndpointDiscoveryEnvVar = "AWS_ENABLE_ENDPOINT_DISCOVERY" |
51 | 52 |
|
52 | 53 | awsS3UseARNRegionEnvVar = "AWS_S3_USE_ARN_REGION" |
| 54 | + |
| 55 | + awsEc2MetadataServiceEndpointModeEnvVar = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE" |
| 56 | + |
| 57 | + awsEc2MetadataServiceEndpointEnvVar = "AWS_EC2_METADATA_SERVICE_ENDPOINT" |
| 58 | + |
| 59 | + awsEc2MetadataDisabled = "AWS_EC2_METADATA_DISABLED" |
53 | 60 | ) |
54 | 61 |
|
55 | 62 | var ( |
@@ -180,6 +187,21 @@ type EnvConfig struct { |
180 | 187 | // |
181 | 188 | // AWS_S3_USE_ARN_REGION=true |
182 | 189 | S3UseARNRegion *bool |
| 190 | + |
| 191 | + // Specifies if the EC2 IMDS service client is enabled. |
| 192 | + // |
| 193 | + // AWS_EC2_METADATA_DISABLED=true |
| 194 | + EC2IMDSClientEnableState imds.ClientEnableState |
| 195 | + |
| 196 | + // Specifies the EC2 Instance Metadata Service default endpoint selection mode (IPv4 or IPv6) |
| 197 | + // |
| 198 | + // AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6 |
| 199 | + EC2IMDSEndpointMode imds.EndpointModeState |
| 200 | + |
| 201 | + // Specifies the EC2 Instance Metadata Service endpoint to use. If specified it overrides EC2IMDSEndpointMode. |
| 202 | + // |
| 203 | + // AWS_EC2_METADATA_SERVICE_ENDPOINT=http://fd00:ec2::254 |
| 204 | + EC2IMDSEndpoint string |
183 | 205 | } |
184 | 206 |
|
185 | 207 | // loadEnvConfig reads configuration values from the OS's environment variables. |
@@ -215,22 +237,59 @@ func NewEnvConfig() (EnvConfig, error) { |
215 | 237 |
|
216 | 238 | cfg.CustomCABundle = os.Getenv(awsCustomCABundleEnvVar) |
217 | 239 |
|
218 | | - cfg.WebIdentityTokenFilePath = os.Getenv(awsWebIdentityTokenFilePathEnvKey) |
| 240 | + cfg.WebIdentityTokenFilePath = os.Getenv(awsWebIdentityTokenFilePathEnvVar) |
219 | 241 |
|
220 | | - cfg.RoleARN = os.Getenv(awsRoleARNEnvKey) |
221 | | - cfg.RoleSessionName = os.Getenv(awsRoleSessionNameEnvKey) |
| 242 | + cfg.RoleARN = os.Getenv(awsRoleARNEnvVar) |
| 243 | + cfg.RoleSessionName = os.Getenv(awsRoleSessionNameEnvVar) |
222 | 244 |
|
223 | | - if err := setEndpointDiscoveryTypeFromEnvVal(&cfg.EnableEndpointDiscovery, []string{awsEnableEndpointDiscoveryEnvKey}); err != nil { |
| 245 | + if err := setEndpointDiscoveryTypeFromEnvVal(&cfg.EnableEndpointDiscovery, []string{awsEnableEndpointDiscoveryEnvVar}); err != nil { |
224 | 246 | return cfg, err |
225 | 247 | } |
226 | 248 |
|
227 | 249 | if err := setBoolPtrFromEnvVal(&cfg.S3UseARNRegion, []string{awsS3UseARNRegionEnvVar}); err != nil { |
228 | 250 | return cfg, err |
229 | 251 | } |
230 | 252 |
|
| 253 | + setEC2IMDSClientEnableState(&cfg.EC2IMDSClientEnableState, []string{awsEc2MetadataDisabled}) |
| 254 | + if err := setEC2IMDSEndpointMode(&cfg.EC2IMDSEndpointMode, []string{awsEc2MetadataServiceEndpointModeEnvVar}); err != nil { |
| 255 | + return cfg, err |
| 256 | + } |
| 257 | + cfg.EC2IMDSEndpoint = os.Getenv(awsEc2MetadataServiceEndpointEnvVar) |
| 258 | + |
231 | 259 | return cfg, nil |
232 | 260 | } |
233 | 261 |
|
| 262 | +func setEC2IMDSClientEnableState(state *imds.ClientEnableState, keys []string) { |
| 263 | + for _, k := range keys { |
| 264 | + value := os.Getenv(k) |
| 265 | + if len(value) == 0 { |
| 266 | + continue |
| 267 | + } |
| 268 | + switch { |
| 269 | + case strings.EqualFold(value, "true"): |
| 270 | + *state = imds.ClientDisabled |
| 271 | + case strings.EqualFold(value, "false"): |
| 272 | + *state = imds.ClientEnabled |
| 273 | + default: |
| 274 | + continue |
| 275 | + } |
| 276 | + break |
| 277 | + } |
| 278 | +} |
| 279 | + |
| 280 | +func setEC2IMDSEndpointMode(mode *imds.EndpointModeState, keys []string) error { |
| 281 | + for _, k := range keys { |
| 282 | + value := os.Getenv(k) |
| 283 | + if len(value) == 0 { |
| 284 | + continue |
| 285 | + } |
| 286 | + if err := mode.SetFromString(value); err != nil { |
| 287 | + return fmt.Errorf("invalid value for environment variable, %s=%s, %v", k, value, err) |
| 288 | + } |
| 289 | + } |
| 290 | + return nil |
| 291 | +} |
| 292 | + |
234 | 293 | // GetRegion returns the AWS Region if set in the environment. Returns an empty |
235 | 294 | // string if not set. |
236 | 295 | func (c EnvConfig) getRegion(ctx context.Context) (string, bool, error) { |
@@ -371,3 +430,30 @@ func (c EnvConfig) GetEnableEndpointDiscovery(ctx context.Context) (value aws.En |
371 | 430 |
|
372 | 431 | return c.EnableEndpointDiscovery, true, nil |
373 | 432 | } |
| 433 | + |
| 434 | +// GetEC2IMDSClientEnableState implements a EC2IMDSClientEnableState options resolver interface. |
| 435 | +func (c EnvConfig) GetEC2IMDSClientEnableState() (imds.ClientEnableState, bool, error) { |
| 436 | + if c.EC2IMDSClientEnableState == imds.ClientDefaultEnableState { |
| 437 | + return imds.ClientDefaultEnableState, false, nil |
| 438 | + } |
| 439 | + |
| 440 | + return c.EC2IMDSClientEnableState, true, nil |
| 441 | +} |
| 442 | + |
| 443 | +// GetEC2IMDSEndpointMode implements a EC2IMDSEndpointMode option resolver interface. |
| 444 | +func (c EnvConfig) GetEC2IMDSEndpointMode() (imds.EndpointModeState, bool, error) { |
| 445 | + if c.EC2IMDSEndpointMode == imds.EndpointModeStateUnset { |
| 446 | + return imds.EndpointModeStateUnset, false, nil |
| 447 | + } |
| 448 | + |
| 449 | + return c.EC2IMDSEndpointMode, true, nil |
| 450 | +} |
| 451 | + |
| 452 | +// GetEC2IMDSEndpoint implements a EC2IMDSEndpoint option resolver interface. |
| 453 | +func (c EnvConfig) GetEC2IMDSEndpoint() (string, bool, error) { |
| 454 | + if len(c.EC2IMDSEndpoint) == 0 { |
| 455 | + return "", false, nil |
| 456 | + } |
| 457 | + |
| 458 | + return c.EC2IMDSEndpoint, true, nil |
| 459 | +} |
0 commit comments