diff --git a/THIRD-PARTY-LICENSES b/THIRD-PARTY-LICENSES index f8565727ee..605217e687 100644 --- a/THIRD-PARTY-LICENSES +++ b/THIRD-PARTY-LICENSES @@ -180,7 +180,7 @@ Copyright © 2015 Steve Francia ----- -** aws/aws-sdk-go; version 1.15.7 -- https://github.com/aws/aws-sdk-go/ +** aws/aws-sdk-go-v2; version 1.24.4 -- https://github.com/aws/aws-sdk-go-v2/ ** Etcd; version v3.1.0-alpha.1 -- https://github.com/coreos/etcd/tree/v3.1.0-alpha.1 ** github.com/coreos/go-semver; version 0.2 -- https://github.com/coreos/go-semver ** github.com/coreos/go-systemd/; version 10 -- https://github.com/coreos/go-systemd/ @@ -412,9 +412,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -* For aws/aws-sdk-go see also this required NOTICE: +* For aws/aws-sdk-go-v2 see also this required NOTICE: AWS SDK for Go -Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved. Copyright 2014-2015 Stripe, Inc. * For Etcd see also this required NOTICE: CoreOS Project diff --git a/go.mod b/go.mod index fdcb6247f9..3e383b8041 100644 --- a/go.mod +++ b/go.mod @@ -2,13 +2,21 @@ module k8s.io/cloud-provider-aws go 1.22.7 +toolchain go1.24.3 + require ( - github.com/aws/aws-sdk-go v1.54.6 + github.com/aws/aws-sdk-go v1.55.7 + github.com/aws/aws-sdk-go-v2 v1.36.5 + github.com/aws/aws-sdk-go-v2/config v1.29.17 + github.com/aws/aws-sdk-go-v2/service/autoscaling v1.54.0 + github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing v1.29.3 + github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.45.2 + github.com/aws/aws-sdk-go-v2/service/kms v1.41.0 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 - github.com/spf13/cobra v1.7.0 + github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.10.0 golang.org/x/time v0.3.0 gopkg.in/gcfg.v1 v1.2.3 k8s.io/api v0.28.10 @@ -19,17 +27,34 @@ require ( k8s.io/component-base v0.28.10 k8s.io/controller-manager v0.28.10 k8s.io/csi-translation-lib v0.28.10 - k8s.io/klog/v2 v2.100.1 + k8s.io/klog/v2 v2.130.1 k8s.io/kubelet v0.28.10 k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 sigs.k8s.io/yaml v1.3.0 ) +require ( + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.3 // indirect + github.com/onsi/ginkgo/v2 v2.23.0 // indirect + github.com/onsi/gomega v1.36.2 // indirect +) + require ( github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.70 + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.32 + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.36 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.36 // indirect + github.com/aws/aws-sdk-go-v2/service/ec2 v1.218.0 + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.17 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.34.0 + github.com/aws/smithy-go v1.22.4 github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect @@ -41,7 +66,7 @@ require ( github.com/evanphx/json-patch v5.6.0+incompatible // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect - github.com/go-logr/logr v1.3.0 // indirect + github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.19.6 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect @@ -73,7 +98,7 @@ require ( github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.10.1 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect - github.com/stretchr/objx v0.5.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect go.etcd.io/etcd/api/v3 v3.5.9 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.9 // indirect go.etcd.io/etcd/client/v3 v3.5.9 // indirect @@ -89,22 +114,22 @@ require ( go.uber.org/atomic v1.10.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.24.0 // indirect - golang.org/x/crypto v0.21.0 // indirect + golang.org/x/crypto v0.33.0 // indirect golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect - golang.org/x/mod v0.14.0 // indirect - golang.org/x/net v0.23.0 // indirect + golang.org/x/mod v0.23.0 // indirect + golang.org/x/net v0.35.0 // indirect golang.org/x/oauth2 v0.11.0 // indirect - golang.org/x/sync v0.5.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/term v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.16.1 // indirect + golang.org/x/sync v0.11.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/term v0.29.0 // indirect + golang.org/x/text v0.22.0 // indirect + golang.org/x/tools v0.30.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20230822172742-b8732ec3820d // indirect google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect google.golang.org/grpc v1.59.0 // indirect - google.golang.org/protobuf v1.33.0 // indirect + google.golang.org/protobuf v1.36.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect diff --git a/go.sum b/go.sum index 23bc4b1ac1..bf7866a560 100644 --- a/go.sum +++ b/go.sum @@ -11,8 +11,44 @@ github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1 h github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1/go.mod h1:pSwJ0fSY5KhvocuWSx4fz3BA8OrA1bQn+K1Eli3BRwM= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= -github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= -github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go v1.55.7 h1:UJrkFq7es5CShfBwlWAC8DA077vp8PyVbQd3lqLiztE= +github.com/aws/aws-sdk-go v1.55.7/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.36.5 h1:0OF9RiEMEdDdZEMqF9MRjevyxAQcf6gY+E7vwBILFj0= +github.com/aws/aws-sdk-go-v2 v1.36.5/go.mod h1:EYrzvCCN9CMUTa5+6lf6MM4tq3Zjp8UhSGR/cBsjai0= +github.com/aws/aws-sdk-go-v2/config v1.29.17 h1:jSuiQ5jEe4SAMH6lLRMY9OVC+TqJLP5655pBGjmnjr0= +github.com/aws/aws-sdk-go-v2/config v1.29.17/go.mod h1:9P4wwACpbeXs9Pm9w1QTh6BwWwJjwYvJ1iCt5QbCXh8= +github.com/aws/aws-sdk-go-v2/credentials v1.17.70 h1:ONnH5CM16RTXRkS8Z1qg7/s2eDOhHhaXVd72mmyv4/0= +github.com/aws/aws-sdk-go-v2/credentials v1.17.70/go.mod h1:M+lWhhmomVGgtuPOhO85u4pEa3SmssPTdcYpP/5J/xc= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.32 h1:KAXP9JSHO1vKGCr5f4O6WmlVKLFFXgWYAGoJosorxzU= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.32/go.mod h1:h4Sg6FQdexC1yYG9RDnOvLbW1a/P986++/Y/a+GyEM8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.36 h1:SsytQyTMHMDPspp+spo7XwXTP44aJZZAC7fBV2C5+5s= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.36/go.mod h1:Q1lnJArKRXkenyog6+Y+zr7WDpk4e6XlR6gs20bbeNo= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.36 h1:i2vNHQiXUvKhs3quBR6aqlgJaiaexz/aNvdCktW/kAM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.36/go.mod h1:UdyGa7Q91id/sdyHPwth+043HhmP6yP9MBHgbZM0xo8= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/autoscaling v1.54.0 h1:0BmpSm5x2rpB9D2K2OAoOc1cZTUJpw1OiQj86ZT8RTg= +github.com/aws/aws-sdk-go-v2/service/autoscaling v1.54.0/go.mod h1:6U/Xm5bBkZGCTxH3NE9+hPKEpCFCothGn/gwytsr1Mk= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.218.0 h1:QPYsTfcPpPhkF+37pxLcl3xbQz2SRxsShQNB6VCkvLo= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.218.0/go.mod h1:ouvGEfHbLaIlWwpDpOVWPWR+YwO0HDv3vm5tYLq8ImY= +github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing v1.29.3 h1:DpyV8LeDf0y7iDaGZ3h1Y+Nh5IaBOR+xj44vVgEEegY= +github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing v1.29.3/go.mod h1:H232HdqVlSUoqy0cMJYW1TKjcxvGFGFZ20xQG8fOAPw= +github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.45.2 h1:vX70Z4lNSr7XsioU0uJq5yvxgI50sB66MvD+V/3buS4= +github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.45.2/go.mod h1:xnCC3vFBfOKpU6PcsCKL2ktgBTZfOwTGxj6V8/X3IS4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.4 h1:CXV68E2dNqhuynZJPB80bhPQwAKqBWVer887figW6Jc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.4/go.mod h1:/xFi9KtvBXP97ppCz1TAEvU1Uf66qvid89rbem3wCzQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.17 h1:t0E6FzREdtCsiLIoLCWsYliNsRBgyGD/MCK571qk4MI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.17/go.mod h1:ygpklyoaypuyDvOM5ujWGrYWpAK3h7ugnmKCU/76Ys4= +github.com/aws/aws-sdk-go-v2/service/kms v1.41.0 h1:2jKyib9msVrAVn+lngwlSplG13RpUZmzVte2yDao5nc= +github.com/aws/aws-sdk-go-v2/service/kms v1.41.0/go.mod h1:RyhzxkWGcfixlkieewzpO3D4P4fTMxhIDqDZWsh0u/4= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.5 h1:AIRJ3lfb2w/1/8wOOSqYb9fUKGwQbtysJ2H1MofRUPg= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.5/go.mod h1:b7SiVprpU+iGazDUqvRSLf5XmCdn+JtT1on7uNL6Ipc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.3 h1:BpOxT3yhLwSJ77qIY3DoHAQjZsc4HEGfMCE4NGy3uFg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.3/go.mod h1:vq/GQR1gOFLquZMSrxUK/cpvKCNVYibNyJ1m7JrU88E= +github.com/aws/aws-sdk-go-v2/service/sts v1.34.0 h1:NFOJ/NXEGV4Rq//71Hs1jC/NvPs1ezajK+yQmkwnPV0= +github.com/aws/aws-sdk-go-v2/service/sts v1.34.0/go.mod h1:7ph2tGpfQvwzgistp2+zga9f+bCjlQJPkPUmMgDSD7w= +github.com/aws/smithy-go v1.22.4 h1:uqXzVZNuNexwc/xrh6Tb56u89WDlJY6HS+KC0S4QSjw= +github.com/aws/smithy-go v1.22.4/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -29,7 +65,7 @@ github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr github.com/coreos/go-semver v0.3.1/go.mod h1:irMmmIw/7yzSRPWryHsK7EYSg09caPQL03VsM8rvUec= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= @@ -49,10 +85,9 @@ github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSw github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU= -github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= -github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-logr/zapr v1.2.3 h1:a9vnzlIBPQBBkeaR9IuMUfmVOrQlkoC4YfPoFkX3T7A= @@ -64,7 +99,8 @@ github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= @@ -93,8 +129,8 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= -github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad h1:a6HEuzUHeKH6hwfN/ZoQgRgVIWFJljSWa/zetS2WTvg= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= @@ -144,10 +180,10 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE= -github.com/onsi/ginkgo/v2 v2.9.4/go.mod h1:gCQYp2Q+kSoIj7ykSVb9nskRSsR6PUj4AiLywzIhbKM= -github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= -github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= +github.com/onsi/ginkgo/v2 v2.23.0 h1:FA1xjp8ieYDzlgS5ABTpdUDB7wtngggONc8a7ku2NqQ= +github.com/onsi/ginkgo/v2 v2.23.0/go.mod h1:zXTP6xIp3U8aVuXN8ENK9IXRaTjFnpVB9mGmaSRvxnM= +github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= +github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -167,22 +203,23 @@ github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0 github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stoewer/go-strcase v1.3.0 h1:g0eASXYtp+yvN9fK8sH94oCIk0fau9uV1/ZdJ0AVEzs= github.com/stoewer/go-strcase v1.3.0/go.mod h1:fAH5hQ5pehh+j3nZfvwdk2RgEgQjAoM8wodgtPmh1xo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75 h1:6fotK7otjonDflCTK0BCfls4SPy3NcCVb5dqqmbRknE= github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75/go.mod h1:KO6IkyS8Y3j8OdNO85qEYBsRPuteD+YciPomcXdrMnk= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5Qo6v2eYzo7kUS51QINcR5jNpbZS8= @@ -242,8 +279,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= -golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= +golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -252,8 +289,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= -golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/oauth2 v0.11.0 h1:vPL4xzxBM4niKCW6g9whtaWVXTJf1U5e4aZxxFx/gbU= golang.org/x/oauth2 v0.11.0/go.mod h1:LdF7O/8bLR/qWK9DrpXmbHLTouvRHK0SgJl0GmDBchk= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -263,8 +300,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= -golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -278,23 +315,24 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= -golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= +golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= +golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -305,8 +343,8 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= -golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -321,8 +359,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d h1: google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M= google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= +google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= @@ -364,8 +402,8 @@ k8s.io/csi-translation-lib v0.28.10/go.mod h1:KLe1tGLhunZ7HJ41a+hbf4b5sZ7vj5NEj3 k8s.io/gengo v0.0.0-20220902162205-c0856e24416d h1:U9tB195lKdzwqicbJvyJeOXV7Klv+wNAWENRnXEGi08= k8s.io/gengo v0.0.0-20220902162205-c0856e24416d/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E= k8s.io/klog/v2 v2.2.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= -k8s.io/klog/v2 v2.100.1 h1:7WCHKK6K8fNhTqfBhISHQ97KrnJNFZMcQvKp7gP/tmg= -k8s.io/klog/v2 v2.100.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= +k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= +k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= k8s.io/kms v0.28.10 h1:rjP2HWXeMF6p5dnSosVGiYhhMYkVkCE2wVRxkqiUiAM= k8s.io/kms v0.28.10/go.mod h1:RoSWftot8nvpJft3LXsx1G5Gd/RO/VxD2AWu7cMxqb0= k8s.io/kube-openapi v0.0.0-20230717233707-2695361300d9 h1:LyMgNKD2P8Wn1iAwQU5OhxCKlKJy0sHc+PcDwFB24dQ= diff --git a/hack/e2e/run.sh b/hack/e2e/run.sh index feb25d116a..892c389bb9 100755 --- a/hack/e2e/run.sh +++ b/hack/e2e/run.sh @@ -47,7 +47,7 @@ UP="${UP:-yes}" # if DOWN==yes, delete cluster after test DOWN="${DOWN:-yes}" -KUBERNETES_VERSION="${KUBERNETES_VERSION:-v1.26.0}" +KUBERNETES_VERSION="${KUBERNETES_VERSION:-v1.32.0}" CLUSTER_NAME="${CLUSTER_NAME:-test-cluster-${test_run_id}.k8s}" KOPS_STATE_STORE="${KOPS_STATE_STORE:-}" REGION="${AWS_REGION:-us-west-2}" diff --git a/pkg/controllers/tagging/tagging_controller.go b/pkg/controllers/tagging/tagging_controller.go index 01f9ab7546..306faa5030 100644 --- a/pkg/controllers/tagging/tagging_controller.go +++ b/pkg/controllers/tagging/tagging_controller.go @@ -14,6 +14,7 @@ limitations under the License. package tagging import ( + "context" "crypto/md5" "fmt" "golang.org/x/time/rate" @@ -37,7 +38,7 @@ import ( // workItem contains the node and an action for that node type workItem struct { node *v1.Node - action func(node *v1.Node) error + action func(ctx context.Context, node *v1.Node) error requeuingCount int enqueueTime time.Time } @@ -169,33 +170,33 @@ func NewTaggingController( // Run will start the controller to tag resources attached to the cluster // and untag resources detached from the cluster. -func (tc *Controller) Run(stopCh <-chan struct{}) { +func (tc *Controller) Run(ctx context.Context) { defer utilruntime.HandleCrash() defer tc.workqueue.ShutDown() // Wait for the caches to be synced before starting workers klog.Info("Waiting for informer caches to sync") - if ok := cache.WaitForCacheSync(stopCh, tc.nodesSynced); !ok { + if ok := cache.WaitForCacheSync(ctx.Done(), tc.nodesSynced); !ok { klog.Errorf("failed to wait for caches to sync") return } klog.Infof("Starting the tagging controller") - go wait.Until(tc.work, tc.nodeMonitorPeriod, stopCh) + go wait.UntilWithContext(ctx, func(ctx context.Context) { tc.work(ctx) }, tc.nodeMonitorPeriod) - <-stopCh + <-ctx.Done() } // work is a long-running function that continuously // call process() for each message on the workqueue -func (tc *Controller) work() { - for tc.process() { +func (tc *Controller) work(ctx context.Context) { + for tc.process(ctx) { } } // process reads each message in the queue and performs either // tag or untag function on the Node object -func (tc *Controller) process() bool { +func (tc *Controller) process(ctx context.Context) bool { obj, shutdown := tc.workqueue.Get() if shutdown { return false @@ -216,7 +217,7 @@ func (tc *Controller) process() bool { timeTaken := time.Since(workItem.enqueueTime).Seconds() recordWorkItemLatencyMetrics(workItemDequeuingTimeWorkItemMetric, timeTaken) - klog.Infof("Dequeuing latency %s", timeTaken) + klog.Infof("Dequeuing latency %f", timeTaken) instanceID, err := awsv1.KubernetesInstanceID(workItem.node.Spec.ProviderID).MapToAWSInstanceID() if err != nil { @@ -232,7 +233,7 @@ func (tc *Controller) process() bool { return nil } - err = workItem.action(workItem.node) + err = workItem.action(ctx, workItem.node) if err != nil { if workItem.requeuingCount < maxRequeuingCount { @@ -250,7 +251,7 @@ func (tc *Controller) process() bool { klog.Infof("Finished processing %s", workItem) timeTaken = time.Since(workItem.enqueueTime).Seconds() recordWorkItemLatencyMetrics(workItemProcessingTimeWorkItemMetric, timeTaken) - klog.Infof("Processing latency %s", timeTaken) + klog.Infof("Processing latency %f", timeTaken) } tc.workqueue.Forget(obj) @@ -267,11 +268,11 @@ func (tc *Controller) process() bool { // tagNodesResources tag node resources // If we want to tag more resources, modify this function appropriately -func (tc *Controller) tagNodesResources(node *v1.Node) error { +func (tc *Controller) tagNodesResources(ctx context.Context, node *v1.Node) error { for _, resource := range tc.resources { switch resource { case opt.Instance: - err := tc.tagEc2Instance(node) + err := tc.tagEc2Instance(ctx, node) if err != nil { return err } @@ -283,7 +284,7 @@ func (tc *Controller) tagNodesResources(node *v1.Node) error { // tagEc2Instances applies the provided tags to each EC2 instance in // the cluster. -func (tc *Controller) tagEc2Instance(node *v1.Node) error { +func (tc *Controller) tagEc2Instance(ctx context.Context, node *v1.Node) error { if !tc.isTaggingRequired(node) { klog.Infof("Skip tagging node %s since it was already tagged earlier.", node.GetName()) return nil @@ -291,7 +292,7 @@ func (tc *Controller) tagEc2Instance(node *v1.Node) error { instanceID, _ := awsv1.KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() - err := tc.cloud.TagResource(string(instanceID), tc.tags) + err := tc.cloud.TagResource(ctx, string(instanceID), tc.tags) if err != nil { if awsv1.IsAWSErrorInstanceNotFound(err) { @@ -324,11 +325,11 @@ func (tc *Controller) tagEc2Instance(node *v1.Node) error { // untagNodeResources untag node resources // If we want to untag more resources, modify this function appropriately -func (tc *Controller) untagNodeResources(node *v1.Node) error { +func (tc *Controller) untagNodeResources(ctx context.Context, node *v1.Node) error { for _, resource := range tc.resources { switch resource { case opt.Instance: - err := tc.untagEc2Instance(node) + err := tc.untagEc2Instance(ctx, node) if err != nil { return err } @@ -340,10 +341,10 @@ func (tc *Controller) untagNodeResources(node *v1.Node) error { // untagEc2Instances deletes the provided tags to each EC2 instances in // the cluster. -func (tc *Controller) untagEc2Instance(node *v1.Node) error { +func (tc *Controller) untagEc2Instance(ctx context.Context, node *v1.Node) error { instanceID, _ := awsv1.KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() - err := tc.cloud.UntagResource(string(instanceID), tc.tags) + err := tc.cloud.UntagResource(ctx, string(instanceID), tc.tags) if err != nil { klog.Errorf("Error in untagging EC2 instance %s for node %s, error: %v", instanceID, node.GetName(), err) @@ -357,7 +358,7 @@ func (tc *Controller) untagEc2Instance(node *v1.Node) error { // enqueueNode takes in the object and an // action for the object for a workitem and enqueue to the workqueue -func (tc *Controller) enqueueNode(node *v1.Node, action func(node *v1.Node) error) { +func (tc *Controller) enqueueNode(node *v1.Node, action func(ctx context.Context, node *v1.Node) error) { item := &workItem{ node: node, action: action, diff --git a/pkg/controllers/tagging/tagging_controller_test.go b/pkg/controllers/tagging/tagging_controller_test.go index a96b4135f5..117dbf6441 100644 --- a/pkg/controllers/tagging/tagging_controller_test.go +++ b/pkg/controllers/tagging/tagging_controller_test.go @@ -194,7 +194,7 @@ func Test_NodesJoiningAndLeaving(t *testing.T) { } awsServices := awsv1.NewFakeAWSServices(TestClusterID) - fakeAws, _ := awsv1.NewAWSCloud(awsv1.CloudConfig{}, awsServices) + fakeAws, _ := awsv1.NewAWSCloud(awsv1.CloudConfig{}, awsServices, nil) for _, testcase := range testcases { t.Run(testcase.name, func(t *testing.T) { @@ -236,7 +236,7 @@ func Test_NodesJoiningAndLeaving(t *testing.T) { } for tc.workqueue.Len() > 0 { - tc.process() + tc.process(context.TODO()) // sleep briefly because of exponential backoff when requeueing failed workitem // resulting in workqueue to be empty if checked immediately diff --git a/pkg/controllers/tagging/tagging_controller_wrapper.go b/pkg/controllers/tagging/tagging_controller_wrapper.go index e44181e168..1e0ee7167d 100644 --- a/pkg/controllers/tagging/tagging_controller_wrapper.go +++ b/pkg/controllers/tagging/tagging_controller_wrapper.go @@ -55,7 +55,7 @@ func (tc *ControllerWrapper) startTaggingController(ctx context.Context, initCon return nil, false, nil } - go taggingcontroller.Run(ctx.Done()) + go taggingcontroller.Run(ctx) return nil, true, nil } diff --git a/pkg/providers/v1/aws.go b/pkg/providers/v1/aws.go index 3e6236a29d..71c9e75433 100644 --- a/pkg/providers/v1/aws.go +++ b/pkg/providers/v1/aws.go @@ -29,21 +29,21 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/ec2metadata" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/autoscaling" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elbv2" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + awsConfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + stscredsv2 "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/smithy-go" "gopkg.in/gcfg.v1" v1 "k8s.io/api/core/v1" @@ -57,10 +57,10 @@ import ( clientset "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/scheme" v1core "k8s.io/client-go/kubernetes/typed/core/v1" - "k8s.io/client-go/pkg/version" "k8s.io/client-go/tools/cache" "k8s.io/client-go/tools/record" cloudprovider "k8s.io/cloud-provider" + "k8s.io/cloud-provider-aws/pkg/services" nodehelpers "k8s.io/cloud-provider/node/helpers" servicehelpers "k8s.io/cloud-provider/service/helpers" cloudvolume "k8s.io/cloud-provider/volume" @@ -330,16 +330,44 @@ const MaxReadThenCreateRetries = 30 // DefaultVolumeType specifies which storage to use for newly created Volumes // TODO: Remove when user/admin can configure volume types and thus we don't // need hardcoded defaults. -const DefaultVolumeType = "gp2" +const DefaultVolumeType = ec2types.VolumeTypeGp2 // Services is an abstraction over AWS, to allow mocking/other implementations type Services interface { - Compute(region string) (EC2, error) - LoadBalancing(region string) (ELB, error) - LoadBalancingV2(region string) (ELBV2, error) - Autoscaling(region string) (ASG, error) - Metadata() (EC2Metadata, error) - KeyManagement(region string) (KMS, error) + Compute(ctx context.Context, region string, assumeRoleProvider *stscreds.AssumeRoleProvider) (EC2, error) + LoadBalancing(ctx context.Context, regionName string, assumeRoleProvider *stscreds.AssumeRoleProvider) (ELB, error) + LoadBalancingV2(ctx context.Context, regionName string, assumeRoleProvider *stscreds.AssumeRoleProvider) (ELBV2, error) + Autoscaling(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (ASG, error) + Metadata(ctx context.Context) (EC2Metadata, error) + KeyManagement(ctx context.Context, regionName string, assumeRoleProvider *stscreds.AssumeRoleProvider) (KMS, error) +} + +// EC2API is an interface to satisfy the ec2.Client API. +// More details about this pattern: https://docs.aws.amazon.com/sdk-for-go/v2/developer-guide/unit-testing.html +type EC2API interface { + AuthorizeSecurityGroupIngress(ctx context.Context, params *ec2.AuthorizeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) + AttachVolume(ctx context.Context, params *ec2.AttachVolumeInput, optFns ...func(*ec2.Options)) (*ec2.AttachVolumeOutput, error) + CreateRoute(ctx context.Context, params *ec2.CreateRouteInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) + CreateSecurityGroup(ctx context.Context, params *ec2.CreateSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) + CreateTags(ctx context.Context, params *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) + CreateVolume(ctx context.Context, params *ec2.CreateVolumeInput, optFns ...func(*ec2.Options)) (*ec2.CreateVolumeOutput, error) + DeleteRoute(ctx context.Context, params *ec2.DeleteRouteInput, optFns ...func(*ec2.Options)) (*ec2.DeleteRouteOutput, error) + DeleteSecurityGroup(ctx context.Context, params *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) + DeleteTags(ctx context.Context, params *ec2.DeleteTagsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTagsOutput, error) + DeleteVolume(ctx context.Context, params *ec2.DeleteVolumeInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVolumeOutput, error) + DescribeAvailabilityZones(ctx context.Context, params *ec2.DescribeAvailabilityZonesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeAvailabilityZonesOutput, error) + DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFuns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) + DescribeNetworkInterfaces(ctx context.Context, params *ec2.DescribeNetworkInterfacesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) + DescribeRouteTables(ctx context.Context, params *ec2.DescribeRouteTablesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) + DescribeSecurityGroups(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) + DescribeSubnets(ctx context.Context, params *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) + DescribeVolumes(ctx context.Context, params *ec2.DescribeVolumesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVolumesOutput, error) + DescribeVolumesModifications(ctx context.Context, params *ec2.DescribeVolumesModificationsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVolumesModificationsOutput, error) + DescribeVpcs(ctx context.Context, params *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) + DetachVolume(ctx context.Context, params *ec2.DetachVolumeInput, optFns ...func(*ec2.Options)) (*ec2.DetachVolumeOutput, error) + ModifyInstanceAttribute(ctx context.Context, params *ec2.ModifyInstanceAttributeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyInstanceAttributeOutput, error) + ModifyVolume(ctx context.Context, params *ec2.ModifyVolumeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyVolumeOutput, error) + RevokeSecurityGroupIngress(ctx context.Context, params *ec2.RevokeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) } // EC2 is an abstraction over AWS', to allow mocking/other implementations @@ -347,126 +375,124 @@ type Services interface { // TODO: Should we rename this to AWS (EBS & ELB are not technically part of EC2) type EC2 interface { // Query EC2 for instances matching the filter - DescribeInstances(request *ec2.DescribeInstancesInput) ([]*ec2.Instance, error) + DescribeInstances(ctx context.Context, request *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) ([]ec2types.Instance, error) // Attach a volume to an instance - AttachVolume(*ec2.AttachVolumeInput) (*ec2.VolumeAttachment, error) + AttachVolume(ctx context.Context, request *ec2.AttachVolumeInput, optFns ...func(*ec2.Options)) (*ec2.AttachVolumeOutput, error) // Detach a volume from an instance it is attached to - DetachVolume(request *ec2.DetachVolumeInput) (resp *ec2.VolumeAttachment, err error) + DetachVolume(ctx context.Context, request *ec2.DetachVolumeInput, optFns ...func(*ec2.Options)) (resp *ec2.DetachVolumeOutput, err error) // Lists volumes - DescribeVolumes(request *ec2.DescribeVolumesInput) ([]*ec2.Volume, error) + DescribeVolumes(ctx context.Context, request *ec2.DescribeVolumesInput, optFns ...func(*ec2.Options)) ([]ec2types.Volume, error) // Create an EBS volume - CreateVolume(request *ec2.CreateVolumeInput) (resp *ec2.Volume, err error) + CreateVolume(ctx context.Context, request *ec2.CreateVolumeInput, optFns ...func(*ec2.Options)) (resp *ec2.CreateVolumeOutput, err error) // Delete an EBS volume - DeleteVolume(*ec2.DeleteVolumeInput) (*ec2.DeleteVolumeOutput, error) + DeleteVolume(ctx context.Context, request *ec2.DeleteVolumeInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVolumeOutput, error) - ModifyVolume(*ec2.ModifyVolumeInput) (*ec2.ModifyVolumeOutput, error) + ModifyVolume(ctx context.Context, request *ec2.ModifyVolumeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyVolumeOutput, error) - DescribeVolumeModifications(*ec2.DescribeVolumesModificationsInput) ([]*ec2.VolumeModification, error) + DescribeVolumeModifications(ctx context.Context, request *ec2.DescribeVolumesModificationsInput, optFns ...func(*ec2.Options)) ([]ec2types.VolumeModification, error) - DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) + DescribeSecurityGroups(ctx context.Context, request *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) ([]ec2types.SecurityGroup, error) - CreateSecurityGroup(*ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) - DeleteSecurityGroup(request *ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error) + CreateSecurityGroup(ctx context.Context, request *ec2.CreateSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) + DeleteSecurityGroup(ctx context.Context, request *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) - AuthorizeSecurityGroupIngress(*ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error) - RevokeSecurityGroupIngress(*ec2.RevokeSecurityGroupIngressInput) (*ec2.RevokeSecurityGroupIngressOutput, error) + AuthorizeSecurityGroupIngress(ctx context.Context, request *ec2.AuthorizeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) + RevokeSecurityGroupIngress(ctx context.Context, request *ec2.RevokeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) - DescribeSubnets(*ec2.DescribeSubnetsInput) ([]*ec2.Subnet, error) + DescribeSubnets(ctx context.Context, request *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) ([]ec2types.Subnet, error) - DescribeAvailabilityZones(request *ec2.DescribeAvailabilityZonesInput) ([]*ec2.AvailabilityZone, error) + DescribeAvailabilityZones(ctx context.Context, request *ec2.DescribeAvailabilityZonesInput, optFns ...func(*ec2.Options)) ([]ec2types.AvailabilityZone, error) - CreateTags(*ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) - DeleteTags(input *ec2.DeleteTagsInput) (*ec2.DeleteTagsOutput, error) + CreateTags(ctx context.Context, request *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) + DeleteTags(ctx context.Context, input *ec2.DeleteTagsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTagsOutput, error) - DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ([]*ec2.RouteTable, error) - CreateRoute(request *ec2.CreateRouteInput) (*ec2.CreateRouteOutput, error) - DeleteRoute(request *ec2.DeleteRouteInput) (*ec2.DeleteRouteOutput, error) + DescribeRouteTables(ctx context.Context, request *ec2.DescribeRouteTablesInput, optFns ...func(*ec2.Options)) ([]ec2types.RouteTable, error) + CreateRoute(ctx context.Context, request *ec2.CreateRouteInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) + DeleteRoute(ctx context.Context, request *ec2.DeleteRouteInput, optFns ...func(*ec2.Options)) (*ec2.DeleteRouteOutput, error) - ModifyInstanceAttribute(request *ec2.ModifyInstanceAttributeInput) (*ec2.ModifyInstanceAttributeOutput, error) + ModifyInstanceAttribute(ctx context.Context, request *ec2.ModifyInstanceAttributeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyInstanceAttributeOutput, error) - DescribeVpcs(input *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) + DescribeVpcs(ctx context.Context, input *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) - DescribeNetworkInterfaces(input *ec2.DescribeNetworkInterfacesInput) (*ec2.DescribeNetworkInterfacesOutput, error) + DescribeNetworkInterfaces(ctx context.Context, input *ec2.DescribeNetworkInterfacesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) } // ELB is a simple pass-through of AWS' ELB client interface, which allows for testing type ELB interface { - CreateLoadBalancer(*elb.CreateLoadBalancerInput) (*elb.CreateLoadBalancerOutput, error) - DeleteLoadBalancer(*elb.DeleteLoadBalancerInput) (*elb.DeleteLoadBalancerOutput, error) - DescribeLoadBalancers(*elb.DescribeLoadBalancersInput) (*elb.DescribeLoadBalancersOutput, error) - AddTags(*elb.AddTagsInput) (*elb.AddTagsOutput, error) - RegisterInstancesWithLoadBalancer(*elb.RegisterInstancesWithLoadBalancerInput) (*elb.RegisterInstancesWithLoadBalancerOutput, error) - DeregisterInstancesFromLoadBalancer(*elb.DeregisterInstancesFromLoadBalancerInput) (*elb.DeregisterInstancesFromLoadBalancerOutput, error) - CreateLoadBalancerPolicy(*elb.CreateLoadBalancerPolicyInput) (*elb.CreateLoadBalancerPolicyOutput, error) - SetLoadBalancerPoliciesForBackendServer(*elb.SetLoadBalancerPoliciesForBackendServerInput) (*elb.SetLoadBalancerPoliciesForBackendServerOutput, error) - SetLoadBalancerPoliciesOfListener(input *elb.SetLoadBalancerPoliciesOfListenerInput) (*elb.SetLoadBalancerPoliciesOfListenerOutput, error) - DescribeLoadBalancerPolicies(input *elb.DescribeLoadBalancerPoliciesInput) (*elb.DescribeLoadBalancerPoliciesOutput, error) + CreateLoadBalancer(ctx context.Context, input *elb.CreateLoadBalancerInput, optFns ...func(*elb.Options)) (*elb.CreateLoadBalancerOutput, error) + DeleteLoadBalancer(ctx context.Context, input *elb.DeleteLoadBalancerInput, optFns ...func(*elb.Options)) (*elb.DeleteLoadBalancerOutput, error) + DescribeLoadBalancers(ctx context.Context, input *elb.DescribeLoadBalancersInput, optFns ...func(*elb.Options)) (*elb.DescribeLoadBalancersOutput, error) + AddTags(ctx context.Context, input *elb.AddTagsInput, optFns ...func(*elb.Options)) (*elb.AddTagsOutput, error) + RegisterInstancesWithLoadBalancer(ctx context.Context, input *elb.RegisterInstancesWithLoadBalancerInput, optFns ...func(*elb.Options)) (*elb.RegisterInstancesWithLoadBalancerOutput, error) + DeregisterInstancesFromLoadBalancer(ctx context.Context, input *elb.DeregisterInstancesFromLoadBalancerInput, optFns ...func(*elb.Options)) (*elb.DeregisterInstancesFromLoadBalancerOutput, error) + CreateLoadBalancerPolicy(ctx context.Context, input *elb.CreateLoadBalancerPolicyInput, optFns ...func(*elb.Options)) (*elb.CreateLoadBalancerPolicyOutput, error) + SetLoadBalancerPoliciesForBackendServer(ctx context.Context, input *elb.SetLoadBalancerPoliciesForBackendServerInput, optFns ...func(*elb.Options)) (*elb.SetLoadBalancerPoliciesForBackendServerOutput, error) + SetLoadBalancerPoliciesOfListener(ctx context.Context, input *elb.SetLoadBalancerPoliciesOfListenerInput, optFns ...func(*elb.Options)) (*elb.SetLoadBalancerPoliciesOfListenerOutput, error) + DescribeLoadBalancerPolicies(ctx context.Context, input *elb.DescribeLoadBalancerPoliciesInput, optFns ...func(*elb.Options)) (*elb.DescribeLoadBalancerPoliciesOutput, error) - DetachLoadBalancerFromSubnets(*elb.DetachLoadBalancerFromSubnetsInput) (*elb.DetachLoadBalancerFromSubnetsOutput, error) - AttachLoadBalancerToSubnets(*elb.AttachLoadBalancerToSubnetsInput) (*elb.AttachLoadBalancerToSubnetsOutput, error) + DetachLoadBalancerFromSubnets(ctx context.Context, input *elb.DetachLoadBalancerFromSubnetsInput, optFns ...func(*elb.Options)) (*elb.DetachLoadBalancerFromSubnetsOutput, error) + AttachLoadBalancerToSubnets(ctx context.Context, input *elb.AttachLoadBalancerToSubnetsInput, optFns ...func(*elb.Options)) (*elb.AttachLoadBalancerToSubnetsOutput, error) - CreateLoadBalancerListeners(*elb.CreateLoadBalancerListenersInput) (*elb.CreateLoadBalancerListenersOutput, error) - DeleteLoadBalancerListeners(*elb.DeleteLoadBalancerListenersInput) (*elb.DeleteLoadBalancerListenersOutput, error) + CreateLoadBalancerListeners(ctx context.Context, input *elb.CreateLoadBalancerListenersInput, optFns ...func(*elb.Options)) (*elb.CreateLoadBalancerListenersOutput, error) + DeleteLoadBalancerListeners(ctx context.Context, input *elb.DeleteLoadBalancerListenersInput, optFns ...func(*elb.Options)) (*elb.DeleteLoadBalancerListenersOutput, error) - ApplySecurityGroupsToLoadBalancer(*elb.ApplySecurityGroupsToLoadBalancerInput) (*elb.ApplySecurityGroupsToLoadBalancerOutput, error) + ApplySecurityGroupsToLoadBalancer(ctx context.Context, input *elb.ApplySecurityGroupsToLoadBalancerInput, optFns ...func(*elb.Options)) (*elb.ApplySecurityGroupsToLoadBalancerOutput, error) - ConfigureHealthCheck(*elb.ConfigureHealthCheckInput) (*elb.ConfigureHealthCheckOutput, error) + ConfigureHealthCheck(ctx context.Context, input *elb.ConfigureHealthCheckInput, optFns ...func(*elb.Options)) (*elb.ConfigureHealthCheckOutput, error) - DescribeLoadBalancerAttributes(*elb.DescribeLoadBalancerAttributesInput) (*elb.DescribeLoadBalancerAttributesOutput, error) - ModifyLoadBalancerAttributes(*elb.ModifyLoadBalancerAttributesInput) (*elb.ModifyLoadBalancerAttributesOutput, error) + DescribeLoadBalancerAttributes(ctx context.Context, input *elb.DescribeLoadBalancerAttributesInput, optFns ...func(*elb.Options)) (*elb.DescribeLoadBalancerAttributesOutput, error) + ModifyLoadBalancerAttributes(ctx context.Context, input *elb.ModifyLoadBalancerAttributesInput, optFns ...func(*elb.Options)) (*elb.ModifyLoadBalancerAttributesOutput, error) } // ELBV2 is a simple pass-through of AWS' ELBV2 client interface, which allows for testing type ELBV2 interface { - AddTags(input *elbv2.AddTagsInput) (*elbv2.AddTagsOutput, error) + AddTags(ctx context.Context, input *elbv2.AddTagsInput, optFns ...func(*elbv2.Options)) (*elbv2.AddTagsOutput, error) - CreateLoadBalancer(*elbv2.CreateLoadBalancerInput) (*elbv2.CreateLoadBalancerOutput, error) - DescribeLoadBalancers(*elbv2.DescribeLoadBalancersInput) (*elbv2.DescribeLoadBalancersOutput, error) - DeleteLoadBalancer(*elbv2.DeleteLoadBalancerInput) (*elbv2.DeleteLoadBalancerOutput, error) + CreateLoadBalancer(ctx context.Context, input *elbv2.CreateLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateLoadBalancerOutput, error) + DescribeLoadBalancers(ctx context.Context, input *elbv2.DescribeLoadBalancersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancersOutput, error) + DeleteLoadBalancer(ctx context.Context, input *elbv2.DeleteLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteLoadBalancerOutput, error) - ModifyLoadBalancerAttributes(*elbv2.ModifyLoadBalancerAttributesInput) (*elbv2.ModifyLoadBalancerAttributesOutput, error) - DescribeLoadBalancerAttributes(*elbv2.DescribeLoadBalancerAttributesInput) (*elbv2.DescribeLoadBalancerAttributesOutput, error) + ModifyLoadBalancerAttributes(ctx context.Context, input *elbv2.ModifyLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyLoadBalancerAttributesOutput, error) + DescribeLoadBalancerAttributes(ctx context.Context, input *elbv2.DescribeLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancerAttributesOutput, error) - CreateTargetGroup(*elbv2.CreateTargetGroupInput) (*elbv2.CreateTargetGroupOutput, error) - DescribeTargetGroups(*elbv2.DescribeTargetGroupsInput) (*elbv2.DescribeTargetGroupsOutput, error) - ModifyTargetGroup(*elbv2.ModifyTargetGroupInput) (*elbv2.ModifyTargetGroupOutput, error) - DeleteTargetGroup(*elbv2.DeleteTargetGroupInput) (*elbv2.DeleteTargetGroupOutput, error) + CreateTargetGroup(ctx context.Context, input *elbv2.CreateTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateTargetGroupOutput, error) + DescribeTargetGroups(ctx context.Context, input *elbv2.DescribeTargetGroupsInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupsOutput, error) + ModifyTargetGroup(ctx context.Context, input *elbv2.ModifyTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupOutput, error) + DeleteTargetGroup(ctx context.Context, input *elbv2.DeleteTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteTargetGroupOutput, error) - DescribeTargetHealth(input *elbv2.DescribeTargetHealthInput) (*elbv2.DescribeTargetHealthOutput, error) + DescribeTargetHealth(ctx context.Context, input *elbv2.DescribeTargetHealthInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetHealthOutput, error) - DescribeTargetGroupAttributes(*elbv2.DescribeTargetGroupAttributesInput) (*elbv2.DescribeTargetGroupAttributesOutput, error) - ModifyTargetGroupAttributes(*elbv2.ModifyTargetGroupAttributesInput) (*elbv2.ModifyTargetGroupAttributesOutput, error) + DescribeTargetGroupAttributes(ctx context.Context, input *elbv2.DescribeTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupAttributesOutput, error) + ModifyTargetGroupAttributes(ctx context.Context, input *elbv2.ModifyTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupAttributesOutput, error) - RegisterTargets(*elbv2.RegisterTargetsInput) (*elbv2.RegisterTargetsOutput, error) - DeregisterTargets(*elbv2.DeregisterTargetsInput) (*elbv2.DeregisterTargetsOutput, error) + RegisterTargets(ctx context.Context, input *elbv2.RegisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.RegisterTargetsOutput, error) + DeregisterTargets(ctx context.Context, input *elbv2.DeregisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.DeregisterTargetsOutput, error) - CreateListener(*elbv2.CreateListenerInput) (*elbv2.CreateListenerOutput, error) - DescribeListeners(*elbv2.DescribeListenersInput) (*elbv2.DescribeListenersOutput, error) - DeleteListener(*elbv2.DeleteListenerInput) (*elbv2.DeleteListenerOutput, error) - ModifyListener(*elbv2.ModifyListenerInput) (*elbv2.ModifyListenerOutput, error) - - WaitUntilLoadBalancersDeleted(*elbv2.DescribeLoadBalancersInput) error + CreateListener(ctx context.Context, input *elbv2.CreateListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateListenerOutput, error) + DescribeListeners(ctx context.Context, input *elbv2.DescribeListenersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeListenersOutput, error) + DeleteListener(ctx context.Context, input *elbv2.DeleteListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteListenerOutput, error) + ModifyListener(ctx context.Context, input *elbv2.ModifyListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyListenerOutput, error) } // ASG is a simple pass-through of the Autoscaling client interface, which // allows for testing. type ASG interface { - UpdateAutoScalingGroup(*autoscaling.UpdateAutoScalingGroupInput) (*autoscaling.UpdateAutoScalingGroupOutput, error) - DescribeAutoScalingGroups(*autoscaling.DescribeAutoScalingGroupsInput) (*autoscaling.DescribeAutoScalingGroupsOutput, error) + UpdateAutoScalingGroup(ctx context.Context, input *autoscaling.UpdateAutoScalingGroupInput, optFns ...func(*autoscaling.Options)) (*autoscaling.UpdateAutoScalingGroupOutput, error) + DescribeAutoScalingGroups(ctx context.Context, input *autoscaling.DescribeAutoScalingGroupsInput, optFns ...func(*autoscaling.Options)) (*autoscaling.DescribeAutoScalingGroupsOutput, error) } // KMS is a simple pass-through of the Key Management Service client interface, // which allows for testing. type KMS interface { - DescribeKey(*kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) + DescribeKey(ctx context.Context, input *kms.DescribeKeyInput, optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) } // EC2Metadata is an abstraction over the AWS metadata service. type EC2Metadata interface { // Query the EC2 metadata service (used to discover instance-id etc) - GetMetadata(path string) (string, error) - Region() (string, error) + GetMetadata(ctx context.Context, params *imds.GetMetadataInput, optFns ...func(*imds.Options)) (*imds.GetMetadataOutput, error) + GetRegion(ctx context.Context, params *imds.GetRegionInput, optFns ...func(*imds.Options)) (*imds.GetRegionOutput, error) } // AWS volume types @@ -492,7 +518,7 @@ const ( type VolumeOptions struct { CapacityGB int Tags map[string]string - VolumeType string + VolumeType ec2types.VolumeType AvailabilityZone string // IOPSPerGB x CapacityGB will give total IOPS of the volume to create. // Calculated total IOPS will be capped at MaxTotalIOPS. @@ -509,43 +535,43 @@ type Volumes interface { // Attach the disk to the node with the specified NodeName // nodeName can be empty to mean "the instance on which we are running" // Returns the device (e.g. /dev/xvdf) where we attached the volume - AttachDisk(diskName KubernetesVolumeID, nodeName types.NodeName) (string, error) + AttachDisk(ctx context.Context, diskName KubernetesVolumeID, nodeName types.NodeName) (string, error) // Detach the disk from the node with the specified NodeName // nodeName can be empty to mean "the instance on which we are running" // Returns the device where the volume was attached - DetachDisk(diskName KubernetesVolumeID, nodeName types.NodeName) (string, error) + DetachDisk(ctx context.Context, diskName KubernetesVolumeID, nodeName types.NodeName) (string, error) // Create a volume with the specified options - CreateDisk(volumeOptions *VolumeOptions) (volumeName KubernetesVolumeID, err error) + CreateDisk(ctx context.Context, volumeOptions *VolumeOptions) (volumeName KubernetesVolumeID, err error) // Delete the specified volume // Returns true iff the volume was deleted // If the was not found, returns (false, nil) - DeleteDisk(volumeName KubernetesVolumeID) (bool, error) + DeleteDisk(ctx context.Context, volumeName KubernetesVolumeID) (bool, error) // Get labels to apply to volume on creation - GetVolumeLabels(volumeName KubernetesVolumeID) (map[string]string, error) + GetVolumeLabels(ctx context.Context, volumeName KubernetesVolumeID) (map[string]string, error) // Get volume's disk path from volume name // return the device path where the volume is attached - GetDiskPath(volumeName KubernetesVolumeID) (string, error) + GetDiskPath(ctx context.Context, volumeName KubernetesVolumeID) (string, error) // Check if the volume is already attached to the node with the specified NodeName - DiskIsAttached(diskName KubernetesVolumeID, nodeName types.NodeName) (bool, error) + DiskIsAttached(ctx context.Context, diskName KubernetesVolumeID, nodeName types.NodeName) (bool, error) // Check if disks specified in argument map are still attached to their respective nodes. - DisksAreAttached(map[types.NodeName][]KubernetesVolumeID) (map[types.NodeName]map[KubernetesVolumeID]bool, error) + DisksAreAttached(ctx context.Context, nodeVolumes map[types.NodeName][]KubernetesVolumeID) (map[types.NodeName]map[KubernetesVolumeID]bool, error) // Expand the disk to new size - ResizeDisk(diskName KubernetesVolumeID, oldSize resource.Quantity, newSize resource.Quantity) (resource.Quantity, error) + ResizeDisk(ctx context.Context, diskName KubernetesVolumeID, oldSize resource.Quantity, newSize resource.Quantity) (resource.Quantity, error) } // InstanceGroups is an interface for managing cloud-managed instance groups / autoscaling instance groups // TODO: Allow other clouds to implement this type InstanceGroups interface { // Set the size to the fixed size - ResizeInstanceGroup(instanceGroupName string, size int) error + ResizeInstanceGroup(ctx context.Context, instanceGroupName string, size int) error // Queries the cloud provider for information about the specified instance group - DescribeInstanceGroup(instanceGroupName string) (InstanceGroupInfo, error) + DescribeInstanceGroup(ctx context.Context, instanceGroupName string) (InstanceGroupInfo, error) } // InstanceGroupInfo is returned by InstanceGroups.Describe, and exposes information about the group. @@ -681,19 +707,19 @@ type CloudConfig struct { // GetRegion returns the AWS region from the config, if set, or gets it from the metadata // service if unset and sets in config -func (cfg *CloudConfig) GetRegion(metadata EC2Metadata) (string, error) { +func (cfg *CloudConfig) GetRegion(ctx context.Context, metadata EC2Metadata) (string, error) { if cfg.Global.Region != "" { return cfg.Global.Region, nil } klog.Info("Loading region from metadata service") - region, err := metadata.Region() + region, err := metadata.GetRegion(ctx, &imds.GetRegionInput{}) if err != nil { return "", err } - cfg.Global.Region = region - return region, nil + cfg.Global.Region = region.Region + return region.Region, nil } func (cfg *CloudConfig) validateOverrides() error { @@ -734,51 +760,35 @@ func (cfg *CloudConfig) validateOverrides() error { return nil } -func (cfg *CloudConfig) getResolver() endpoints.ResolverFunc { - defaultResolver := endpoints.DefaultResolver() - defaultResolverFn := func(service, region string, - optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { - return defaultResolver.EndpointFor(service, region, optFns...) - } - if len(cfg.ServiceOverride) == 0 { - return defaultResolverFn - } - - return func(service, region string, - optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { - for _, override := range cfg.ServiceOverride { - if override.Service == service && override.Region == region { - return endpoints.ResolvedEndpoint{ - URL: override.URL, - SigningRegion: override.SigningRegion, - SigningMethod: override.SigningMethod, - SigningName: override.SigningName, - }, nil - } - } - return defaultResolver.EndpointFor(service, region, optFns...) - } -} - // awsSdkEC2 is an implementation of the EC2 interface, backed by aws-sdk-go type awsSdkEC2 struct { - ec2 ec2iface.EC2API + ec2 EC2API } // Interface to make the CloudConfig immutable for awsSDKProvider type awsCloudConfigProvider interface { - getResolver() endpoints.ResolverFunc + GetEC2EndpointOpts(region string) []func(*ec2.Options) + GetCustomEC2Resolver() ec2.EndpointResolverV2 + GetELBEndpointOpts(region string) []func(*elb.Options) + GetCustomELBResolver() elb.EndpointResolverV2 + GetELBV2EndpointOpts(region string) []func(*elbv2.Options) + GetCustomELBV2Resolver() elbv2.EndpointResolverV2 + GetKMSEndpointOpts(region string) []func(*kms.Options) + GetCustomKMSResolver() kms.EndpointResolverV2 + GetIMDSEndpointOpts() []func(*imds.Options) + GetAutoscalingEndpointOpts(region string) []func(*autoscaling.Options) + GetCustomAutoscalingResolver() autoscaling.EndpointResolverV2 } type awsSDKProvider struct { - creds *credentials.Credentials + creds aws.CredentialsProvider cfg awsCloudConfigProvider mutex sync.Mutex regionDelayers map[string]*CrossRequestRetryDelay } -func newAWSSDKProvider(creds *credentials.Credentials, cfg *CloudConfig) *awsSDKProvider { +func newAWSSDKProvider(creds aws.CredentialsProvider, cfg *CloudConfig) *awsSDKProvider { return &awsSDKProvider{ creds: creds, cfg: cfg, @@ -786,45 +796,6 @@ func newAWSSDKProvider(creds *credentials.Credentials, cfg *CloudConfig) *awsSDK } } -func (p *awsSDKProvider) addHandlers(regionName string, h *request.Handlers) { - h.Build.PushFrontNamed(request.NamedHandler{ - Name: "k8s/user-agent", - Fn: request.MakeAddToUserAgentHandler("kubernetes", version.Get().String()), - }) - - h.Sign.PushFrontNamed(request.NamedHandler{ - Name: "k8s/logger", - Fn: awsHandlerLogger, - }) - - delayer := p.getCrossRequestRetryDelay(regionName) - if delayer != nil { - h.Sign.PushFrontNamed(request.NamedHandler{ - Name: "k8s/delay-presign", - Fn: delayer.BeforeSign, - }) - - h.AfterRetry.PushFrontNamed(request.NamedHandler{ - Name: "k8s/delay-afterretry", - Fn: delayer.AfterRetry, - }) - } - - p.addAPILoggingHandlers(h) -} - -func (p *awsSDKProvider) addAPILoggingHandlers(h *request.Handlers) { - h.Send.PushBackNamed(request.NamedHandler{ - Name: "k8s/api-request", - Fn: awsSendHandlerLogger, - }) - - h.ValidateResponse.PushFrontNamed(request.NamedHandler{ - Name: "k8s/api-validate-response", - Fn: awsValidateResponseHandlerLogger, - }) -} - // Get a CrossRequestRetryDelay, scoped to the region, not to the request. // This means that when we hit a limit on a call, we will delay _all_ calls to the API. // We do this to protect the AWS account from becoming overloaded and effectively locked. @@ -876,132 +847,158 @@ func (c *Cloud) SetInformers(informerFactory informers.SharedInformerFactory) { }) } -func (p *awsSDKProvider) Compute(regionName string) (EC2, error) { - awsConfig := &aws.Config{ - Region: ®ionName, - Credentials: p.creds, +func (p *awsSDKProvider) Compute(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (EC2, error) { + cfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithDefaultsMode(aws.DefaultsModeInRegion), + awsConfig.WithRegion(regionName), + ) + if assumeRoleProvider != nil { + cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider) } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.getResolver()) - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + return nil, fmt.Errorf("unable to initialize AWS config: %v", err) } - service := ec2.New(sess) - p.addHandlers(regionName, &service.Handlers) + p.AddMiddleware(ctx, regionName, &cfg) + var opts []func(*ec2.Options) = p.cfg.GetEC2EndpointOpts(regionName) + opts = append(opts, func(o *ec2.Options) { + o.Retryer = &customRetryer{ + retry.NewStandard(), + } + o.EndpointResolverV2 = p.cfg.GetCustomEC2Resolver() + }) + + ec2Client := ec2.NewFromConfig(cfg, opts...) ec2 := &awsSdkEC2{ - ec2: service, + ec2: ec2Client, } return ec2, nil } -func (p *awsSDKProvider) LoadBalancing(regionName string) (ELB, error) { - awsConfig := &aws.Config{ - Region: ®ionName, - Credentials: p.creds, +func (p *awsSDKProvider) LoadBalancing(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (ELB, error) { + cfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithDefaultsMode(aws.DefaultsModeInRegion), + awsConfig.WithRegion(regionName), + ) + if assumeRoleProvider != nil { + cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider) } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.getResolver()) - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + return nil, fmt.Errorf("unable to initialize AWS config: %v", err) } - elbClient := elb.New(sess) - p.addHandlers(regionName, &elbClient.Handlers) + + p.AddMiddleware(ctx, regionName, &cfg) + var opts []func(*elb.Options) = p.cfg.GetELBEndpointOpts(regionName) + opts = append(opts, func(o *elb.Options) { + o.Retryer = &customRetryer{ + retry.NewStandard(), + } + o.EndpointResolverV2 = p.cfg.GetCustomELBResolver() + }) + + elbClient := elb.NewFromConfig(cfg, opts...) return elbClient, nil } -func (p *awsSDKProvider) LoadBalancingV2(regionName string) (ELBV2, error) { - awsConfig := &aws.Config{ - Region: ®ionName, - Credentials: p.creds, +func (p *awsSDKProvider) LoadBalancingV2(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (ELBV2, error) { + cfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithDefaultsMode(aws.DefaultsModeInRegion), + awsConfig.WithRegion(regionName), + ) + if assumeRoleProvider != nil { + cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider) } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.getResolver()) - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + return nil, fmt.Errorf("unable to initialize AWS config: %v", err) } - elbClient := elbv2.New(sess) - p.addHandlers(regionName, &elbClient.Handlers) + p.AddMiddleware(ctx, regionName, &cfg) + var opts []func(*elbv2.Options) = p.cfg.GetELBV2EndpointOpts(regionName) + opts = append(opts, func(o *elbv2.Options) { + o.Retryer = &customRetryer{ + retry.NewStandard(), + } + o.EndpointResolverV2 = p.cfg.GetCustomELBV2Resolver() + }) - return elbClient, nil + elbv2Client := elbv2.NewFromConfig(cfg, opts...) + + return elbv2Client, nil } -func (p *awsSDKProvider) Autoscaling(regionName string) (ASG, error) { - awsConfig := &aws.Config{ - Region: ®ionName, - Credentials: p.creds, +func (p *awsSDKProvider) Autoscaling(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (ASG, error) { + cfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithDefaultsMode(aws.DefaultsModeInRegion), + awsConfig.WithRegion(regionName), + ) + if assumeRoleProvider != nil { + cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider) } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.getResolver()) - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + return nil, fmt.Errorf("unable to initialize AWS config: %v", err) } - client := autoscaling.New(sess) - p.addHandlers(regionName, &client.Handlers) + p.AddMiddleware(ctx, regionName, &cfg) + var opts []func(*autoscaling.Options) = p.cfg.GetAutoscalingEndpointOpts(regionName) + opts = append(opts, func(o *autoscaling.Options) { + o.Retryer = &customRetryer{ + retry.NewStandard(), + } + o.EndpointResolverV2 = p.cfg.GetCustomAutoscalingResolver() + }) - return client, nil + autoscalingClient := autoscaling.NewFromConfig(cfg, opts...) + + return autoscalingClient, nil } -func (p *awsSDKProvider) Metadata() (EC2Metadata, error) { - sess, err := session.NewSession(&aws.Config{ - EndpointResolver: p.cfg.getResolver(), - }) +func (p *awsSDKProvider) Metadata(ctx context.Context) (EC2Metadata, error) { + cfg, err := awsConfig.LoadDefaultConfig(context.TODO(), awsConfig.WithDefaultsMode(aws.DefaultsModeInRegion)) if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + return nil, fmt.Errorf("unable to initialize AWS config: %v", err) } - client := ec2metadata.New(sess) - p.addAPILoggingHandlers(&client.Handlers) - return client, nil + + p.addAPILoggingMiddleware(&cfg) + + // Unlike other SDK clients, the IMDS client does not support signing, so any overrides of the signing region and name + // from awsSDKProvider.cfg will not be recognized. + // Standard SDK clients use SigV4: https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html + // But IMDS uses a different request pattern: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html + var opts []func(*imds.Options) = p.cfg.GetIMDSEndpointOpts() + imdsClient := imds.NewFromConfig(cfg, opts...) + + return imdsClient, nil } -func (p *awsSDKProvider) KeyManagement(regionName string) (KMS, error) { - awsConfig := &aws.Config{ - Region: ®ionName, - Credentials: p.creds, +func (p *awsSDKProvider) KeyManagement(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (KMS, error) { + cfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithDefaultsMode(aws.DefaultsModeInRegion), + awsConfig.WithRegion(regionName), + ) + if assumeRoleProvider != nil { + cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider) } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.getResolver()) - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + return nil, fmt.Errorf("unable to initialize AWS config: %v", err) } - kmsClient := kms.New(sess) - p.addHandlers(regionName, &kmsClient.Handlers) + p.AddMiddleware(ctx, regionName, &cfg) + var opts []func(*kms.Options) = p.cfg.GetKMSEndpointOpts(regionName) + opts = append(opts, func(o *kms.Options) { + o.Retryer = &customRetryer{ + retry.NewStandard(), + } + o.EndpointResolverV2 = p.cfg.GetCustomKMSResolver() + }) + + kmsClient := kms.NewFromConfig(cfg, opts...) return kmsClient, nil } -func newEc2Filter(name string, values ...string) *ec2.Filter { - filter := &ec2.Filter{ +func newEc2Filter(name string, values ...string) ec2types.Filter { + filter := ec2types.Filter{ Name: aws.String(name), } for _, value := range values { - filter.Values = append(filter.Values, aws.String(value)) + filter.Values = append(filter.Values, value) } return filter } @@ -1017,20 +1014,20 @@ func (c *Cloud) CurrentNodeName(ctx context.Context, hostname string) (types.Nod } // Implementation of EC2.Instances -func (s *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*ec2.Instance, error) { +func (s *awsSdkEC2) DescribeInstances(ctx context.Context, request *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) ([]ec2types.Instance, error) { // Instances are paged - results := []*ec2.Instance{} + results := []ec2types.Instance{} var nextToken *string requestTime := time.Now() if request.MaxResults == nil && len(request.InstanceIds) == 0 { // MaxResults must be set in order for pagination to work // MaxResults cannot be set with InstanceIds - request.MaxResults = aws.Int64(1000) + request.MaxResults = aws.Int32(1000) } for { - response, err := s.ec2.DescribeInstances(request) + response, err := s.ec2.DescribeInstances(ctx, request, optFns...) if err != nil { recordAWSMetric("describe_instance", 0, err) return nil, fmt.Errorf("error listing AWS instances: %q", err) @@ -1041,7 +1038,7 @@ func (s *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*e } nextToken = response.NextToken - if aws.StringValue(nextToken) == "" { + if aws.ToString(nextToken) == "" { break } request.NextToken = nextToken @@ -1052,22 +1049,22 @@ func (s *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*e } // DescribeNetworkInterfaces describes network interface provided in the input. -func (s *awsSdkEC2) DescribeNetworkInterfaces(input *ec2.DescribeNetworkInterfacesInput) (*ec2.DescribeNetworkInterfacesOutput, error) { +func (s *awsSdkEC2) DescribeNetworkInterfaces(ctx context.Context, input *ec2.DescribeNetworkInterfacesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) { requestTime := time.Now() - resp, err := s.ec2.DescribeNetworkInterfaces(input) + resp, err := s.ec2.DescribeNetworkInterfaces(ctx, input, optFns...) timeTaken := time.Since(requestTime).Seconds() recordAWSMetric("describe_network_interfaces", timeTaken, err) return resp, err } // Implements EC2.DescribeSecurityGroups -func (s *awsSdkEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) { +func (s *awsSdkEC2) DescribeSecurityGroups(ctx context.Context, request *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) ([]ec2types.SecurityGroup, error) { // Security groups are paged - results := []*ec2.SecurityGroup{} + results := []ec2types.SecurityGroup{} var nextToken *string requestTime := time.Now() for { - response, err := s.ec2.DescribeSecurityGroups(request) + response, err := s.ec2.DescribeSecurityGroups(ctx, request, optFns...) if err != nil { recordAWSMetric("describe_security_groups", 0, err) return nil, fmt.Errorf("error listing AWS security groups: %q", err) @@ -1076,7 +1073,7 @@ func (s *awsSdkEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsIn results = append(results, response.SecurityGroups...) nextToken = response.NextToken - if aws.StringValue(nextToken) == "" { + if aws.ToString(nextToken) == "" { break } request.NextToken = nextToken @@ -1086,29 +1083,29 @@ func (s *awsSdkEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsIn return results, nil } -func (s *awsSdkEC2) AttachVolume(request *ec2.AttachVolumeInput) (*ec2.VolumeAttachment, error) { +func (s *awsSdkEC2) AttachVolume(ctx context.Context, request *ec2.AttachVolumeInput, optFns ...func(*ec2.Options)) (*ec2.AttachVolumeOutput, error) { requestTime := time.Now() - resp, err := s.ec2.AttachVolume(request) + resp, err := s.ec2.AttachVolume(ctx, request, optFns...) timeTaken := time.Since(requestTime).Seconds() recordAWSMetric("attach_volume", timeTaken, err) return resp, err } -func (s *awsSdkEC2) DetachVolume(request *ec2.DetachVolumeInput) (*ec2.VolumeAttachment, error) { +func (s *awsSdkEC2) DetachVolume(ctx context.Context, request *ec2.DetachVolumeInput, optFns ...func(*ec2.Options)) (*ec2.DetachVolumeOutput, error) { requestTime := time.Now() - resp, err := s.ec2.DetachVolume(request) + resp, err := s.ec2.DetachVolume(ctx, request, optFns...) timeTaken := time.Since(requestTime).Seconds() recordAWSMetric("detach_volume", timeTaken, err) return resp, err } -func (s *awsSdkEC2) DescribeVolumes(request *ec2.DescribeVolumesInput) ([]*ec2.Volume, error) { +func (s *awsSdkEC2) DescribeVolumes(ctx context.Context, request *ec2.DescribeVolumesInput, optFns ...func(*ec2.Options)) ([]ec2types.Volume, error) { // Volumes are paged - results := []*ec2.Volume{} + results := []ec2types.Volume{} var nextToken *string requestTime := time.Now() for { - response, err := s.ec2.DescribeVolumes(request) + response, err := s.ec2.DescribeVolumes(ctx, request, optFns...) if err != nil { recordAWSMetric("describe_volume", 0, err) @@ -1118,7 +1115,7 @@ func (s *awsSdkEC2) DescribeVolumes(request *ec2.DescribeVolumesInput) ([]*ec2.V results = append(results, response.Volumes...) nextToken = response.NextToken - if aws.StringValue(nextToken) == "" { + if aws.ToString(nextToken) == "" { break } request.NextToken = nextToken @@ -1128,43 +1125,43 @@ func (s *awsSdkEC2) DescribeVolumes(request *ec2.DescribeVolumesInput) ([]*ec2.V return results, nil } -func (s *awsSdkEC2) CreateVolume(request *ec2.CreateVolumeInput) (*ec2.Volume, error) { +func (s *awsSdkEC2) CreateVolume(ctx context.Context, request *ec2.CreateVolumeInput, optFns ...func(*ec2.Options)) (*ec2.CreateVolumeOutput, error) { requestTime := time.Now() - resp, err := s.ec2.CreateVolume(request) + resp, err := s.ec2.CreateVolume(ctx, request, optFns...) timeTaken := time.Since(requestTime).Seconds() recordAWSMetric("create_volume", timeTaken, err) return resp, err } -func (s *awsSdkEC2) DeleteVolume(request *ec2.DeleteVolumeInput) (*ec2.DeleteVolumeOutput, error) { +func (s *awsSdkEC2) DeleteVolume(ctx context.Context, request *ec2.DeleteVolumeInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVolumeOutput, error) { requestTime := time.Now() - resp, err := s.ec2.DeleteVolume(request) + resp, err := s.ec2.DeleteVolume(ctx, request, optFns...) timeTaken := time.Since(requestTime).Seconds() recordAWSMetric("delete_volume", timeTaken, err) return resp, err } -func (s *awsSdkEC2) ModifyVolume(request *ec2.ModifyVolumeInput) (*ec2.ModifyVolumeOutput, error) { +func (s *awsSdkEC2) ModifyVolume(ctx context.Context, request *ec2.ModifyVolumeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyVolumeOutput, error) { requestTime := time.Now() - resp, err := s.ec2.ModifyVolume(request) + resp, err := s.ec2.ModifyVolume(ctx, request, optFns...) timeTaken := time.Since(requestTime).Seconds() recordAWSMetric("modify_volume", timeTaken, err) return resp, err } -func (s *awsSdkEC2) DescribeVolumeModifications(request *ec2.DescribeVolumesModificationsInput) ([]*ec2.VolumeModification, error) { +func (s *awsSdkEC2) DescribeVolumeModifications(ctx context.Context, request *ec2.DescribeVolumesModificationsInput, optFns ...func(*ec2.Options)) ([]ec2types.VolumeModification, error) { requestTime := time.Now() - results := []*ec2.VolumeModification{} + results := []ec2types.VolumeModification{} var nextToken *string for { - resp, err := s.ec2.DescribeVolumesModifications(request) + resp, err := s.ec2.DescribeVolumesModifications(ctx, request, optFns...) if err != nil { recordAWSMetric("describe_volume_modification", 0, err) return nil, fmt.Errorf("error listing volume modifictions : %v", err) } results = append(results, resp.VolumesModifications...) nextToken = resp.NextToken - if aws.StringValue(nextToken) == "" { + if aws.ToString(nextToken) == "" { break } request.NextToken = nextToken @@ -1174,62 +1171,62 @@ func (s *awsSdkEC2) DescribeVolumeModifications(request *ec2.DescribeVolumesModi return results, nil } -func (s *awsSdkEC2) DescribeSubnets(request *ec2.DescribeSubnetsInput) ([]*ec2.Subnet, error) { +func (s *awsSdkEC2) DescribeSubnets(ctx context.Context, request *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) ([]ec2types.Subnet, error) { // Subnets are not paged - response, err := s.ec2.DescribeSubnets(request) + response, err := s.ec2.DescribeSubnets(ctx, request, optFns...) if err != nil { return nil, fmt.Errorf("error listing AWS subnets: %q", err) } return response.Subnets, nil } -func (s *awsSdkEC2) DescribeAvailabilityZones(request *ec2.DescribeAvailabilityZonesInput) ([]*ec2.AvailabilityZone, error) { +func (s *awsSdkEC2) DescribeAvailabilityZones(ctx context.Context, request *ec2.DescribeAvailabilityZonesInput, optFns ...func(*ec2.Options)) ([]ec2types.AvailabilityZone, error) { // AZs are not paged - response, err := s.ec2.DescribeAvailabilityZones(request) + response, err := s.ec2.DescribeAvailabilityZones(ctx, request, optFns...) if err != nil { return nil, fmt.Errorf("error listing AWS availability zones: %q", err) } return response.AvailabilityZones, err } -func (s *awsSdkEC2) CreateSecurityGroup(request *ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) { - return s.ec2.CreateSecurityGroup(request) +func (s *awsSdkEC2) CreateSecurityGroup(ctx context.Context, request *ec2.CreateSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) { + return s.ec2.CreateSecurityGroup(ctx, request, optFns...) } -func (s *awsSdkEC2) DeleteSecurityGroup(request *ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error) { - return s.ec2.DeleteSecurityGroup(request) +func (s *awsSdkEC2) DeleteSecurityGroup(ctx context.Context, request *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) { + return s.ec2.DeleteSecurityGroup(ctx, request, optFns...) } -func (s *awsSdkEC2) AuthorizeSecurityGroupIngress(request *ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error) { - return s.ec2.AuthorizeSecurityGroupIngress(request) +func (s *awsSdkEC2) AuthorizeSecurityGroupIngress(ctx context.Context, request *ec2.AuthorizeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) { + return s.ec2.AuthorizeSecurityGroupIngress(ctx, request, optFns...) } -func (s *awsSdkEC2) RevokeSecurityGroupIngress(request *ec2.RevokeSecurityGroupIngressInput) (*ec2.RevokeSecurityGroupIngressOutput, error) { - return s.ec2.RevokeSecurityGroupIngress(request) +func (s *awsSdkEC2) RevokeSecurityGroupIngress(ctx context.Context, request *ec2.RevokeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) { + return s.ec2.RevokeSecurityGroupIngress(ctx, request, optFns...) } -func (s *awsSdkEC2) CreateTags(request *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) { +func (s *awsSdkEC2) CreateTags(ctx context.Context, request *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) { requestTime := time.Now() - resp, err := s.ec2.CreateTags(request) + resp, err := s.ec2.CreateTags(ctx, request, optFns...) timeTaken := time.Since(requestTime).Seconds() recordAWSMetric("create_tags", timeTaken, err) return resp, err } -func (s *awsSdkEC2) DeleteTags(request *ec2.DeleteTagsInput) (*ec2.DeleteTagsOutput, error) { +func (s *awsSdkEC2) DeleteTags(ctx context.Context, request *ec2.DeleteTagsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTagsOutput, error) { requestTime := time.Now() - resp, err := s.ec2.DeleteTags(request) + resp, err := s.ec2.DeleteTags(ctx, request, optFns...) timeTaken := time.Since(requestTime).Seconds() recordAWSMetric("delete_tags", timeTaken, err) return resp, err } -func (s *awsSdkEC2) DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ([]*ec2.RouteTable, error) { - results := []*ec2.RouteTable{} +func (s *awsSdkEC2) DescribeRouteTables(ctx context.Context, request *ec2.DescribeRouteTablesInput, optFns ...func(*ec2.Options)) ([]ec2types.RouteTable, error) { + results := []ec2types.RouteTable{} var nextToken *string requestTime := time.Now() for { - response, err := s.ec2.DescribeRouteTables(request) + response, err := s.ec2.DescribeRouteTables(ctx, request, optFns...) if err != nil { recordAWSMetric("describe_route_tables", 0, err) return nil, fmt.Errorf("error listing AWS route tables: %q", err) @@ -1238,7 +1235,7 @@ func (s *awsSdkEC2) DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ( results = append(results, response.RouteTables...) nextToken = response.NextToken - if aws.StringValue(nextToken) == "" { + if aws.ToString(nextToken) == "" { break } request.NextToken = nextToken @@ -1248,25 +1245,26 @@ func (s *awsSdkEC2) DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ( return results, nil } -func (s *awsSdkEC2) CreateRoute(request *ec2.CreateRouteInput) (*ec2.CreateRouteOutput, error) { - return s.ec2.CreateRoute(request) +func (s *awsSdkEC2) CreateRoute(ctx context.Context, request *ec2.CreateRouteInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) { + return s.ec2.CreateRoute(ctx, request, optFns...) } -func (s *awsSdkEC2) DeleteRoute(request *ec2.DeleteRouteInput) (*ec2.DeleteRouteOutput, error) { - return s.ec2.DeleteRoute(request) +func (s *awsSdkEC2) DeleteRoute(ctx context.Context, request *ec2.DeleteRouteInput, optFns ...func(*ec2.Options)) (*ec2.DeleteRouteOutput, error) { + return s.ec2.DeleteRoute(ctx, request, optFns...) } -func (s *awsSdkEC2) ModifyInstanceAttribute(request *ec2.ModifyInstanceAttributeInput) (*ec2.ModifyInstanceAttributeOutput, error) { - return s.ec2.ModifyInstanceAttribute(request) +func (s *awsSdkEC2) ModifyInstanceAttribute(ctx context.Context, request *ec2.ModifyInstanceAttributeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyInstanceAttributeOutput, error) { + return s.ec2.ModifyInstanceAttribute(ctx, request, optFns...) } -func (s *awsSdkEC2) DescribeVpcs(request *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) { - return s.ec2.DescribeVpcs(request) +func (s *awsSdkEC2) DescribeVpcs(ctx context.Context, request *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) { + return s.ec2.DescribeVpcs(ctx, request, optFns...) } func init() { registerMetrics() cloudprovider.RegisterCloudProvider(ProviderName, func(config io.Reader) (cloudprovider.Interface, error) { + ctx := context.Background() cfg, err := readAWSCloudConfig(config) if err != nil { return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err) @@ -1276,65 +1274,30 @@ func init() { return nil, fmt.Errorf("unable to validate custom endpoint overrides: %v", err) } - metadata, err := newAWSSDKProvider(nil, cfg).Metadata() + metadata, err := newAWSSDKProvider(nil, cfg).Metadata(ctx) if err != nil { return nil, fmt.Errorf("error creating AWS metadata client: %q", err) } - regionName, err := getRegionFromMetadata(*cfg, metadata) + regionName, err := getRegionFromMetadata(ctx, *cfg, metadata) if err != nil { return nil, err } - sess, err := session.NewSessionWithOptions(session.Options{ - Config: *aws.NewConfig().WithRegion(regionName).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint), - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return nil, fmt.Errorf("unable to initialize AWS session: %v", err) - } - - var creds *credentials.Credentials + var creds *stscreds.AssumeRoleProvider if cfg.Global.RoleARN != "" { - stsClient, err := getSTSClient(sess, cfg.Global.RoleARN, cfg.Global.SourceARN) + stsClient, err := services.NewStsClient(ctx, regionName, cfg.Global.RoleARN, cfg.Global.SourceARN) if err != nil { - return nil, fmt.Errorf("unable to create sts client, %v", err) + return nil, fmt.Errorf("unable to create sts v2 client: %v", err) } - creds = credentials.NewChainCredentials( - []credentials.Provider{ - &credentials.EnvProvider{}, - assumeRoleProvider(&stscreds.AssumeRoleProvider{ - Client: stsClient, - RoleARN: cfg.Global.RoleARN, - }), - }) + creds = stscreds.NewAssumeRoleProvider(stsClient, cfg.Global.RoleARN) } aws := newAWSSDKProvider(creds, cfg) - return newAWSCloud(*cfg, aws) + return newAWSCloud(*cfg, aws, creds) }) } -func getSTSClient(sess *session.Session, roleARN, sourceARN string) (*sts.STS, error) { - klog.Infof("Using AWS assumed role %v", roleARN) - stsClient := sts.New(sess) - sourceAcct, err := GetSourceAccount(roleARN) - if err != nil { - return nil, err - } - reqHeaders := map[string]string{ - headerSourceAccount: sourceAcct, - } - if sourceARN != "" { - reqHeaders[headerSourceArn] = sourceARN - } - stsClient.Handlers.Sign.PushFront(func(s *request.Request) { - s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders)) - }) - klog.V(4).Infof("configuring STS client with extra headers, %v", reqHeaders) - return stsClient, nil -} - // readAWSCloudConfig reads an instance of AWSCloudConfig from config reader. func readAWSCloudConfig(config io.Reader) (*CloudConfig, error) { var cfg CloudConfig @@ -1368,42 +1331,43 @@ func azToRegion(az string) (string, error) { // newAWSCloud creates a new instance of AWSCloud. // AWSProvider and instanceId are primarily for tests -func newAWSCloud(cfg CloudConfig, awsServices Services) (*Cloud, error) { +func newAWSCloud(cfg CloudConfig, awsServices Services, credentials *stscreds.AssumeRoleProvider) (*Cloud, error) { + ctx := context.Background() // We have some state in the Cloud object - in particular the attaching map // Log so that if we are building multiple Cloud objects, it is obvious! klog.Infof("Building AWS cloudprovider") - metadata, err := awsServices.Metadata() + metadata, err := awsServices.Metadata(ctx) if err != nil { return nil, fmt.Errorf("error creating AWS metadata client: %q", err) } - regionName, err := getRegionFromMetadata(cfg, metadata) + regionName, err := getRegionFromMetadata(ctx, cfg, metadata) if err != nil { return nil, err } - ec2, err := awsServices.Compute(regionName) + ec2, err := awsServices.Compute(ctx, regionName, credentials) if err != nil { return nil, fmt.Errorf("error creating AWS EC2 client: %v", err) } - elb, err := awsServices.LoadBalancing(regionName) + elb, err := awsServices.LoadBalancing(ctx, regionName, credentials) if err != nil { return nil, fmt.Errorf("error creating AWS ELB client: %v", err) } - elbv2, err := awsServices.LoadBalancingV2(regionName) + elbv2, err := awsServices.LoadBalancingV2(ctx, regionName, credentials) if err != nil { return nil, fmt.Errorf("error creating AWS ELBV2 client: %v", err) } - asg, err := awsServices.Autoscaling(regionName) + asg, err := awsServices.Autoscaling(ctx, regionName, credentials) if err != nil { return nil, fmt.Errorf("error creating AWS autoscaling client: %v", err) } - kms, err := awsServices.KeyManagement(regionName) + kms, err := awsServices.KeyManagement(ctx, regionName, credentials) if err != nil { return nil, fmt.Errorf("error creating AWS key management client: %v", err) } @@ -1435,7 +1399,7 @@ func newAWSCloud(cfg CloudConfig, awsServices Services) (*Cloud, error) { } awsCloud.vpcID = cfg.Global.VPC } else { - selfAWSInstance, err := awsCloud.buildSelfAWSInstance() + selfAWSInstance, err := awsCloud.buildSelfAWSInstance(ctx) if err != nil { return nil, err } @@ -1449,7 +1413,7 @@ func newAWSCloud(cfg CloudConfig, awsServices Services) (*Cloud, error) { } } else { // TODO: Clean up double-API query - info, err := awsCloud.selfAWSInstance.describeInstance() + info, err := awsCloud.selfAWSInstance.describeInstance(ctx) if err != nil { return nil, err } @@ -1467,8 +1431,8 @@ func newAWSCloud(cfg CloudConfig, awsServices Services) (*Cloud, error) { } // NewAWSCloud calls and return new aws cloud from newAWSCloud with the supplied configuration -func NewAWSCloud(cfg CloudConfig, awsServices Services) (*Cloud, error) { - return newAWSCloud(cfg, awsServices) +func NewAWSCloud(cfg CloudConfig, awsServices Services, credentials *stscreds.AssumeRoleProvider) (*Cloud, error) { + return newAWSCloud(cfg, awsServices, credentials) } // Initialize passes a Kubernetes clientBuilder interface to the cloud provider @@ -1533,7 +1497,7 @@ func (c *Cloud) NodeAddresses(ctx context.Context, name types.NodeName) ([]v1.No // extractIPv4NodeAddresses maps the instance information from EC2 to an array of NodeAddresses. // This function will extract private and public IP addresses and their corresponding DNS names. -func extractIPv4NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) { +func extractIPv4NodeAddresses(instance *ec2types.Instance) ([]v1.NodeAddress, error) { // Not clear if the order matters here, but we might as well indicate a sensible preference order if instance == nil { @@ -1552,21 +1516,21 @@ func extractIPv4NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) return true } - return aws.Int64Value(instance.NetworkInterfaces[i].Attachment.DeviceIndex) < aws.Int64Value(instance.NetworkInterfaces[j].Attachment.DeviceIndex) + return aws.ToInt32(instance.NetworkInterfaces[i].Attachment.DeviceIndex) < aws.ToInt32(instance.NetworkInterfaces[j].Attachment.DeviceIndex) }) // handle internal network interfaces for _, networkInterface := range instance.NetworkInterfaces { // skip network interfaces that are not currently in use - if aws.StringValue(networkInterface.Status) != ec2.NetworkInterfaceStatusInUse { + if networkInterface.Status != ec2types.NetworkInterfaceStatusInUse { continue } for _, internalIP := range networkInterface.PrivateIpAddresses { - if ipAddress := aws.StringValue(internalIP.PrivateIpAddress); ipAddress != "" { + if ipAddress := aws.ToString(internalIP.PrivateIpAddress); ipAddress != "" { ip := netutils.ParseIPSloppy(ipAddress) if ip == nil { - return nil, fmt.Errorf("EC2 instance had invalid private address: %s (%q)", aws.StringValue(instance.InstanceId), ipAddress) + return nil, fmt.Errorf("EC2 instance had invalid private address: %s (%q)", aws.ToString(instance.InstanceId), ipAddress) } addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()}) } @@ -1574,22 +1538,22 @@ func extractIPv4NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) } // TODO: Other IP addresses (multiple ips)? - publicIPAddress := aws.StringValue(instance.PublicIpAddress) + publicIPAddress := aws.ToString(instance.PublicIpAddress) if publicIPAddress != "" { ip := netutils.ParseIPSloppy(publicIPAddress) if ip == nil { - return nil, fmt.Errorf("EC2 instance had invalid public address: %s (%s)", aws.StringValue(instance.InstanceId), publicIPAddress) + return nil, fmt.Errorf("EC2 instance had invalid public address: %s (%s)", aws.ToString(instance.InstanceId), publicIPAddress) } addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalIP, Address: ip.String()}) } - privateDNSName := aws.StringValue(instance.PrivateDnsName) + privateDNSName := aws.ToString(instance.PrivateDnsName) if privateDNSName != "" { addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalDNS, Address: privateDNSName}) addresses = append(addresses, v1.NodeAddress{Type: v1.NodeHostName, Address: privateDNSName}) } - publicDNSName := aws.StringValue(instance.PublicDnsName) + publicDNSName := aws.ToString(instance.PublicDnsName) if publicDNSName != "" { addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalDNS, Address: publicDNSName}) } @@ -1599,7 +1563,7 @@ func extractIPv4NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) // extractIPv6NodeAddresses maps the instance information from EC2 to an array of NodeAddresses // All IPv6 addresses are considered internal even if they are publicly routable. There are no instance DNS names associated with IPv6. -func extractIPv6NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) { +func extractIPv6NodeAddresses(instance *ec2types.Instance) ([]v1.NodeAddress, error) { // Not clear if the order matters here, but we might as well indicate a sensible preference order if instance == nil { @@ -1611,15 +1575,15 @@ func extractIPv6NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) // handle internal network interfaces with IPv6 addresses for _, networkInterface := range instance.NetworkInterfaces { // skip network interfaces that are not currently in use - if aws.StringValue(networkInterface.Status) != ec2.NetworkInterfaceStatusInUse || len(networkInterface.Ipv6Addresses) == 0 { + if networkInterface.Status != ec2types.NetworkInterfaceStatusInUse || len(networkInterface.Ipv6Addresses) == 0 { continue } // return only the "first" address for each ENI - internalIPv6 := aws.StringValue(networkInterface.Ipv6Addresses[0].Ipv6Address) + internalIPv6 := aws.ToString(networkInterface.Ipv6Addresses[0].Ipv6Address) ip := net.ParseIP(internalIPv6) if ip == nil { - return nil, fmt.Errorf("EC2 instance had invalid IPv6 address: %s (%q)", aws.StringValue(instance.InstanceId), internalIPv6) + return nil, fmt.Errorf("EC2 instance had invalid IPv6 address: %s (%q)", aws.ToString(instance.InstanceId), internalIPv6) } addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()}) } @@ -1647,7 +1611,7 @@ func (c *Cloud) NodeAddressesByProviderID(ctx context.Context, providerID string } if IsFargateNode(string(instanceID)) { - eni, err := c.describeNetworkInterfaces(string(instanceID)) + eni, err := c.describeNetworkInterfaces(ctx, string(instanceID)) if eni == nil || err != nil { return nil, err } @@ -1658,7 +1622,7 @@ func (c *Cloud) NodeAddressesByProviderID(ctx context.Context, providerID string for _, family := range c.cfg.Global.NodeIPFamilies { switch family { case "ipv4": - nodeAddresses := getNodeAddressesForFargateNode(aws.StringValue(eni.PrivateDnsName), aws.StringValue(eni.PrivateIpAddress)) + nodeAddresses := getNodeAddressesForFargateNode(aws.ToString(eni.PrivateDnsName), aws.ToString(eni.PrivateIpAddress)) addresses = append(addresses, nodeAddresses...) case "ipv6": if eni.Ipv6Addresses == nil || len(eni.Ipv6Addresses) == 0 { @@ -1666,14 +1630,14 @@ func (c *Cloud) NodeAddressesByProviderID(ctx context.Context, providerID string continue } internalIPv6Address := eni.Ipv6Addresses[0].Ipv6Address - nodeAddresses := getNodeAddressesForFargateNode(aws.StringValue(eni.PrivateDnsName), aws.StringValue(internalIPv6Address)) + nodeAddresses := getNodeAddressesForFargateNode(aws.ToString(eni.PrivateDnsName), aws.ToString(internalIPv6Address)) addresses = append(addresses, nodeAddresses...) } } return addresses, nil } - instance, err := describeInstance(c.ec2, instanceID) + instance, err := describeInstance(ctx, c.ec2, instanceID) if err != nil { return nil, err } @@ -1709,15 +1673,15 @@ func (c *Cloud) InstanceExistsByProviderID(ctx context.Context, providerID strin } if IsFargateNode(string(instanceID)) { - eni, err := c.describeNetworkInterfaces(string(instanceID)) + eni, err := c.describeNetworkInterfaces(ctx, string(instanceID)) return eni != nil, err } request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []string{string(instanceID)}, } - instances, err := c.ec2.DescribeInstances(request) + instances, err := c.ec2.DescribeInstances(ctx, request) if err != nil { // if err is InstanceNotFound, return false with no error if IsAWSErrorInstanceNotFound(err) { @@ -1733,7 +1697,7 @@ func (c *Cloud) InstanceExistsByProviderID(ctx context.Context, providerID strin } state := instances[0].State.Name - if *state == ec2.InstanceStateNameTerminated { + if state == ec2types.InstanceStateNameTerminated { klog.Warningf("the instance %s is terminated", instanceID) return false, nil } @@ -1749,15 +1713,15 @@ func (c *Cloud) InstanceShutdownByProviderID(ctx context.Context, providerID str } if IsFargateNode(string(instanceID)) { - eni, err := c.describeNetworkInterfaces(string(instanceID)) + eni, err := c.describeNetworkInterfaces(ctx, string(instanceID)) return eni != nil, err } request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []string{string(instanceID)}, } - instances, err := c.ec2.DescribeInstances(request) + instances, err := c.ec2.DescribeInstances(ctx, request) if err != nil { return false, err } @@ -1773,9 +1737,9 @@ func (c *Cloud) InstanceShutdownByProviderID(ctx context.Context, providerID str instance := instances[0] if instance.State != nil { - state := aws.StringValue(instance.State.Name) + state := instance.State.Name // valid state for detaching volumes - if state == ec2.InstanceStateNameStopped { + if state == ec2types.InstanceStateNameStopped { return true, nil } } @@ -1789,7 +1753,7 @@ func (c *Cloud) InstanceID(ctx context.Context, nodeName types.NodeName) (string if c.selfAWSInstance.nodeName == nodeName { return "/" + c.selfAWSInstance.availabilityZone + "/" + c.selfAWSInstance.awsID, nil } - inst, err := c.getInstanceByNodeName(nodeName) + inst, err := c.getInstanceByNodeName(ctx, nodeName) if err != nil { if err == cloudprovider.InstanceNotFound { // The Instances interface requires that we return InstanceNotFound (without wrapping) @@ -1797,7 +1761,7 @@ func (c *Cloud) InstanceID(ctx context.Context, nodeName types.NodeName) (string } return "", fmt.Errorf("getInstanceByNodeName failed for %q with %q", nodeName, err) } - return "/" + aws.StringValue(inst.Placement.AvailabilityZone) + "/" + aws.StringValue(inst.InstanceId), nil + return "/" + aws.ToString(inst.Placement.AvailabilityZone) + "/" + aws.ToString(inst.InstanceId), nil } // InstanceTypeByProviderID returns the cloudprovider instance type of the node with the specified unique providerID @@ -1813,12 +1777,12 @@ func (c *Cloud) InstanceTypeByProviderID(ctx context.Context, providerID string) return "", nil } - instance, err := describeInstance(c.ec2, instanceID) + instance, err := describeInstance(ctx, c.ec2, instanceID) if err != nil { return "", err } - return aws.StringValue(instance.InstanceType), nil + return string(instance.InstanceType), nil } // InstanceType returns the type of the node with the specified nodeName. @@ -1826,16 +1790,16 @@ func (c *Cloud) InstanceType(ctx context.Context, nodeName types.NodeName) (stri if c.selfAWSInstance.nodeName == nodeName { return c.selfAWSInstance.instanceType, nil } - inst, err := c.getInstanceByNodeName(nodeName) + inst, err := c.getInstanceByNodeName(ctx, nodeName) if err != nil { return "", fmt.Errorf("getInstanceByNodeName failed for %q with %q", nodeName, err) } - return aws.StringValue(inst.InstanceType), nil + return string(inst.InstanceType), nil } // GetCandidateZonesForDynamicVolume retrieves a list of all the zones in which nodes are running // It currently involves querying all instances -func (c *Cloud) GetCandidateZonesForDynamicVolume() (sets.String, error) { +func (c *Cloud) GetCandidateZonesForDynamicVolume(ctx context.Context) (sets.String, error) { // We don't currently cache this; it is currently used only in volume // creation which is expected to be a comparatively rare occurrence. @@ -1846,12 +1810,12 @@ func (c *Cloud) GetCandidateZonesForDynamicVolume() (sets.String, error) { // filters than to call it once with a tag filter that results in a logical // OR. For really large clusters the logical OR will result in EC2 API rate // limiting. - instances := []*ec2.Instance{} + instances := []*ec2types.Instance{} - baseFilters := []*ec2.Filter{newEc2Filter("instance-state-name", "running")} + baseFilters := []ec2types.Filter{newEc2Filter("instance-state-name", "running")} filters := c.tagging.addFilters(baseFilters) - di, err := c.describeInstances(filters) + di, err := c.describeInstances(ctx, filters) if err != nil { return nil, err } @@ -1860,7 +1824,7 @@ func (c *Cloud) GetCandidateZonesForDynamicVolume() (sets.String, error) { if c.tagging.usesLegacyTags { filters = c.tagging.addLegacyFilters(baseFilters) - di, err = c.describeInstances(filters) + di, err = c.describeInstances(ctx, filters) if err != nil { return nil, err } @@ -1880,19 +1844,19 @@ func (c *Cloud) GetCandidateZonesForDynamicVolume() (sets.String, error) { // This is a short-term workaround until the scheduler takes care of zone selection master := false for _, tag := range instance.Tags { - tagKey := aws.StringValue(tag.Key) + tagKey := aws.ToString(tag.Key) if awsTagNameMasterRoles.Has(tagKey) { master = true } } if master { - klog.V(4).Infof("Ignoring master instance %q in zone discovery", aws.StringValue(instance.InstanceId)) + klog.V(4).Infof("Ignoring master instance %q in zone discovery", aws.ToString(instance.InstanceId)) continue } if instance.Placement != nil { - zone := aws.StringValue(instance.Placement.AvailabilityZone) + zone := aws.ToString(instance.Placement.AvailabilityZone) zones.Insert(zone) } } @@ -1919,7 +1883,7 @@ func (c *Cloud) GetZoneByProviderID(ctx context.Context, providerID string) (clo } if IsFargateNode(string(instanceID)) { - eni, err := c.describeNetworkInterfaces(string(instanceID)) + eni, err := c.describeNetworkInterfaces(ctx, string(instanceID)) if eni == nil || err != nil { return cloudprovider.Zone{}, err } @@ -1929,7 +1893,7 @@ func (c *Cloud) GetZoneByProviderID(ctx context.Context, providerID string) (clo }, nil } - instance, err := c.getInstanceByID(string(instanceID)) + instance, err := c.getInstanceByID(ctx, string(instanceID)) if err != nil { return cloudprovider.Zone{}, err } @@ -1946,7 +1910,7 @@ func (c *Cloud) GetZoneByProviderID(ctx context.Context, providerID string) (clo // This is particularly useful in external cloud providers where the kubelet // does not initialize node data. func (c *Cloud) GetZoneByNodeName(ctx context.Context, nodeName types.NodeName) (cloudprovider.Zone, error) { - instance, err := c.getInstanceByNodeName(nodeName) + instance, err := c.getInstanceByNodeName(ctx, nodeName) if err != nil { return cloudprovider.Zone{}, err } @@ -1965,11 +1929,10 @@ func IsAWSErrorInstanceNotFound(err error) bool { return false } - if awsError, ok := err.(awserr.Error); ok { - if awsError.Code() == ec2.UnsuccessfulInstanceCreditSpecificationErrorCodeInvalidInstanceIdNotFound { - return true - } - } else if strings.Contains(err.Error(), ec2.UnsuccessfulInstanceCreditSpecificationErrorCodeInvalidInstanceIdNotFound) { + var ae smithy.APIError + if errors.As(err, &ae) { + return ae.ErrorCode() == string(ec2types.UnsuccessfulInstanceCreditSpecificationErrorCodeInstanceNotFound) + } else if strings.Contains(err.Error(), string(ec2types.UnsuccessfulInstanceCreditSpecificationErrorCodeInstanceNotFound)) { // In places like https://github.com/kubernetes/cloud-provider-aws/blob/1c6194aad0122ab44504de64187e3d1a7415b198/pkg/providers/v1/aws.go#L1007, // the error has been transformed into something else so check the error string to see if it contains the error code we're looking for. return true @@ -2005,27 +1968,27 @@ type awsInstance struct { } // newAWSInstance creates a new awsInstance object -func newAWSInstance(ec2Service EC2, instance *ec2.Instance) *awsInstance { +func newAWSInstance(ec2Service EC2, instance *ec2types.Instance) *awsInstance { az := "" if instance.Placement != nil { - az = aws.StringValue(instance.Placement.AvailabilityZone) + az = aws.ToString(instance.Placement.AvailabilityZone) } self := &awsInstance{ ec2: ec2Service, - awsID: aws.StringValue(instance.InstanceId), + awsID: aws.ToString(instance.InstanceId), nodeName: mapInstanceToNodeName(instance), availabilityZone: az, - instanceType: aws.StringValue(instance.InstanceType), - vpcID: aws.StringValue(instance.VpcId), - subnetID: aws.StringValue(instance.SubnetId), + instanceType: string(instance.InstanceType), + vpcID: aws.ToString(instance.VpcId), + subnetID: aws.ToString(instance.SubnetId), } return self } // Gets the full information about this instance from the EC2 API -func (i *awsInstance) describeInstance() (*ec2.Instance, error) { - return describeInstance(i.ec2, InstanceID(i.awsID)) +func (i *awsInstance) describeInstance(ctx context.Context) (*ec2types.Instance, error) { + return describeInstance(ctx, i.ec2, InstanceID(i.awsID)) } // Gets the mountDevice already assigned to the volume, or assigns an unused mountDevice. @@ -2033,24 +1996,24 @@ func (i *awsInstance) describeInstance() (*ec2.Instance, error) { // Otherwise the mountDevice is assigned by finding the first available mountDevice, and it is returned with alreadyAttached=false. func (c *Cloud) getMountDevice( i *awsInstance, - info *ec2.Instance, + info *ec2types.Instance, volumeID EBSVolumeID, assign bool) (assigned mountDevice, alreadyAttached bool, err error) { deviceMappings := map[mountDevice]EBSVolumeID{} volumeStatus := map[EBSVolumeID]string{} // for better logging of volume status for _, blockDevice := range info.BlockDeviceMappings { - name := aws.StringValue(blockDevice.DeviceName) + name := aws.ToString(blockDevice.DeviceName) name = strings.TrimPrefix(name, "/dev/sd") name = strings.TrimPrefix(name, "/dev/xvd") if len(name) < 1 || len(name) > 2 { - klog.Warningf("Unexpected EBS DeviceName: %q", aws.StringValue(blockDevice.DeviceName)) + klog.Warningf("Unexpected EBS DeviceName: %q", aws.ToString(blockDevice.DeviceName)) } if blockDevice.Ebs != nil && blockDevice.Ebs.VolumeId != nil { - volumeStatus[EBSVolumeID(*blockDevice.Ebs.VolumeId)] = aws.StringValue(blockDevice.Ebs.Status) + volumeStatus[EBSVolumeID(*blockDevice.Ebs.VolumeId)] = string(blockDevice.Ebs.Status) } - deviceMappings[mountDevice(name)] = EBSVolumeID(aws.StringValue(blockDevice.Ebs.VolumeId)) + deviceMappings[mountDevice(name)] = EBSVolumeID(aws.ToString(blockDevice.Ebs.VolumeId)) } // We lock to prevent concurrent mounts from conflicting @@ -2159,9 +2122,10 @@ func newAWSDisk(aws *Cloud, name KubernetesVolumeID) (*awsDisk, error) { // and returns true in case the AWS error is "InvalidVolume.NotFound", false otherwise func isAWSErrorVolumeNotFound(err error) bool { if err != nil { - if awsError, ok := err.(awserr.Error); ok { + var ae smithy.APIError + if errors.As(err, &ae) { // https://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html - if awsError.Code() == "InvalidVolume.NotFound" { + if ae.ErrorCode() == "InvalidVolume.NotFound" { return true } } @@ -2170,32 +2134,32 @@ func isAWSErrorVolumeNotFound(err error) bool { } // Gets the full information about this volume from the EC2 API -func (d *awsDisk) describeVolume() (*ec2.Volume, error) { +func (d *awsDisk) describeVolume(ctx context.Context) (*ec2types.Volume, error) { volumeID := d.awsID request := &ec2.DescribeVolumesInput{ - VolumeIds: []*string{volumeID.awsString()}, + VolumeIds: []string{string(volumeID)}, } - volumes, err := d.ec2.DescribeVolumes(request) + volumes, err := d.ec2.DescribeVolumes(ctx, request) if err != nil { - return nil, err + return &ec2types.Volume{}, err } if len(volumes) == 0 { - return nil, fmt.Errorf("no volumes found") + return &ec2types.Volume{}, fmt.Errorf("no volumes found") } if len(volumes) > 1 { - return nil, fmt.Errorf("multiple volumes found") + return &ec2types.Volume{}, fmt.Errorf("multiple volumes found") } - return volumes[0], nil + return &volumes[0], nil } -func (d *awsDisk) describeVolumeModification() (*ec2.VolumeModification, error) { +func (d *awsDisk) describeVolumeModification(ctx context.Context) (*ec2types.VolumeModification, error) { volumeID := d.awsID request := &ec2.DescribeVolumesModificationsInput{ - VolumeIds: []*string{volumeID.awsString()}, + VolumeIds: []string{string(volumeID)}, } - volumeMods, err := d.ec2.DescribeVolumeModifications(request) + volumeMods, err := d.ec2.DescribeVolumeModifications(ctx, request) if err != nil { return nil, fmt.Errorf("error describing volume modification %s with %v", volumeID, err) @@ -2205,17 +2169,17 @@ func (d *awsDisk) describeVolumeModification() (*ec2.VolumeModification, error) return nil, fmt.Errorf("no volume modifications found for %s", volumeID) } lastIndex := len(volumeMods) - 1 - return volumeMods[lastIndex], nil + return &volumeMods[lastIndex], nil } -func (d *awsDisk) modifyVolume(requestGiB int64) (int64, error) { +func (d *awsDisk) modifyVolume(ctx context.Context, requestGiB int64) (int64, error) { volumeID := d.awsID request := &ec2.ModifyVolumeInput{ VolumeId: volumeID.awsString(), - Size: aws.Int64(requestGiB), + Size: aws.Int32(int32(requestGiB)), } - output, err := d.ec2.ModifyVolume(request) + output, err := d.ec2.ModifyVolume(ctx, request) if err != nil { modifyError := fmt.Errorf("AWS modifyVolume failed for %s with %v", volumeID, err) return requestGiB, modifyError @@ -2223,8 +2187,8 @@ func (d *awsDisk) modifyVolume(requestGiB int64) (int64, error) { volumeModification := output.VolumeModification - if aws.StringValue(volumeModification.ModificationState) == ec2.VolumeModificationStateCompleted { - return aws.Int64Value(volumeModification.TargetSize), nil + if volumeModification.ModificationState == ec2types.VolumeModificationStateCompleted { + return int64(aws.ToInt32(volumeModification.TargetSize)), nil } backoff := wait.Backoff{ @@ -2234,7 +2198,7 @@ func (d *awsDisk) modifyVolume(requestGiB int64) (int64, error) { } checkForResize := func() (bool, error) { - volumeModification, err := d.describeVolumeModification() + volumeModification, err := d.describeVolumeModification(ctx) if err != nil { return false, err @@ -2242,7 +2206,7 @@ func (d *awsDisk) modifyVolume(requestGiB int64) (int64, error) { // According to https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/monitoring_mods.html // Size changes usually take a few seconds to complete and take effect after a volume is in the Optimizing state. - if aws.StringValue(volumeModification.ModificationState) == ec2.VolumeModificationStateOptimizing { + if volumeModification.ModificationState == ec2types.VolumeModificationStateOptimizing { return true, nil } return false, nil @@ -2275,7 +2239,7 @@ func (c *Cloud) applyUnSchedulableTaint(nodeName types.NodeName, reason string) // waitForAttachmentStatus polls until the attachment status is the expected value // On success, it returns the last attachment state. -func (d *awsDisk) waitForAttachmentStatus(status string, expectedInstance, expectedDevice string, alreadyAttached bool) (*ec2.VolumeAttachment, error) { +func (d *awsDisk) waitForAttachmentStatus(ctx context.Context, status string, expectedInstance, expectedDevice string, alreadyAttached bool) (*ec2types.VolumeAttachment, error) { backoff := wait.Backoff{ Duration: volumeAttachmentStatusPollDelay, Factor: volumeAttachmentStatusFactor, @@ -2293,19 +2257,19 @@ func (d *awsDisk) waitForAttachmentStatus(status string, expectedInstance, expec // process the request. time.Sleep(getInitialAttachDetachDelay(status)) - var attachment *ec2.VolumeAttachment + var attachment *ec2types.VolumeAttachment err := wait.ExponentialBackoff(backoff, func() (bool, error) { - info, err := d.describeVolume() + info, err := d.describeVolume(ctx) if err != nil { // The VolumeNotFound error is special -- we don't need to wait for it to repeat if isAWSErrorVolumeNotFound(err) { if status == volumeDetachedStatus { // The disk doesn't exist, assume it's detached, log warning and stop waiting klog.Warningf("Waiting for volume %q to be detached but the volume does not exist", d.awsID) - stateStr := "detached" - attachment = &ec2.VolumeAttachment{ - State: &stateStr, + stateStr := ec2types.VolumeAttachmentStateDetached + attachment = &ec2types.VolumeAttachment{ + State: stateStr, } return true, nil } @@ -2335,13 +2299,8 @@ func (d *awsDisk) waitForAttachmentStatus(status string, expectedInstance, expec // Shouldn't happen; log so we know if it is klog.Warningf("Found multiple attachments for volume %q: %v", d.awsID, info) } - if a.State != nil { - attachment = a - attachmentStatus = *a.State - } else { - // Shouldn't happen; log so we know if it is - klog.Warningf("Ignoring nil attachment state for volume %q: %v", d.awsID, a) - } + attachment = &a + attachmentStatus = string(a.State) } if attachmentStatus == "" { attachmentStatus = volumeDetachedStatus @@ -2351,7 +2310,7 @@ func (d *awsDisk) waitForAttachmentStatus(status string, expectedInstance, expec // For example, we're waiting for a volume to be attached as /dev/xvdba, but AWS can tell us it's // attached as /dev/xvdbb, where it was attached before and it was already detached. // Retry couple of times, hoping AWS starts reporting the right status. - device := aws.StringValue(attachment.Device) + device := aws.ToString(attachment.Device) if expectedDevice != "" && device != "" && device != expectedDevice { klog.Warningf("Expected device %s %s for volume %s, but found device %s %s", expectedDevice, status, d.name, device, attachmentStatus) errorCount++ @@ -2361,7 +2320,7 @@ func (d *awsDisk) waitForAttachmentStatus(status string, expectedInstance, expec } return false, nil } - instanceID := aws.StringValue(attachment.InstanceId) + instanceID := aws.ToString(attachment.InstanceId) if expectedInstance != "" && instanceID != "" && instanceID != expectedInstance { klog.Warningf("Expected instance %s/%s for volume %s, but found instance %s/%s", expectedInstance, status, d.name, instanceID, attachmentStatus) errorCount++ @@ -2393,15 +2352,16 @@ func (d *awsDisk) waitForAttachmentStatus(status string, expectedInstance, expec } // Deletes the EBS disk -func (d *awsDisk) deleteVolume() (bool, error) { +func (d *awsDisk) deleteVolume(ctx context.Context) (bool, error) { request := &ec2.DeleteVolumeInput{VolumeId: d.awsID.awsString()} - _, err := d.ec2.DeleteVolume(request) + _, err := d.ec2.DeleteVolume(ctx, request) if err != nil { if isAWSErrorVolumeNotFound(err) { return false, nil } - if awsError, ok := err.(awserr.Error); ok { - if awsError.Code() == "VolumeInUse" { + var ae smithy.APIError + if errors.As(err, &ae) { + if ae.ErrorCode() == "VolumeInUse" { return false, volerr.NewDeletedVolumeInUseError(err.Error()) } } @@ -2412,11 +2372,11 @@ func (d *awsDisk) deleteVolume() (bool, error) { // Builds the awsInstance for the EC2 instance on which we are running. // This is called when the AWSCloud is initialized, and should not be called otherwise (because the awsInstance for the local instance is a singleton with drive mapping state) -func (c *Cloud) buildSelfAWSInstance() (*awsInstance, error) { +func (c *Cloud) buildSelfAWSInstance(ctx context.Context) (*awsInstance, error) { if c.selfAWSInstance != nil { panic("do not call buildSelfAWSInstance directly") } - instanceID, err := c.metadata.GetMetadata("instance-id") + instanceIDMetadata, err := c.metadata.GetMetadata(ctx, &imds.GetMetadataInput{Path: "instance-id"}) if err != nil { return nil, fmt.Errorf("error fetching instance-id from ec2 metadata service: %q", err) } @@ -2429,27 +2389,34 @@ func (c *Cloud) buildSelfAWSInstance() (*awsInstance, error) { // information from the instance returned by the EC2 API - it is a // single API call to get all the information, and it means we don't // have two code paths. - instance, err := c.getInstanceByID(instanceID) + instanceIDBytes, err := io.ReadAll(instanceIDMetadata.Content) + if err != nil { + return nil, fmt.Errorf("unable to parse instance id: %q", err) + } + defer instanceIDMetadata.Content.Close() + + instance, err := c.getInstanceByID(ctx, string(instanceIDBytes)) if err != nil { - return nil, fmt.Errorf("error finding instance %s: %q", instanceID, err) + return nil, fmt.Errorf("error finding instance %s: %q", string(instanceIDBytes), err) } return newAWSInstance(c.ec2, instance), nil } // wrapAttachError wraps the error returned by an AttachVolume request with // additional information, if needed and possible. -func wrapAttachError(err error, disk *awsDisk, instance string) error { - if awsError, ok := err.(awserr.Error); ok { - if awsError.Code() == "VolumeInUse" { - info, err := disk.describeVolume() +func wrapAttachError(ctx context.Context, err error, disk *awsDisk, instance string) error { + var ae smithy.APIError + if errors.As(err, &ae) { + if ae.ErrorCode() == "VolumeInUse" { + info, err := disk.describeVolume(ctx) if err != nil { klog.Errorf("Error describing volume %q: %q", disk.awsID, err) } else { for _, a := range info.Attachments { - if disk.awsID != EBSVolumeID(aws.StringValue(a.VolumeId)) { - klog.Warningf("Expected to get attachment info of volume %q but instead got info of %q", disk.awsID, aws.StringValue(a.VolumeId)) - } else if aws.StringValue(a.State) == "attached" { - return fmt.Errorf("error attaching EBS volume %q to instance %q: %q. The volume is currently attached to instance %q", disk.awsID, instance, awsError, aws.StringValue(a.InstanceId)) + if disk.awsID != EBSVolumeID(aws.ToString(a.VolumeId)) { + klog.Warningf("Expected to get attachment info of volume %q but instead got info of %q", disk.awsID, aws.ToString(a.VolumeId)) + } else if a.State == ec2types.VolumeAttachmentStateAttached { + return fmt.Errorf("error attaching EBS volume %q to instance %q: %q. The volume is currently attached to instance %q", disk.awsID, instance, ae, aws.ToString(a.InstanceId)) } } } @@ -2459,13 +2426,13 @@ func wrapAttachError(err error, disk *awsDisk, instance string) error { } // AttachDisk implements Volumes.AttachDisk -func (c *Cloud) AttachDisk(diskName KubernetesVolumeID, nodeName types.NodeName) (string, error) { +func (c *Cloud) AttachDisk(ctx context.Context, diskName KubernetesVolumeID, nodeName types.NodeName) (string, error) { disk, err := newAWSDisk(c, diskName) if err != nil { return "", err } - awsInstance, info, err := c.getFullInstance(nodeName) + awsInstance, info, err := c.getFullInstance(ctx, nodeName) if err != nil { return "", fmt.Errorf("error finding instance %s: %q", nodeName, err) } @@ -2498,7 +2465,7 @@ func (c *Cloud) AttachDisk(diskName KubernetesVolumeID, nodeName types.NodeName) ec2Device := "/dev/xvd" + string(mountDevice) if !alreadyAttached { - available, err := c.checkIfAvailable(disk, "attaching", awsInstance.awsID) + available, err := c.checkIfAvailable(ctx, disk, "attaching", awsInstance.awsID) if err != nil { klog.Error(err) } @@ -2513,11 +2480,11 @@ func (c *Cloud) AttachDisk(diskName KubernetesVolumeID, nodeName types.NodeName) VolumeId: disk.awsID.awsString(), } - attachResponse, err := c.ec2.AttachVolume(request) + attachResponse, err := c.ec2.AttachVolume(ctx, request) if err != nil { attachEnded = true // TODO: Check if the volume was concurrently attached? - return "", wrapAttachError(err, disk, awsInstance.awsID) + return "", wrapAttachError(ctx, err, disk, awsInstance.awsID) } if da, ok := c.deviceAllocators[awsInstance.nodeName]; ok { da.Deprioritize(mountDevice) @@ -2525,7 +2492,7 @@ func (c *Cloud) AttachDisk(diskName KubernetesVolumeID, nodeName types.NodeName) klog.V(2).Infof("AttachVolume volume=%q instance=%q request returned %v", disk.awsID, awsInstance.awsID, attachResponse) } - attachment, err := disk.waitForAttachmentStatus("attached", awsInstance.awsID, ec2Device, alreadyAttached) + attachment, err := disk.waitForAttachmentStatus(ctx, "attached", awsInstance.awsID, ec2Device, alreadyAttached) if err != nil { if err == wait.ErrWaitTimeout { @@ -2544,20 +2511,20 @@ func (c *Cloud) AttachDisk(diskName KubernetesVolumeID, nodeName types.NodeName) // Impossible? return "", fmt.Errorf("unexpected state: attachment nil after attached %q to %q", diskName, nodeName) } - if ec2Device != aws.StringValue(attachment.Device) { + if ec2Device != aws.ToString(attachment.Device) { // Already checked in waitForAttachmentStatus(), but just to be sure... - return "", fmt.Errorf("disk attachment of %q to %q failed: requested device %q but found %q", diskName, nodeName, ec2Device, aws.StringValue(attachment.Device)) + return "", fmt.Errorf("disk attachment of %q to %q failed: requested device %q but found %q", diskName, nodeName, ec2Device, aws.ToString(attachment.Device)) } - if awsInstance.awsID != aws.StringValue(attachment.InstanceId) { - return "", fmt.Errorf("disk attachment of %q to %q failed: requested instance %q but found %q", diskName, nodeName, awsInstance.awsID, aws.StringValue(attachment.InstanceId)) + if awsInstance.awsID != aws.ToString(attachment.InstanceId) { + return "", fmt.Errorf("disk attachment of %q to %q failed: requested instance %q but found %q", diskName, nodeName, awsInstance.awsID, aws.ToString(attachment.InstanceId)) } return hostDevice, nil } // DetachDisk implements Volumes.DetachDisk -func (c *Cloud) DetachDisk(diskName KubernetesVolumeID, nodeName types.NodeName) (string, error) { - diskInfo, attached, err := c.checkIfAttachedToNode(diskName, nodeName) +func (c *Cloud) DetachDisk(ctx context.Context, diskName KubernetesVolumeID, nodeName types.NodeName) (string, error) { + diskInfo, attached, err := c.checkIfAttachedToNode(ctx, diskName, nodeName) if err != nil { if isAWSErrorVolumeNotFound(err) { // Someone deleted the volume being detached; complain, but do nothing else and return success @@ -2594,7 +2561,7 @@ func (c *Cloud) DetachDisk(diskName KubernetesVolumeID, nodeName types.NodeName) VolumeId: diskInfo.disk.awsID.awsString(), } - response, err := c.ec2.DetachVolume(&request) + response, err := c.ec2.DetachVolume(ctx, &request) if err != nil { return "", fmt.Errorf("error detaching EBS volume %q from %q: %q", diskInfo.disk.awsID, awsInstance.awsID, err) } @@ -2603,7 +2570,7 @@ func (c *Cloud) DetachDisk(diskName KubernetesVolumeID, nodeName types.NodeName) return "", errors.New("no response from DetachVolume") } - attachment, err := diskInfo.disk.waitForAttachmentStatus("detached", awsInstance.awsID, "", false) + attachment, err := diskInfo.disk.waitForAttachmentStatus(ctx, "detached", awsInstance.awsID, "", false) if err != nil { return "", err } @@ -2626,9 +2593,9 @@ func (c *Cloud) DetachDisk(diskName KubernetesVolumeID, nodeName types.NodeName) } // CreateDisk implements Volumes.CreateDisk -func (c *Cloud) CreateDisk(volumeOptions *VolumeOptions) (KubernetesVolumeID, error) { - var createType string - var iops int64 +func (c *Cloud) CreateDisk(ctx context.Context, volumeOptions *VolumeOptions) (KubernetesVolumeID, error) { + var createType ec2types.VolumeType + var iops int32 switch volumeOptions.VolumeType { case VolumeTypeGP2, VolumeTypeSC1, VolumeTypeST1: createType = volumeOptions.VolumeType @@ -2638,7 +2605,7 @@ func (c *Cloud) CreateDisk(volumeOptions *VolumeOptions) (KubernetesVolumeID, er // for IOPS constraints. AWS will throw an error if IOPS per GB gets out // of supported bounds, no need to check it here. createType = volumeOptions.VolumeType - iops = int64(volumeOptions.CapacityGB * volumeOptions.IOPSPerGB) + iops = int32(volumeOptions.CapacityGB * volumeOptions.IOPSPerGB) // Cap at min/max total IOPS, AWS would throw an error if it gets too // low/high. @@ -2658,43 +2625,43 @@ func (c *Cloud) CreateDisk(volumeOptions *VolumeOptions) (KubernetesVolumeID, er request := &ec2.CreateVolumeInput{} request.AvailabilityZone = aws.String(volumeOptions.AvailabilityZone) - request.Size = aws.Int64(int64(volumeOptions.CapacityGB)) - request.VolumeType = aws.String(createType) + request.Size = aws.Int32(int32(volumeOptions.CapacityGB)) + request.VolumeType = createType request.Encrypted = aws.Bool(volumeOptions.Encrypted) if len(volumeOptions.KmsKeyID) > 0 { request.KmsKeyId = aws.String(volumeOptions.KmsKeyID) request.Encrypted = aws.Bool(true) } if iops > 0 { - request.Iops = aws.Int64(iops) + request.Iops = aws.Int32(iops) } tags := volumeOptions.Tags tags = c.tagging.buildTags(ResourceLifecycleOwned, tags) - var tagList []*ec2.Tag + var tagList []ec2types.Tag for k, v := range tags { - tagList = append(tagList, &ec2.Tag{ + tagList = append(tagList, ec2types.Tag{ Key: aws.String(k), Value: aws.String(v), }) } - request.TagSpecifications = append(request.TagSpecifications, &ec2.TagSpecification{ + request.TagSpecifications = append(request.TagSpecifications, ec2types.TagSpecification{ Tags: tagList, - ResourceType: aws.String(ec2.ResourceTypeVolume), + ResourceType: ec2types.ResourceTypeVolume, }) - response, err := c.ec2.CreateVolume(request) + response, err := c.ec2.CreateVolume(ctx, request) if err != nil { return KubernetesVolumeID(""), err } - awsID := EBSVolumeID(aws.StringValue(response.VolumeId)) + awsID := EBSVolumeID(aws.ToString(response.VolumeId)) if awsID == "" { return KubernetesVolumeID(""), fmt.Errorf("VolumeID was not returned by CreateVolume") } - volumeName := KubernetesVolumeID("aws://" + aws.StringValue(response.AvailabilityZone) + "/" + string(awsID)) + volumeName := KubernetesVolumeID("aws://" + aws.ToString(response.AvailabilityZone) + "/" + string(awsID)) - err = c.waitUntilVolumeAvailable(volumeName) + err = c.waitUntilVolumeAvailable(ctx, volumeName) if err != nil { // AWS has a bad habbit of reporting success when creating a volume with // encryption keys that either don't exists or have wrong permissions. @@ -2711,7 +2678,7 @@ func (c *Cloud) CreateDisk(volumeOptions *VolumeOptions) (KubernetesVolumeID, er if newDiskError != nil { klog.Errorf("Failed to delete the volume %v due to error: %v", volumeName, newDiskError) } else { - if _, deleteVolumeError := awsDisk.deleteVolume(); deleteVolumeError != nil { + if _, deleteVolumeError := awsDisk.deleteVolume(ctx); deleteVolumeError != nil { klog.Errorf("Failed to delete the volume %v due to error: %v", volumeName, deleteVolumeError) } else { klog.V(5).Infof("%v is deleted because it is not in desired state after waiting", volumeName) @@ -2724,7 +2691,7 @@ func (c *Cloud) CreateDisk(volumeOptions *VolumeOptions) (KubernetesVolumeID, er return volumeName, nil } -func (c *Cloud) waitUntilVolumeAvailable(volumeName KubernetesVolumeID) error { +func (c *Cloud) waitUntilVolumeAvailable(ctx context.Context, volumeName KubernetesVolumeID) error { disk, err := newAWSDisk(c, volumeName) if err != nil { // Unreachable code @@ -2737,33 +2704,30 @@ func (c *Cloud) waitUntilVolumeAvailable(volumeName KubernetesVolumeID) error { Steps: volumeCreateBackoffSteps, } err = wait.ExponentialBackoff(backoff, func() (done bool, err error) { - vol, err := disk.describeVolume() + vol, err := disk.describeVolume(ctx) if err != nil { return true, err } - if vol.State != nil { - switch *vol.State { - case "available": - // The volume is Available, it won't be deleted now. - return true, nil - case "creating": - return false, nil - default: - return true, fmt.Errorf("unexpected State of newly created AWS EBS volume %s: %q", volumeName, *vol.State) - } + switch vol.State { + case ec2types.VolumeStateAvailable: + // The volume is Available, it won't be deleted now. + return true, nil + case ec2types.VolumeStateCreating: + return false, nil + default: + return true, fmt.Errorf("unexpected State of newly created AWS EBS volume %s: %q", volumeName, vol.State) } - return false, nil }) return err } // DeleteDisk implements Volumes.DeleteDisk -func (c *Cloud) DeleteDisk(volumeName KubernetesVolumeID) (bool, error) { +func (c *Cloud) DeleteDisk(ctx context.Context, volumeName KubernetesVolumeID) (bool, error) { awsDisk, err := newAWSDisk(c, volumeName) if err != nil { return false, err } - available, err := c.checkIfAvailable(awsDisk, "deleting", "") + available, err := c.checkIfAvailable(ctx, awsDisk, "deleting", "") if err != nil { if isAWSErrorVolumeNotFound(err) { klog.V(2).Infof("Volume %s not found when deleting it, assuming it's deleted", awsDisk.awsID) @@ -2780,11 +2744,11 @@ func (c *Cloud) DeleteDisk(volumeName KubernetesVolumeID) (bool, error) { return false, err } - return awsDisk.deleteVolume() + return awsDisk.deleteVolume(ctx) } -func (c *Cloud) checkIfAvailable(disk *awsDisk, opName string, instance string) (bool, error) { - info, err := disk.describeVolume() +func (c *Cloud) checkIfAvailable(ctx context.Context, disk *awsDisk, opName string, instance string) (bool, error) { + info, err := disk.describeVolume(ctx) if err != nil { klog.Errorf("Error describing volume %q: %q", disk.awsID, err) @@ -2792,7 +2756,7 @@ func (c *Cloud) checkIfAvailable(disk *awsDisk, opName string, instance string) return false, err } - volumeState := aws.StringValue(info.State) + volumeState := string(info.State) opError := fmt.Sprintf("error %s EBS volume %q", opName, disk.awsID) if len(instance) != 0 { opError = fmt.Sprintf("%q to instance %q", opError, instance) @@ -2803,14 +2767,14 @@ func (c *Cloud) checkIfAvailable(disk *awsDisk, opName string, instance string) // Volume is attached somewhere else and we can not attach it here if len(info.Attachments) > 0 { attachment := info.Attachments[0] - instanceID := aws.StringValue(attachment.InstanceId) - attachedInstance, ierr := c.getInstanceByID(instanceID) + instanceID := aws.ToString(attachment.InstanceId) + attachedInstance, ierr := c.getInstanceByID(ctx, instanceID) attachErr := fmt.Sprintf("%s since volume is currently attached to %q", opError, instanceID) if ierr != nil { klog.Error(attachErr) return false, errors.New(attachErr) } - devicePath := aws.StringValue(attachment.Device) + devicePath := aws.ToString(attachment.Device) nodeName := mapInstanceToNodeName(attachedInstance) danglingErr := volerr.NewDanglingError(attachErr, nodeName, devicePath) @@ -2837,7 +2801,7 @@ func (c *Cloud) GetLabelsForVolume(ctx context.Context, pv *v1.PersistentVolume) } spec := KubernetesVolumeID(pv.Spec.AWSElasticBlockStore.VolumeID) - labels, err := c.GetVolumeLabels(spec) + labels, err := c.GetVolumeLabels(ctx, spec) if err != nil { return nil, err } @@ -2846,19 +2810,19 @@ func (c *Cloud) GetLabelsForVolume(ctx context.Context, pv *v1.PersistentVolume) } // GetVolumeLabels implements Volumes.GetVolumeLabels -func (c *Cloud) GetVolumeLabels(volumeName KubernetesVolumeID) (map[string]string, error) { +func (c *Cloud) GetVolumeLabels(ctx context.Context, volumeName KubernetesVolumeID) (map[string]string, error) { awsDisk, err := newAWSDisk(c, volumeName) if err != nil { return nil, err } - info, err := awsDisk.describeVolume() + info, err := awsDisk.describeVolume(ctx) if err != nil { return nil, err } labels := make(map[string]string) - az := aws.StringValue(info.AvailabilityZone) + az := aws.ToString(info.AvailabilityZone) if az == "" { - return nil, fmt.Errorf("volume did not have AZ information: %q", aws.StringValue(info.VolumeId)) + return nil, fmt.Errorf("volume did not have AZ information: %q", aws.ToString(info.VolumeId)) } labels[v1.LabelTopologyZone] = az @@ -2872,24 +2836,24 @@ func (c *Cloud) GetVolumeLabels(volumeName KubernetesVolumeID) (map[string]strin } // GetDiskPath implements Volumes.GetDiskPath -func (c *Cloud) GetDiskPath(volumeName KubernetesVolumeID) (string, error) { +func (c *Cloud) GetDiskPath(ctx context.Context, volumeName KubernetesVolumeID) (string, error) { awsDisk, err := newAWSDisk(c, volumeName) if err != nil { return "", err } - info, err := awsDisk.describeVolume() + info, err := awsDisk.describeVolume(ctx) if err != nil { return "", err } if len(info.Attachments) == 0 { return "", fmt.Errorf("No attachment to volume %s", volumeName) } - return aws.StringValue(info.Attachments[0].Device), nil + return aws.ToString(info.Attachments[0].Device), nil } // DiskIsAttached implements Volumes.DiskIsAttached -func (c *Cloud) DiskIsAttached(diskName KubernetesVolumeID, nodeName types.NodeName) (bool, error) { - _, attached, err := c.checkIfAttachedToNode(diskName, nodeName) +func (c *Cloud) DiskIsAttached(ctx context.Context, diskName KubernetesVolumeID, nodeName types.NodeName) (bool, error) { + _, attached, err := c.checkIfAttachedToNode(ctx, diskName, nodeName) if err != nil { if isAWSErrorVolumeNotFound(err) { // The disk doesn't exist, can't be attached @@ -2905,7 +2869,7 @@ func (c *Cloud) DiskIsAttached(diskName KubernetesVolumeID, nodeName types.NodeN // DisksAreAttached returns a map of nodes and Kubernetes volume IDs indicating // if the volumes are attached to the node -func (c *Cloud) DisksAreAttached(nodeDisks map[types.NodeName][]KubernetesVolumeID) (map[types.NodeName]map[KubernetesVolumeID]bool, error) { +func (c *Cloud) DisksAreAttached(ctx context.Context, nodeDisks map[types.NodeName][]KubernetesVolumeID) (map[types.NodeName]map[KubernetesVolumeID]bool, error) { attached := make(map[types.NodeName]map[KubernetesVolumeID]bool) if len(nodeDisks) == 0 { @@ -2922,7 +2886,7 @@ func (c *Cloud) DisksAreAttached(nodeDisks map[types.NodeName][]KubernetesVolume // Note that we get instances regardless of state. // This means there might be multiple nodes with the same node names. - awsInstances, err := c.getInstancesByNodeNames(nodeNames) + awsInstances, err := c.getInstancesByNodeNames(ctx, nodeNames) if err != nil { // When there is an error fetching instance information // it is safer to return nil and let volume information not be touched. @@ -2945,7 +2909,7 @@ func (c *Cloud) DisksAreAttached(nodeDisks map[types.NodeName][]KubernetesVolume awsInstanceState := "" if awsInstance != nil && awsInstance.State != nil { - awsInstanceState = aws.StringValue(awsInstance.State.Name) + awsInstanceState = string(awsInstance.State.Name) } if awsInstanceState == "terminated" { // Instance is terminated, safe to assume volumes not attached @@ -2963,7 +2927,7 @@ func (c *Cloud) DisksAreAttached(nodeDisks map[types.NodeName][]KubernetesVolume } for _, blockDevice := range awsInstance.BlockDeviceMappings { - volumeID := EBSVolumeID(aws.StringValue(blockDevice.Ebs.VolumeId)) + volumeID := EBSVolumeID(aws.ToString(blockDevice.Ebs.VolumeId)) diskName, found := idToDiskName[volumeID] if found { // Disk is still attached to node @@ -2977,7 +2941,7 @@ func (c *Cloud) DisksAreAttached(nodeDisks map[types.NodeName][]KubernetesVolume // ResizeDisk resizes an EBS volume in GiB increments, it will round up to the // next GiB if arguments are not provided in even GiB increments -func (c *Cloud) ResizeDisk( +func (c *Cloud) ResizeDisk(ctx context.Context, diskName KubernetesVolumeID, oldSize resource.Quantity, newSize resource.Quantity) (resource.Quantity, error) { @@ -2986,7 +2950,7 @@ func (c *Cloud) ResizeDisk( return oldSize, err } - volumeInfo, err := awsDisk.describeVolume() + volumeInfo, err := awsDisk.describeVolume(ctx) if err != nil { descErr := fmt.Errorf("AWS.ResizeDisk Error describing volume %s with %v", diskName, err) return oldSize, descErr @@ -2999,10 +2963,10 @@ func (c *Cloud) ResizeDisk( newSizeQuant := resource.MustParse(fmt.Sprintf("%dGi", requestGiB)) // If disk already if of greater or equal size than requested we return - if aws.Int64Value(volumeInfo.Size) >= requestGiB { + if aws.ToInt32(volumeInfo.Size) >= int32(requestGiB) { return newSizeQuant, nil } - _, err = awsDisk.modifyVolume(requestGiB) + _, err = awsDisk.modifyVolume(ctx, requestGiB) if err != nil { return oldSize, err @@ -3011,34 +2975,37 @@ func (c *Cloud) ResizeDisk( } // Gets the current load balancer state -func (c *Cloud) describeLoadBalancer(name string) (*elb.LoadBalancerDescription, error) { +func (c *Cloud) describeLoadBalancer(ctx context.Context, name string) (*elbtypes.LoadBalancerDescription, error) { request := &elb.DescribeLoadBalancersInput{} - request.LoadBalancerNames = []*string{&name} + request.LoadBalancerNames = []string{name} + + response, err := c.elb.DescribeLoadBalancers(ctx, request) - response, err := c.elb.DescribeLoadBalancers(request) if err != nil { - if awsError, ok := err.(awserr.Error); ok { - if awsError.Code() == "LoadBalancerNotFound" { + var ae smithy.APIError + if errors.As(err, &ae) { + if ae.ErrorCode() == "LoadBalancerNotFound" { return nil, nil } } + return nil, err } - var ret *elb.LoadBalancerDescription + var ret *elbtypes.LoadBalancerDescription for _, loadBalancer := range response.LoadBalancerDescriptions { if ret != nil { klog.Errorf("Found multiple load balancers with name: %s", name) } - ret = loadBalancer + ret = &loadBalancer } return ret, nil } -func (c *Cloud) addLoadBalancerTags(loadBalancerName string, requested map[string]string) error { - var tags []*elb.Tag +func (c *Cloud) addLoadBalancerTags(ctx context.Context, loadBalancerName string, requested map[string]string) error { + var tags []elbtypes.Tag for k, v := range requested { - tag := &elb.Tag{ + tag := elbtypes.Tag{ Key: aws.String(k), Value: aws.String(v), } @@ -3046,10 +3013,10 @@ func (c *Cloud) addLoadBalancerTags(loadBalancerName string, requested map[strin } request := &elb.AddTagsInput{} - request.LoadBalancerNames = []*string{&loadBalancerName} + request.LoadBalancerNames = []string{loadBalancerName} request.Tags = tags - _, err := c.elb.AddTags(request) + _, err := c.elb.AddTags(ctx, request) if err != nil { return fmt.Errorf("error adding tags to load balancer: %v", err) } @@ -3057,25 +3024,24 @@ func (c *Cloud) addLoadBalancerTags(loadBalancerName string, requested map[strin } // Gets the current load balancer state -func (c *Cloud) describeLoadBalancerv2(name string) (*elbv2.LoadBalancer, error) { +func (c *Cloud) describeLoadBalancerv2(ctx context.Context, name string) (*elbv2types.LoadBalancer, error) { request := &elbv2.DescribeLoadBalancersInput{ - Names: []*string{aws.String(name)}, + Names: []string{name}, } - response, err := c.elbv2.DescribeLoadBalancers(request) + response, err := c.elbv2.DescribeLoadBalancers(ctx, request) if err != nil { - if awsError, ok := err.(awserr.Error); ok { - if awsError.Code() == elbv2.ErrCodeLoadBalancerNotFoundException { - return nil, nil - } + var notFoundErr *elbv2types.LoadBalancerNotFoundException + if errors.As(err, ¬FoundErr) { + return nil, nil } return nil, fmt.Errorf("error describing load balancer: %q", err) } // AWS will not return 2 load balancers with the same name _and_ type. for i := range response.LoadBalancers { - if aws.StringValue(response.LoadBalancers[i].Type) == elbv2.LoadBalancerTypeEnumNetwork { - return response.LoadBalancers[i], nil + if response.LoadBalancers[i].Type == elbv2types.LoadBalancerTypeEnumNetwork { + return &response.LoadBalancers[i], nil } } @@ -3083,11 +3049,17 @@ func (c *Cloud) describeLoadBalancerv2(name string) (*elbv2.LoadBalancer, error) } // Retrieves instance's vpc id from metadata -func (c *Cloud) findVPCID() (string, error) { - macs, err := c.metadata.GetMetadata("network/interfaces/macs/") +func (c *Cloud) findVPCID(ctx context.Context) (string, error) { + macsMetadata, err := c.metadata.GetMetadata(ctx, &imds.GetMetadataInput{Path: "network/interfaces/macs/"}) if err != nil { return "", fmt.Errorf("could not list interfaces of the instance: %q", err) } + macsBytes, err := io.ReadAll(macsMetadata.Content) + if err != nil { + return "", fmt.Errorf("unable to parse macs: %q", err) + } + defer macsMetadata.Content.Close() + macs := string(macsBytes) // loop over interfaces, first vpc id returned wins for _, macPath := range strings.Split(macs, "\n") { @@ -3095,23 +3067,28 @@ func (c *Cloud) findVPCID() (string, error) { continue } url := fmt.Sprintf("network/interfaces/macs/%svpc-id", macPath) - vpcID, err := c.metadata.GetMetadata(url) + vpcIDMetadata, err := c.metadata.GetMetadata(ctx, &imds.GetMetadataInput{Path: url}) if err != nil { continue } - return vpcID, nil + vpcIDBytes, err := io.ReadAll(vpcIDMetadata.Content) + if err != nil { + continue + } + defer vpcIDMetadata.Content.Close() + return string(vpcIDBytes), nil } return "", fmt.Errorf("could not find VPC ID in instance metadata") } // Retrieves the specified security group from the AWS API, or returns nil if not found -func (c *Cloud) findSecurityGroup(securityGroupID string) (*ec2.SecurityGroup, error) { +func (c *Cloud) findSecurityGroup(ctx context.Context, securityGroupID string) (*ec2types.SecurityGroup, error) { describeSecurityGroupsRequest := &ec2.DescribeSecurityGroupsInput{ - GroupIds: []*string{&securityGroupID}, + GroupIds: []string{securityGroupID}, } // We don't apply our tag filters because we are retrieving by ID - groups, err := c.ec2.DescribeSecurityGroups(describeSecurityGroupsRequest) + groups, err := c.ec2.DescribeSecurityGroups(ctx, describeSecurityGroupsRequest) if err != nil { klog.Warningf("Error retrieving security group: %q", err) return nil, err @@ -3125,10 +3102,10 @@ func (c *Cloud) findSecurityGroup(securityGroupID string) (*ec2.SecurityGroup, e return nil, fmt.Errorf("multiple security groups found with same id %q", securityGroupID) } group := groups[0] - return group, nil + return &group, nil } -func isEqualIntPointer(l, r *int64) bool { +func isEqualIntPointer(l, r *int32) bool { if l == nil { return r == nil } @@ -3148,7 +3125,7 @@ func isEqualStringPointer(l, r *string) bool { return *l == *r } -func ipPermissionExists(newPermission, existing *ec2.IpPermission, compareGroupUserIDs bool) bool { +func ipPermissionExists(newPermission, existing *ec2types.IpPermission, compareGroupUserIDs bool) bool { if !isEqualIntPointer(newPermission.FromPort, existing.FromPort) { return false } @@ -3181,7 +3158,7 @@ func ipPermissionExists(newPermission, existing *ec2.IpPermission, compareGroupU for _, leftPair := range newPermission.UserIdGroupPairs { found := false for _, rightPair := range existing.UserIdGroupPairs { - if isEqualUserGroupPair(leftPair, rightPair, compareGroupUserIDs) { + if isEqualUserGroupPair(&leftPair, &rightPair, compareGroupUserIDs) { found = true break } @@ -3194,7 +3171,7 @@ func ipPermissionExists(newPermission, existing *ec2.IpPermission, compareGroupU return true } -func isEqualUserGroupPair(l, r *ec2.UserIdGroupPair, compareGroupUserIDs bool) bool { +func isEqualUserGroupPair(l, r *ec2types.UserIdGroupPair, compareGroupUserIDs bool) bool { klog.V(2).Infof("Comparing %v to %v", *l.GroupId, *r.GroupId) if isEqualStringPointer(l.GroupId, r.GroupId) { if compareGroupUserIDs { @@ -3212,8 +3189,8 @@ func isEqualUserGroupPair(l, r *ec2.UserIdGroupPair, compareGroupUserIDs bool) b // Makes sure the security group ingress is exactly the specified permissions // Returns true if and only if changes were made // The security group must already exist -func (c *Cloud) setSecurityGroupIngress(securityGroupID string, permissions IPPermissionSet) (bool, error) { - group, err := c.findSecurityGroup(securityGroupID) +func (c *Cloud) setSecurityGroupIngress(ctx context.Context, securityGroupID string, permissions IPPermissionSet) (bool, error) { + group, err := c.findSecurityGroup(ctx, securityGroupID) if err != nil { klog.Warningf("Error retrieving security group %q", err) return false, err @@ -3259,7 +3236,7 @@ func (c *Cloud) setSecurityGroupIngress(securityGroupID string, permissions IPPe request := &ec2.AuthorizeSecurityGroupIngressInput{} request.GroupId = &securityGroupID request.IpPermissions = add.List() - _, err = c.ec2.AuthorizeSecurityGroupIngress(request) + _, err = c.ec2.AuthorizeSecurityGroupIngress(ctx, request) if err != nil { return false, fmt.Errorf("error authorizing security group ingress: %q", err) } @@ -3270,7 +3247,7 @@ func (c *Cloud) setSecurityGroupIngress(securityGroupID string, permissions IPPe request := &ec2.RevokeSecurityGroupIngressInput{} request.GroupId = &securityGroupID request.IpPermissions = remove.List() - _, err = c.ec2.RevokeSecurityGroupIngress(request) + _, err = c.ec2.RevokeSecurityGroupIngress(ctx, request) if err != nil { return false, fmt.Errorf("error revoking security group ingress: %q", err) } @@ -3282,13 +3259,13 @@ func (c *Cloud) setSecurityGroupIngress(securityGroupID string, permissions IPPe // Makes sure the security group includes the specified permissions // Returns true if and only if changes were made // The security group must already exist -func (c *Cloud) addSecurityGroupIngress(securityGroupID string, addPermissions []*ec2.IpPermission) (bool, error) { +func (c *Cloud) addSecurityGroupIngress(ctx context.Context, securityGroupID string, addPermissions []ec2types.IpPermission) (bool, error) { // We do not want to make changes to the Global defined SG if securityGroupID == c.cfg.Global.ElbSecurityGroup { return false, nil } - group, err := c.findSecurityGroup(securityGroupID) + group, err := c.findSecurityGroup(ctx, securityGroupID) if err != nil { klog.Warningf("Error retrieving security group: %q", err) return false, err @@ -3300,7 +3277,7 @@ func (c *Cloud) addSecurityGroupIngress(securityGroupID string, addPermissions [ klog.V(2).Infof("Existing security group ingress: %s %v", securityGroupID, group.IpPermissions) - changes := []*ec2.IpPermission{} + changes := []ec2types.IpPermission{} for _, addPermission := range addPermissions { hasUserID := false for i := range addPermission.UserIdGroupPairs { @@ -3311,7 +3288,7 @@ func (c *Cloud) addSecurityGroupIngress(securityGroupID string, addPermissions [ found := false for _, groupPermission := range group.IpPermissions { - if ipPermissionExists(addPermission, groupPermission, hasUserID) { + if ipPermissionExists(&addPermission, &groupPermission, hasUserID) { found = true break } @@ -3331,7 +3308,7 @@ func (c *Cloud) addSecurityGroupIngress(securityGroupID string, addPermissions [ request := &ec2.AuthorizeSecurityGroupIngressInput{} request.GroupId = &securityGroupID request.IpPermissions = changes - _, err = c.ec2.AuthorizeSecurityGroupIngress(request) + _, err = c.ec2.AuthorizeSecurityGroupIngress(ctx, request) if err != nil { klog.Warningf("Error authorizing security group ingress %q", err) return false, fmt.Errorf("error authorizing security group ingress: %q", err) @@ -3343,13 +3320,13 @@ func (c *Cloud) addSecurityGroupIngress(securityGroupID string, addPermissions [ // Makes sure the security group no longer includes the specified permissions // Returns true if and only if changes were made // If the security group no longer exists, will return (false, nil) -func (c *Cloud) removeSecurityGroupIngress(securityGroupID string, removePermissions []*ec2.IpPermission) (bool, error) { +func (c *Cloud) removeSecurityGroupIngress(ctx context.Context, securityGroupID string, removePermissions []ec2types.IpPermission) (bool, error) { // We do not want to make changes to the Global defined SG if securityGroupID == c.cfg.Global.ElbSecurityGroup { return false, nil } - group, err := c.findSecurityGroup(securityGroupID) + group, err := c.findSecurityGroup(ctx, securityGroupID) if err != nil { klog.Warningf("Error retrieving security group: %q", err) return false, err @@ -3360,7 +3337,7 @@ func (c *Cloud) removeSecurityGroupIngress(securityGroupID string, removePermiss return false, nil } - changes := []*ec2.IpPermission{} + changes := []ec2types.IpPermission{} for _, removePermission := range removePermissions { hasUserID := false for i := range removePermission.UserIdGroupPairs { @@ -3369,16 +3346,16 @@ func (c *Cloud) removeSecurityGroupIngress(securityGroupID string, removePermiss } } - var found *ec2.IpPermission + var found *ec2types.IpPermission for _, groupPermission := range group.IpPermissions { - if ipPermissionExists(removePermission, groupPermission, hasUserID) { - found = removePermission + if ipPermissionExists(&removePermission, &groupPermission, hasUserID) { + found = &removePermission break } } if found != nil { - changes = append(changes, found) + changes = append(changes, *found) } } @@ -3391,7 +3368,7 @@ func (c *Cloud) removeSecurityGroupIngress(securityGroupID string, removePermiss request := &ec2.RevokeSecurityGroupIngressInput{} request.GroupId = &securityGroupID request.IpPermissions = changes - _, err = c.ec2.RevokeSecurityGroupIngress(request) + _, err = c.ec2.RevokeSecurityGroupIngress(ctx, request) if err != nil { klog.Warningf("Error revoking security group ingress: %q", err) return false, err @@ -3404,7 +3381,7 @@ func (c *Cloud) removeSecurityGroupIngress(securityGroupID string, removePermiss // For multi-cluster isolation, name must be globally unique, for example derived from the service UUID. // Additional tags can be specified // Returns the security group id or error -func (c *Cloud) ensureSecurityGroup(name string, description string, additionalTags map[string]string) (string, error) { +func (c *Cloud) ensureSecurityGroup(ctx context.Context, name string, description string, additionalTags map[string]string) (string, error) { groupID := "" attempt := 0 for { @@ -3416,12 +3393,12 @@ func (c *Cloud) ensureSecurityGroup(name string, description string, additionalT // If it has a different cluster's tags, that is an error. // This shouldn't happen because name is expected to be globally unique (UUID derived) request := &ec2.DescribeSecurityGroupsInput{} - request.Filters = []*ec2.Filter{ + request.Filters = []ec2types.Filter{ newEc2Filter("group-name", name), newEc2Filter("vpc-id", c.vpcID), } - securityGroups, err := c.ec2.DescribeSecurityGroups(request) + securityGroups, err := c.ec2.DescribeSecurityGroups(ctx, request) if err != nil { return "", err } @@ -3430,14 +3407,14 @@ func (c *Cloud) ensureSecurityGroup(name string, description string, additionalT if len(securityGroups) > 1 { klog.Warningf("Found multiple security groups with name: %q", name) } - err := c.tagging.readRepairClusterTags( - c.ec2, aws.StringValue(securityGroups[0].GroupId), + err := c.tagging.readRepairClusterTags(ctx, + c.ec2, aws.ToString(securityGroups[0].GroupId), ResourceLifecycleOwned, nil, securityGroups[0].Tags) if err != nil { return "", err } - return aws.StringValue(securityGroups[0].GroupId), nil + return aws.ToString(securityGroups[0].GroupId), nil } createRequest := &ec2.CreateSecurityGroupInput{} @@ -3445,27 +3422,27 @@ func (c *Cloud) ensureSecurityGroup(name string, description string, additionalT createRequest.GroupName = &name createRequest.Description = &description tags := c.tagging.buildTags(ResourceLifecycleOwned, additionalTags) - var awsTags []*ec2.Tag + var awsTags []ec2types.Tag for k, v := range tags { - tag := &ec2.Tag{ + tag := ec2types.Tag{ Key: aws.String(k), Value: aws.String(v), } awsTags = append(awsTags, tag) } - createRequest.TagSpecifications = []*ec2.TagSpecification{ + createRequest.TagSpecifications = []ec2types.TagSpecification{ { - ResourceType: aws.String(ec2.ResourceTypeSecurityGroup), + ResourceType: ec2types.ResourceTypeSecurityGroup, Tags: awsTags, }, } - createResponse, err := c.ec2.CreateSecurityGroup(createRequest) + createResponse, err := c.ec2.CreateSecurityGroup(ctx, createRequest) if err != nil { ignore := false - switch err := err.(type) { - case awserr.Error: - if err.Code() == "InvalidGroup.Duplicate" && attempt < MaxReadThenCreateRetries { + var ae smithy.APIError + if errors.As(err, &ae) { + if ae.ErrorCode() == "InvalidGroup.Duplicate" && attempt < MaxReadThenCreateRetries { klog.V(2).Infof("Got InvalidGroup.Duplicate while creating security group (race?); will retry") ignore = true } @@ -3476,7 +3453,7 @@ func (c *Cloud) ensureSecurityGroup(name string, description string, additionalT } time.Sleep(1 * time.Second) } else { - groupID = aws.StringValue(createResponse.GroupId) + groupID = aws.ToString(createResponse.GroupId) break } } @@ -3488,10 +3465,10 @@ func (c *Cloud) ensureSecurityGroup(name string, description string, additionalT } // Finds the value for a given tag. -func findTag(tags []*ec2.Tag, key string) (string, bool) { +func findTag(tags []ec2types.Tag, key string) (string, bool) { for _, tag := range tags { - if aws.StringValue(tag.Key) == key { - return aws.StringValue(tag.Value), true + if aws.ToString(tag.Key) == key { + return aws.ToString(tag.Value), true } } return "", false @@ -3500,16 +3477,16 @@ func findTag(tags []*ec2.Tag, key string) (string, bool) { // Finds the subnets associated with the cluster, by matching cluster tags if present. // For maximal backwards compatibility, if no subnets are tagged, it will fall-back to the current subnet. // However, in future this will likely be treated as an error. -func (c *Cloud) findSubnets() ([]*ec2.Subnet, error) { +func (c *Cloud) findSubnets(ctx context.Context) ([]ec2types.Subnet, error) { request := &ec2.DescribeSubnetsInput{} - request.Filters = []*ec2.Filter{newEc2Filter("vpc-id", c.vpcID)} + request.Filters = []ec2types.Filter{newEc2Filter("vpc-id", c.vpcID)} - subnets, err := c.ec2.DescribeSubnets(request) + subnets, err := c.ec2.DescribeSubnets(ctx, request) if err != nil { return nil, fmt.Errorf("error describing subnets: %q", err) } - var matches []*ec2.Subnet + var matches []ec2types.Subnet for _, subnet := range subnets { if c.tagging.hasClusterTag(subnet.Tags) { matches = append(matches, subnet) @@ -3526,9 +3503,9 @@ func (c *Cloud) findSubnets() ([]*ec2.Subnet, error) { klog.Warningf("No tagged subnets found; will fall-back to the current subnet only. This is likely to be an error in a future version of k8s.") request = &ec2.DescribeSubnetsInput{} - request.Filters = []*ec2.Filter{newEc2Filter("subnet-id", c.selfAWSInstance.subnetID)} + request.Filters = []ec2types.Filter{newEc2Filter("subnet-id", c.selfAWSInstance.subnetID)} - subnets, err = c.ec2.DescribeSubnets(request) + subnets, err = c.ec2.DescribeSubnets(ctx, request) if err != nil { return nil, fmt.Errorf("error describing subnets: %q", err) } @@ -3538,24 +3515,24 @@ func (c *Cloud) findSubnets() ([]*ec2.Subnet, error) { // Returns a mapping between availability zone names and their types // Zone will not be included in the map in case it was not found in AWS by name -func (c *Cloud) getZoneTypesByName(azNames []string) (map[string]string, error) { +func (c *Cloud) getZoneTypesByName(ctx context.Context, azNames []string) (map[string]string, error) { if len(azNames) == 0 { // if az names slice is empty, no need to make a request, return early with empty map return map[string]string{}, nil } azFilter := newEc2Filter("zone-name", azNames...) azRequest := &ec2.DescribeAvailabilityZonesInput{} - azRequest.Filters = []*ec2.Filter{azFilter} + azRequest.Filters = []ec2types.Filter{azFilter} - azs, err := c.ec2.DescribeAvailabilityZones(azRequest) + azs, err := c.ec2.DescribeAvailabilityZones(ctx, azRequest) if err != nil { return nil, fmt.Errorf("error describe availability zones: %q", err) } azTypesMapping := make(map[string]string) for _, az := range azs { - name := aws.StringValue(az.ZoneName) - zoneType := aws.StringValue(az.ZoneType) + name := aws.ToString(az.ZoneName) + zoneType := aws.ToString(az.ZoneType) if name == "" || zoneType == "" { klog.Warningf("Ignoring zone with empty name/type: %v", az) continue @@ -3568,25 +3545,25 @@ func (c *Cloud) getZoneTypesByName(azNames []string) (map[string]string, error) // Finds the subnets to use for an ELB we are creating. // Normal (Internet-facing) ELBs must use public subnets, so we skip private subnets. // Internal ELBs can use public or private subnets, but if we have a private subnet we should prefer that. -func (c *Cloud) findELBSubnets(internalELB bool) ([]string, error) { +func (c *Cloud) findELBSubnets(ctx context.Context, internalELB bool) ([]string, error) { vpcIDFilter := newEc2Filter("vpc-id", c.vpcID) - subnets, err := c.findSubnets() + subnets, err := c.findSubnets(ctx) if err != nil { return nil, err } rRequest := &ec2.DescribeRouteTablesInput{} - rRequest.Filters = []*ec2.Filter{vpcIDFilter} - rt, err := c.ec2.DescribeRouteTables(rRequest) + rRequest.Filters = []ec2types.Filter{vpcIDFilter} + rt, err := c.ec2.DescribeRouteTables(ctx, rRequest) if err != nil { return nil, fmt.Errorf("error describe route table: %q", err) } - subnetsByAZ := make(map[string]*ec2.Subnet) + subnetsByAZ := make(map[string]ec2types.Subnet) for _, subnet := range subnets { - az := aws.StringValue(subnet.AvailabilityZone) - id := aws.StringValue(subnet.SubnetId) + az := aws.ToString(subnet.AvailabilityZone) + id := aws.ToString(subnet.SubnetId) if az == "" || id == "" { klog.Warningf("Ignoring subnet with empty az/id: %v", subnet) continue @@ -3601,8 +3578,8 @@ func (c *Cloud) findELBSubnets(internalELB bool) ([]string, error) { continue } - existing := subnetsByAZ[az] - if existing == nil { + existing, exists := subnetsByAZ[az] + if !exists { subnetsByAZ[az] = subnet continue } @@ -3653,7 +3630,7 @@ func (c *Cloud) findELBSubnets(internalELB bool) ([]string, error) { sort.Strings(azNames) - azTypesMapping, err := c.getZoneTypesByName(azNames) + azTypesMapping, err := c.getZoneTypesByName(ctx, azNames) if err != nil { return nil, fmt.Errorf("error get availability zone types: %q", err) } @@ -3667,7 +3644,7 @@ func (c *Cloud) findELBSubnets(internalELB bool) ([]string, error) { // does not support NLB/CLB for the moment, only ALB. continue } - subnetIDs = append(subnetIDs, aws.StringValue(subnetsByAZ[key].SubnetId)) + subnetIDs = append(subnetIDs, aws.ToString(subnetsByAZ[key].SubnetId)) } return subnetIDs, nil @@ -3696,15 +3673,15 @@ func parseStringSliceAnnotation(annotations map[string]string, annotation string return true } -func (c *Cloud) getLoadBalancerSubnets(service *v1.Service, internalELB bool) ([]string, error) { +func (c *Cloud) getLoadBalancerSubnets(ctx context.Context, service *v1.Service, internalELB bool) ([]string, error) { var rawSubnetNameOrIDs []string if exists := parseStringSliceAnnotation(service.Annotations, ServiceAnnotationLoadBalancerSubnets, &rawSubnetNameOrIDs); exists { - return c.resolveSubnetNameOrIDs(rawSubnetNameOrIDs) + return c.resolveSubnetNameOrIDs(ctx, rawSubnetNameOrIDs) } - return c.findELBSubnets(internalELB) + return c.findELBSubnets(ctx, internalELB) } -func (c *Cloud) resolveSubnetNameOrIDs(subnetNameOrIDs []string) ([]string, error) { +func (c *Cloud) resolveSubnetNameOrIDs(ctx context.Context, subnetNameOrIDs []string) ([]string, error) { var subnetIDs []string var subnetNames []string if len(subnetNameOrIDs) == 0 { @@ -3717,12 +3694,12 @@ func (c *Cloud) resolveSubnetNameOrIDs(subnetNameOrIDs []string) ([]string, erro subnetNames = append(subnetNames, nameOrID) } } - var resolvedSubnets []*ec2.Subnet + var resolvedSubnets []ec2types.Subnet if len(subnetIDs) > 0 { req := &ec2.DescribeSubnetsInput{ - SubnetIds: aws.StringSlice(subnetIDs), + SubnetIds: subnetIDs, } - subnets, err := c.ec2.DescribeSubnets(req) + subnets, err := c.ec2.DescribeSubnets(ctx, req) if err != nil { return []string{}, err } @@ -3730,18 +3707,18 @@ func (c *Cloud) resolveSubnetNameOrIDs(subnetNameOrIDs []string) ([]string, erro } if len(subnetNames) > 0 { req := &ec2.DescribeSubnetsInput{ - Filters: []*ec2.Filter{ + Filters: []ec2types.Filter{ { Name: aws.String("tag:Name"), - Values: aws.StringSlice(subnetNames), + Values: subnetNames, }, { Name: aws.String("vpc-id"), - Values: aws.StringSlice([]string{c.vpcID}), + Values: []string{c.vpcID}, }, }, } - subnets, err := c.ec2.DescribeSubnets(req) + subnets, err := c.ec2.DescribeSubnets(ctx, req) if err != nil { return []string{}, err } @@ -3752,17 +3729,17 @@ func (c *Cloud) resolveSubnetNameOrIDs(subnetNameOrIDs []string) ([]string, erro } var subnets []string for _, subnet := range resolvedSubnets { - subnets = append(subnets, aws.StringValue(subnet.SubnetId)) + subnets = append(subnets, aws.ToString(subnet.SubnetId)) } return subnets, nil } -func isSubnetPublic(rt []*ec2.RouteTable, subnetID string) (bool, error) { - var subnetTable *ec2.RouteTable +func isSubnetPublic(rt []ec2types.RouteTable, subnetID string) (bool, error) { + var subnetTable *ec2types.RouteTable for _, table := range rt { for _, assoc := range table.Associations { - if aws.StringValue(assoc.SubnetId) == subnetID { - subnetTable = table + if aws.ToString(assoc.SubnetId) == subnetID { + subnetTable = &table break } } @@ -3773,10 +3750,10 @@ func isSubnetPublic(rt []*ec2.RouteTable, subnetID string) (bool, error) { // associated with the VPC's main routing table. for _, table := range rt { for _, assoc := range table.Associations { - if aws.BoolValue(assoc.Main) == true { + if aws.ToBool(assoc.Main) == true { klog.V(4).Infof("Assuming implicit use of main routing table %s for %s", - aws.StringValue(table.RouteTableId), subnetID) - subnetTable = table + aws.ToString(table.RouteTableId), subnetID) + subnetTable = &table break } } @@ -3794,7 +3771,7 @@ func isSubnetPublic(rt []*ec2.RouteTable, subnetID string) (bool, error) { // from the default in-subnet route which is called "local" // or other virtual gateway (starting with vgv) // or vpc peering connections (starting with pcx). - if strings.HasPrefix(aws.StringValue(route.GatewayId), "igw") { + if strings.HasPrefix(aws.ToString(route.GatewayId), "igw") { return true, nil } } @@ -3803,8 +3780,8 @@ func isSubnetPublic(rt []*ec2.RouteTable, subnetID string) (bool, error) { } type portSets struct { - names sets.String - numbers sets.Int64 + names sets.Set[string] + numbers sets.Set[int32] } // getPortSets returns a portSets structure representing port names and numbers @@ -3813,8 +3790,8 @@ type portSets struct { func getPortSets(annotation string) (ports *portSets) { if annotation != "" && annotation != "*" { ports = &portSets{ - sets.NewString(), - sets.NewInt64(), + sets.New[string](), + sets.New[int32](), } portStringSlice := strings.Split(annotation, ",") for _, item := range portStringSlice { @@ -3822,7 +3799,7 @@ func getPortSets(annotation string) (ports *portSets) { if err != nil { ports.names.Insert(item) } else { - ports.numbers.Insert(int64(port)) + ports.numbers.Insert(int32(port)) } } } @@ -3847,7 +3824,7 @@ func getSGListFromAnnotation(annotatedSG string) []string { // Extra groups can be specified via annotation, as can extra tags for any // new groups. The annotation "ServiceAnnotationLoadBalancerSecurityGroups" allows for // setting the security groups specified. -func (c *Cloud) buildELBSecurityGroupList(serviceName types.NamespacedName, loadBalancerName string, annotations map[string]string) ([]string, bool, error) { +func (c *Cloud) buildELBSecurityGroupList(ctx context.Context, serviceName types.NamespacedName, loadBalancerName string, annotations map[string]string) ([]string, bool, error) { var err error var securityGroupID string // We do not want to make changes to a Global defined SG @@ -3863,7 +3840,7 @@ func (c *Cloud) buildELBSecurityGroupList(serviceName types.NamespacedName, load // Create a security group for the load balancer sgName := "k8s-elb-" + loadBalancerName sgDescription := fmt.Sprintf("Security group for Kubernetes ELB %s (%v)", loadBalancerName, serviceName) - securityGroupID, err = c.ensureSecurityGroup(sgName, sgDescription, getKeyValuePropertiesFromAnnotation(annotations, ServiceAnnotationLoadBalancerAdditionalTags)) + securityGroupID, err = c.ensureSecurityGroup(ctx, sgName, sgDescription, getKeyValuePropertiesFromAnnotation(annotations, ServiceAnnotationLoadBalancerAdditionalTags)) if err != nil { klog.Errorf("Error creating load balancer security group: %q", err) return nil, setupSg, err @@ -3912,16 +3889,16 @@ func (c *Cloud) sortELBSecurityGroupList(securityGroupIDs []string, annotations // buildListener creates a new listener from the given port, adding an SSL certificate // if indicated by the appropriate annotations. -func buildListener(port v1.ServicePort, annotations map[string]string, sslPorts *portSets) (*elb.Listener, error) { - loadBalancerPort := int64(port.Port) +func buildListener(port v1.ServicePort, annotations map[string]string, sslPorts *portSets) (elbtypes.Listener, error) { + loadBalancerPort := port.Port portName := strings.ToLower(port.Name) - instancePort := int64(port.NodePort) + instancePort := port.NodePort protocol := strings.ToLower(string(port.Protocol)) instanceProtocol := protocol - listener := &elb.Listener{} + listener := elbtypes.Listener{} listener.InstancePort = &instancePort - listener.LoadBalancerPort = &loadBalancerPort + listener.LoadBalancerPort = loadBalancerPort certID := annotations[ServiceAnnotationLoadBalancerCertificate] if certID != "" && (sslPorts == nil || sslPorts.numbers.Has(loadBalancerPort) || sslPorts.names.Has(portName)) { instanceProtocol = annotations[ServiceAnnotationLoadBalancerBEProtocol] @@ -3931,7 +3908,7 @@ func buildListener(port v1.ServicePort, annotations map[string]string, sslPorts } else { protocol = backendProtocolMapping[instanceProtocol] if protocol == "" { - return nil, fmt.Errorf("Invalid backend protocol %s for %s in %s", instanceProtocol, certID, ServiceAnnotationLoadBalancerBEProtocol) + return elbtypes.Listener{}, fmt.Errorf("Invalid backend protocol %s for %s in %s", instanceProtocol, certID, ServiceAnnotationLoadBalancerBEProtocol) } } listener.SSLCertificateId = &certID @@ -3946,13 +3923,13 @@ func buildListener(port v1.ServicePort, annotations map[string]string, sslPorts return listener, nil } -func (c *Cloud) getSubnetCidrs(subnetIDs []string) ([]string, error) { +func (c *Cloud) getSubnetCidrs(ctx context.Context, subnetIDs []string) ([]string, error) { request := &ec2.DescribeSubnetsInput{} for _, subnetID := range subnetIDs { - request.SubnetIds = append(request.SubnetIds, aws.String(subnetID)) + request.SubnetIds = append(request.SubnetIds, subnetID) } - subnets, err := c.ec2.DescribeSubnets(request) + subnets, err := c.ec2.DescribeSubnets(ctx, request) if err != nil { return nil, fmt.Errorf("error querying Subnet for ELB: %q", err) } @@ -3962,7 +3939,7 @@ func (c *Cloud) getSubnetCidrs(subnetIDs []string) ([]string, error) { cidrs := make([]string, 0, len(subnets)) for _, subnet := range subnets { - cidrs = append(cidrs, aws.StringValue(subnet.CidrBlock)) + cidrs = append(cidrs, aws.ToString(subnet.CidrBlock)) } return cidrs, nil } @@ -3975,9 +3952,10 @@ func parseStringAnnotation(annotations map[string]string, annotation string, val return false } -func parseInt64Annotation(annotations map[string]string, annotation string, value *int64) (bool, error) { +func parseInt32Annotation(annotations map[string]string, annotation string, value *int32) (bool, error) { if v, ok := annotations[annotation]; ok { - parsed, err := strconv.ParseInt(v, 10, 0) + parsed64, err := strconv.ParseInt(v, 10, 0) + parsed := int32(parsed64) if err != nil { return true, fmt.Errorf("failed to parse annotation %v=%v", annotation, v) } @@ -3991,7 +3969,7 @@ func (c *Cloud) buildNLBHealthCheckConfiguration(svc *v1.Service) (healthCheckCo hc := healthCheckConfig{ Port: defaultHealthCheckPort, Path: defaultHealthCheckPath, - Protocol: elbv2.ProtocolEnumTcp, + Protocol: elbv2types.ProtocolEnumTcp, Interval: defaultNlbHealthCheckInterval, Timeout: defaultNlbHealthCheckTimeout, HealthyThreshold: defaultNlbHealthCheckThreshold, @@ -4002,20 +3980,22 @@ func (c *Cloud) buildNLBHealthCheckConfiguration(svc *v1.Service) (healthCheckCo hc = healthCheckConfig{ Port: strconv.Itoa(int(port)), Path: path, - Protocol: elbv2.ProtocolEnumHttp, + Protocol: elbv2types.ProtocolEnumHttp, Interval: 10, Timeout: 10, HealthyThreshold: 2, UnhealthyThreshold: 2, } } - if parseStringAnnotation(svc.Annotations, ServiceAnnotationLoadBalancerHealthCheckProtocol, &hc.Protocol) { - hc.Protocol = strings.ToUpper(hc.Protocol) + + var protocolStr string = string(hc.Protocol) + if parseStringAnnotation(svc.Annotations, ServiceAnnotationLoadBalancerHealthCheckProtocol, &protocolStr) { + hc.Protocol = elbv2types.ProtocolEnum(strings.ToUpper(protocolStr)) } switch hc.Protocol { - case elbv2.ProtocolEnumHttp, elbv2.ProtocolEnumHttps: + case elbv2types.ProtocolEnumHttp, elbv2types.ProtocolEnumHttps: parseStringAnnotation(svc.Annotations, ServiceAnnotationLoadBalancerHealthCheckPath, &hc.Path) - case elbv2.ProtocolEnumTcp: + case elbv2types.ProtocolEnumTcp: hc.Path = "" default: return healthCheckConfig{}, fmt.Errorf("Unsupported health check protocol %v", hc.Protocol) @@ -4023,16 +4003,16 @@ func (c *Cloud) buildNLBHealthCheckConfiguration(svc *v1.Service) (healthCheckCo parseStringAnnotation(svc.Annotations, ServiceAnnotationLoadBalancerHealthCheckPort, &hc.Port) - if _, err := parseInt64Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCInterval, &hc.Interval); err != nil { + if _, err := parseInt32Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCInterval, &hc.Interval); err != nil { return healthCheckConfig{}, err } - if _, err := parseInt64Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCTimeout, &hc.Timeout); err != nil { + if _, err := parseInt32Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCTimeout, &hc.Timeout); err != nil { return healthCheckConfig{}, err } - if _, err := parseInt64Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCHealthyThreshold, &hc.HealthyThreshold); err != nil { + if _, err := parseInt32Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCHealthyThreshold, &hc.HealthyThreshold); err != nil { return healthCheckConfig{}, err } - if _, err := parseInt64Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCUnhealthyThreshold, &hc.UnhealthyThreshold); err != nil { + if _, err := parseInt32Annotation(svc.Annotations, ServiceAnnotationLoadBalancerHCUnhealthyThreshold, &hc.UnhealthyThreshold); err != nil { return healthCheckConfig{}, err } @@ -4065,7 +4045,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS return nil, err } // Figure out what mappings we want on the load balancer - listeners := []*elb.Listener{} + listeners := []elbtypes.Listener{} v2Mappings := []nlbPortMapping{} sslPorts := getPortSets(annotations[ServiceAnnotationLoadBalancerSSLPorts]) @@ -4081,10 +4061,10 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS if isNLB(annotations) { portMapping := nlbPortMapping{ - FrontendPort: int64(port.Port), - FrontendProtocol: string(port.Protocol), - TrafficPort: int64(port.NodePort), - TrafficProtocol: string(port.Protocol), + FrontendPort: int32(port.Port), + FrontendProtocol: elbv2types.ProtocolEnum(port.Protocol), + TrafficPort: int32(port.NodePort), + TrafficProtocol: elbv2types.ProtocolEnum(port.Protocol), } var err error if portMapping.HealthCheckConfig, err = c.buildNLBHealthCheckConfiguration(apiService); err != nil { @@ -4092,13 +4072,13 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS } certificateARN := annotations[ServiceAnnotationLoadBalancerCertificate] - if port.Protocol != v1.ProtocolUDP && certificateARN != "" && (sslPorts == nil || sslPorts.numbers.Has(int64(port.Port)) || sslPorts.names.Has(port.Name)) { - portMapping.FrontendProtocol = elbv2.ProtocolEnumTls + if port.Protocol != v1.ProtocolUDP && certificateARN != "" && (sslPorts == nil || sslPorts.numbers.Has(port.Port) || sslPorts.names.Has(port.Name)) { + portMapping.FrontendProtocol = elbv2types.ProtocolEnumTls portMapping.SSLCertificateARN = certificateARN portMapping.SSLPolicy = annotations[ServiceAnnotationLoadBalancerSSLNegotiationPolicy] if backendProtocol := annotations[ServiceAnnotationLoadBalancerBEProtocol]; backendProtocol == "ssl" { - portMapping.TrafficProtocol = elbv2.ProtocolEnumTls + portMapping.TrafficProtocol = elbv2types.ProtocolEnumTls } } @@ -4116,7 +4096,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS return nil, fmt.Errorf("LoadBalancerIP cannot be specified for AWS ELB") } - instances, err := c.findInstancesForELB(nodes, annotations) + instances, err := c.findInstancesForELB(ctx, nodes, annotations) if err != nil { return nil, err } @@ -4137,7 +4117,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS if isNLB(annotations) { // Find the subnets that the ELB will live in - discoveredSubnetIDs, err := c.getLoadBalancerSubnets(apiService, internalELB) + discoveredSubnetIDs, err := c.getLoadBalancerSubnets(ctx, apiService, internalELB) if err != nil { klog.Errorf("Error listing subnets in VPC: %q", err) return nil, err @@ -4155,7 +4135,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS instanceIDs = append(instanceIDs, string(id)) } - v2LoadBalancer, err := c.ensureLoadBalancerv2( + v2LoadBalancer, err := c.ensureLoadBalancerv2(ctx, serviceName, loadBalancerName, v2Mappings, @@ -4177,7 +4157,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS if len(ensuredSubnetIDs) == 0 { return nil, fmt.Errorf("did not find ensured subnets on LB %s", loadBalancerName) } - subnetCidrs, err = c.getSubnetCidrs(ensuredSubnetIDs) + subnetCidrs, err = c.getSubnetCidrs(ctx, ensuredSubnetIDs) if err != nil { klog.Errorf("Error getting subnet cidrs: %q", err) return nil, err @@ -4191,7 +4171,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS sourceRangeCidrs = append(sourceRangeCidrs, "0.0.0.0/0") } - err = c.updateInstanceSecurityGroupsForNLB(loadBalancerName, instances, subnetCidrs, sourceRangeCidrs, v2Mappings) + err = c.updateInstanceSecurityGroupsForNLB(ctx, loadBalancerName, instances, subnetCidrs, sourceRangeCidrs, v2Mappings) if err != nil { klog.Warningf("Error opening ingress rules for the load balancer to the instances: %q", err) return nil, err @@ -4215,24 +4195,24 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS } // Some load balancer attributes are required, so defaults are set. These can be overridden by annotations. - loadBalancerAttributes := &elb.LoadBalancerAttributes{ - AccessLog: &elb.AccessLog{Enabled: aws.Bool(false)}, - ConnectionDraining: &elb.ConnectionDraining{Enabled: aws.Bool(false)}, - ConnectionSettings: &elb.ConnectionSettings{IdleTimeout: aws.Int64(60)}, - CrossZoneLoadBalancing: &elb.CrossZoneLoadBalancing{Enabled: aws.Bool(false)}, + loadBalancerAttributes := &elbtypes.LoadBalancerAttributes{ + AccessLog: &elbtypes.AccessLog{Enabled: false}, + ConnectionDraining: &elbtypes.ConnectionDraining{Enabled: false}, + ConnectionSettings: &elbtypes.ConnectionSettings{IdleTimeout: aws.Int32(60)}, + CrossZoneLoadBalancing: &elbtypes.CrossZoneLoadBalancing{Enabled: false}, } // Determine if an access log emit interval has been specified accessLogEmitIntervalAnnotation := annotations[ServiceAnnotationLoadBalancerAccessLogEmitInterval] if accessLogEmitIntervalAnnotation != "" { - accessLogEmitInterval, err := strconv.ParseInt(accessLogEmitIntervalAnnotation, 10, 64) + accessLogEmitInterval, err := strconv.ParseInt(accessLogEmitIntervalAnnotation, 10, 32) if err != nil { return nil, fmt.Errorf("error parsing service annotation: %s=%s", ServiceAnnotationLoadBalancerAccessLogEmitInterval, accessLogEmitIntervalAnnotation, ) } - loadBalancerAttributes.AccessLog.EmitInterval = &accessLogEmitInterval + loadBalancerAttributes.AccessLog.EmitInterval = aws.Int32(int32(accessLogEmitInterval)) } // Determine if access log enabled/disabled has been specified @@ -4245,7 +4225,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS accessLogEnabledAnnotation, ) } - loadBalancerAttributes.AccessLog.Enabled = &accessLogEnabled + loadBalancerAttributes.AccessLog.Enabled = accessLogEnabled } // Determine if access log s3 bucket name has been specified @@ -4270,33 +4250,33 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS connectionDrainingEnabledAnnotation, ) } - loadBalancerAttributes.ConnectionDraining.Enabled = &connectionDrainingEnabled + loadBalancerAttributes.ConnectionDraining.Enabled = connectionDrainingEnabled } // Determine if connection draining timeout has been specified connectionDrainingTimeoutAnnotation := annotations[ServiceAnnotationLoadBalancerConnectionDrainingTimeout] if connectionDrainingTimeoutAnnotation != "" { - connectionDrainingTimeout, err := strconv.ParseInt(connectionDrainingTimeoutAnnotation, 10, 64) + connectionDrainingTimeout, err := strconv.ParseInt(connectionDrainingTimeoutAnnotation, 10, 32) if err != nil { return nil, fmt.Errorf("error parsing service annotation: %s=%s", ServiceAnnotationLoadBalancerConnectionDrainingTimeout, connectionDrainingTimeoutAnnotation, ) } - loadBalancerAttributes.ConnectionDraining.Timeout = &connectionDrainingTimeout + loadBalancerAttributes.ConnectionDraining.Timeout = aws.Int32(int32(connectionDrainingTimeout)) } // Determine if connection idle timeout has been specified connectionIdleTimeoutAnnotation := annotations[ServiceAnnotationLoadBalancerConnectionIdleTimeout] if connectionIdleTimeoutAnnotation != "" { - connectionIdleTimeout, err := strconv.ParseInt(connectionIdleTimeoutAnnotation, 10, 64) + connectionIdleTimeout, err := strconv.ParseInt(connectionIdleTimeoutAnnotation, 10, 32) if err != nil { return nil, fmt.Errorf("error parsing service annotation: %s=%s", ServiceAnnotationLoadBalancerConnectionIdleTimeout, connectionIdleTimeoutAnnotation, ) } - loadBalancerAttributes.ConnectionSettings.IdleTimeout = &connectionIdleTimeout + loadBalancerAttributes.ConnectionSettings.IdleTimeout = aws.Int32(int32(connectionIdleTimeout)) } // Determine if cross zone load balancing enabled/disabled has been specified @@ -4309,11 +4289,11 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS crossZoneLoadBalancingEnabledAnnotation, ) } - loadBalancerAttributes.CrossZoneLoadBalancing.Enabled = &crossZoneLoadBalancingEnabled + loadBalancerAttributes.CrossZoneLoadBalancing.Enabled = crossZoneLoadBalancingEnabled } // Find the subnets that the ELB will live in - subnetIDs, err := c.getLoadBalancerSubnets(apiService, internalELB) + subnetIDs, err := c.getLoadBalancerSubnets(ctx, apiService, internalELB) if err != nil { klog.Errorf("Error listing subnets in VPC: %q", err) return nil, err @@ -4326,7 +4306,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS loadBalancerName := c.GetLoadBalancerName(ctx, clusterName, apiService) serviceName := types.NamespacedName{Namespace: apiService.Namespace, Name: apiService.Name} - securityGroupIDs, setupSg, err := c.buildELBSecurityGroupList(serviceName, loadBalancerName, annotations) + securityGroupIDs, setupSg, err := c.buildELBSecurityGroupList(ctx, serviceName, loadBalancerName, annotations) if err != nil { return nil, err } @@ -4335,19 +4315,18 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS } if setupSg { - ec2SourceRanges := []*ec2.IpRange{} + ec2SourceRanges := []ec2types.IpRange{} for _, sourceRange := range sourceRanges.StringSlice() { - ec2SourceRanges = append(ec2SourceRanges, &ec2.IpRange{CidrIp: aws.String(sourceRange)}) + ec2SourceRanges = append(ec2SourceRanges, ec2types.IpRange{CidrIp: aws.String(sourceRange)}) } permissions := NewIPPermissionSet() for _, port := range apiService.Spec.Ports { - portInt64 := int64(port.Port) protocol := strings.ToLower(string(port.Protocol)) - permission := &ec2.IpPermission{} - permission.FromPort = &portInt64 - permission.ToPort = &portInt64 + permission := ec2types.IpPermission{} + permission.FromPort = aws.Int32(port.Port) + permission.ToPort = aws.Int32(port.Port) permission.IpRanges = ec2SourceRanges permission.IpProtocol = &protocol @@ -4356,23 +4335,23 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS // Allow ICMP fragmentation packets, important for MTU discovery { - permission := &ec2.IpPermission{ + permission := ec2types.IpPermission{ IpProtocol: aws.String("icmp"), - FromPort: aws.Int64(3), - ToPort: aws.Int64(4), + FromPort: aws.Int32(3), + ToPort: aws.Int32(4), IpRanges: ec2SourceRanges, } permissions.Insert(permission) } - _, err = c.setSecurityGroupIngress(securityGroupIDs[0], permissions) + _, err = c.setSecurityGroupIngress(ctx, securityGroupIDs[0], permissions) if err != nil { return nil, err } } // Build the load balancer itself - loadBalancer, err := c.ensureLoadBalancer( + loadBalancer, err := c.ensureLoadBalancer(ctx, serviceName, loadBalancerName, listeners, @@ -4388,13 +4367,13 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS } if sslPolicyName, ok := annotations[ServiceAnnotationLoadBalancerSSLNegotiationPolicy]; ok { - err := c.ensureSSLNegotiationPolicy(loadBalancer, sslPolicyName) + err := c.ensureSSLNegotiationPolicy(ctx, loadBalancer, sslPolicyName) if err != nil { return nil, err } for _, port := range c.getLoadBalancerTLSPorts(loadBalancer) { - err := c.setSSLNegotiationPolicy(loadBalancerName, sslPolicyName, port) + err := c.setSSLNegotiationPolicy(ctx, loadBalancerName, sslPolicyName, port) if err != nil { return nil, err } @@ -4415,7 +4394,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS if annotations[ServiceAnnotationLoadBalancerHealthCheckPort] == defaultHealthCheckPort { healthCheckNodePort = tcpHealthCheckPort } - err = c.ensureLoadBalancerHealthCheck(loadBalancer, "HTTP", healthCheckNodePort, path, annotations) + err = c.ensureLoadBalancerHealthCheck(ctx, loadBalancer, "HTTP", healthCheckNodePort, path, annotations) if err != nil { return nil, fmt.Errorf("Failed to ensure health check for localized service %v on node port %v: %q", loadBalancerName, healthCheckNodePort, err) } @@ -4429,25 +4408,25 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS hcProtocol = "TCP" } // there must be no path on TCP health check - err = c.ensureLoadBalancerHealthCheck(loadBalancer, hcProtocol, tcpHealthCheckPort, "", annotations) + err = c.ensureLoadBalancerHealthCheck(ctx, loadBalancer, hcProtocol, tcpHealthCheckPort, "", annotations) if err != nil { return nil, err } } - err = c.updateInstanceSecurityGroupsForLoadBalancer(loadBalancer, instances, annotations) + err = c.updateInstanceSecurityGroupsForLoadBalancer(ctx, loadBalancer, instances, annotations) if err != nil { klog.Warningf("Error opening ingress rules for the load balancer to the instances: %q", err) return nil, err } - err = c.ensureLoadBalancerInstances(aws.StringValue(loadBalancer.LoadBalancerName), loadBalancer.Instances, instances) + err = c.ensureLoadBalancerInstances(ctx, aws.ToString(loadBalancer.LoadBalancerName), loadBalancer.Instances, instances) if err != nil { klog.Warningf("Error registering instances with the load balancer: %q", err) return nil, err } - klog.V(1).Infof("Loadbalancer %s (%v) has DNS name %s", loadBalancerName, serviceName, aws.StringValue(loadBalancer.DNSName)) + klog.V(1).Infof("Loadbalancer %s (%v) has DNS name %s", loadBalancerName, serviceName, aws.ToString(loadBalancer.DNSName)) // TODO: Wait for creation? @@ -4463,7 +4442,7 @@ func (c *Cloud) GetLoadBalancer(ctx context.Context, clusterName string, service loadBalancerName := c.GetLoadBalancerName(ctx, clusterName, service) if isNLB(service.Annotations) { - lb, err := c.describeLoadBalancerv2(loadBalancerName) + lb, err := c.describeLoadBalancerv2(ctx, loadBalancerName) if err != nil { return nil, false, err } @@ -4473,7 +4452,7 @@ func (c *Cloud) GetLoadBalancer(ctx context.Context, clusterName string, service return v2toStatus(lb), true, nil } - lb, err := c.describeLoadBalancer(loadBalancerName) + lb, err := c.describeLoadBalancer(ctx, loadBalancerName) if err != nil { return nil, false, err } @@ -4492,19 +4471,19 @@ func (c *Cloud) GetLoadBalancerName(ctx context.Context, clusterName string, ser return cloudprovider.DefaultLoadBalancerName(service) } -func toStatus(lb *elb.LoadBalancerDescription) *v1.LoadBalancerStatus { +func toStatus(lb *elbtypes.LoadBalancerDescription) *v1.LoadBalancerStatus { status := &v1.LoadBalancerStatus{} - if aws.StringValue(lb.DNSName) != "" { + if aws.ToString(lb.DNSName) != "" { var ingress v1.LoadBalancerIngress - ingress.Hostname = aws.StringValue(lb.DNSName) + ingress.Hostname = aws.ToString(lb.DNSName) status.Ingress = []v1.LoadBalancerIngress{ingress} } return status } -func v2toStatus(lb *elbv2.LoadBalancer) *v1.LoadBalancerStatus { +func v2toStatus(lb *elbv2types.LoadBalancer) *v1.LoadBalancerStatus { status := &v1.LoadBalancerStatus{} if lb == nil { klog.Error("[BUG] v2toStatus got nil input, this is a Kubernetes bug, please report") @@ -4512,10 +4491,10 @@ func v2toStatus(lb *elbv2.LoadBalancer) *v1.LoadBalancerStatus { } // We check for Active or Provisioning, the only successful statuses - if aws.StringValue(lb.DNSName) != "" && (aws.StringValue(lb.State.Code) == elbv2.LoadBalancerStateEnumActive || - aws.StringValue(lb.State.Code) == elbv2.LoadBalancerStateEnumProvisioning) { + if aws.ToString(lb.DNSName) != "" && (lb.State.Code == elbv2types.LoadBalancerStateEnumActive || + lb.State.Code == elbv2types.LoadBalancerStateEnumProvisioning) { var ingress v1.LoadBalancerIngress - ingress.Hostname = aws.StringValue(lb.DNSName) + ingress.Hostname = aws.ToString(lb.DNSName) status.Ingress = []v1.LoadBalancerIngress{ingress} } @@ -4526,13 +4505,13 @@ func v2toStatus(lb *elbv2.LoadBalancer) *v1.LoadBalancerStatus { // We only create instances with one security group, so we don't expect multiple security groups. // However, if there are multiple security groups, we will choose the one tagged with our cluster filter. // Otherwise we will return an error. -func findSecurityGroupForInstance(instance *ec2.Instance, taggedSecurityGroups map[string]*ec2.SecurityGroup) (*ec2.GroupIdentifier, error) { - instanceID := aws.StringValue(instance.InstanceId) +func findSecurityGroupForInstance(instance *ec2types.Instance, taggedSecurityGroups map[string]*ec2types.SecurityGroup) (*ec2types.GroupIdentifier, error) { + instanceID := aws.ToString(instance.InstanceId) - var tagged []*ec2.GroupIdentifier - var untagged []*ec2.GroupIdentifier + var tagged []ec2types.GroupIdentifier + var untagged []ec2types.GroupIdentifier for _, group := range instance.SecurityGroups { - groupID := aws.StringValue(group.GroupId) + groupID := aws.ToString(group.GroupId) if groupID == "" { klog.Warningf("Ignoring security group without id for instance %q: %v", instanceID, group) continue @@ -4555,7 +4534,7 @@ func findSecurityGroupForInstance(instance *ec2.Instance, taggedSecurityGroups m } return nil, fmt.Errorf("Multiple tagged security groups found for instance %s; ensure only the k8s security group is tagged; the tagged groups were %v", instanceID, taggedGroups) } - return tagged[0], nil + return &tagged[0], nil } if len(untagged) > 0 { @@ -4563,7 +4542,7 @@ func findSecurityGroupForInstance(instance *ec2.Instance, taggedSecurityGroups m if len(untagged) != 1 { return nil, fmt.Errorf("Multiple untagged security groups found for instance %s; ensure the k8s security group is tagged", instanceID) } - return untagged[0], nil + return &untagged[0], nil } klog.Warningf("No security group found for instance %q", instanceID) @@ -4571,52 +4550,52 @@ func findSecurityGroupForInstance(instance *ec2.Instance, taggedSecurityGroups m } // Return all the security groups that are tagged as being part of our cluster -func (c *Cloud) getTaggedSecurityGroups() (map[string]*ec2.SecurityGroup, error) { +func (c *Cloud) getTaggedSecurityGroups(ctx context.Context) (map[string]*ec2types.SecurityGroup, error) { request := &ec2.DescribeSecurityGroupsInput{} - groups, err := c.ec2.DescribeSecurityGroups(request) + groups, err := c.ec2.DescribeSecurityGroups(ctx, request) if err != nil { return nil, fmt.Errorf("error querying security groups: %q", err) } - m := make(map[string]*ec2.SecurityGroup) + m := make(map[string]*ec2types.SecurityGroup) for _, group := range groups { if !c.tagging.hasClusterTag(group.Tags) { continue } - id := aws.StringValue(group.GroupId) + id := aws.ToString(group.GroupId) if id == "" { klog.Warningf("Ignoring group without id: %v", group) continue } - m[id] = group + m[id] = &group } return m, nil } // Open security group ingress rules on the instances so that the load balancer can talk to them // Will also remove any security groups ingress rules for the load balancer that are _not_ needed for allInstances -func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancerDescription, instances map[InstanceID]*ec2.Instance, annotations map[string]string) error { +func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(ctx context.Context, lb *elbtypes.LoadBalancerDescription, instances map[InstanceID]*ec2types.Instance, annotations map[string]string) error { if c.cfg.Global.DisableSecurityGroupIngress { return nil } // Determine the load balancer security group id - lbSecurityGroupIDs := aws.StringValueSlice(lb.SecurityGroups) + lbSecurityGroupIDs := lb.SecurityGroups if len(lbSecurityGroupIDs) == 0 { - return fmt.Errorf("could not determine security group for load balancer: %s", aws.StringValue(lb.LoadBalancerName)) + return fmt.Errorf("could not determine security group for load balancer: %s", aws.ToString(lb.LoadBalancerName)) } c.sortELBSecurityGroupList(lbSecurityGroupIDs, annotations) loadBalancerSecurityGroupID := lbSecurityGroupIDs[0] // Get the actual list of groups that allow ingress from the load-balancer - var actualGroups []*ec2.SecurityGroup + var actualGroups []*ec2types.SecurityGroup { describeRequest := &ec2.DescribeSecurityGroupsInput{} - describeRequest.Filters = []*ec2.Filter{ + describeRequest.Filters = []ec2types.Filter{ newEc2Filter("ip-permission.group-id", loadBalancerSecurityGroupID), } - response, err := c.ec2.DescribeSecurityGroups(describeRequest) + response, err := c.ec2.DescribeSecurityGroups(ctx, describeRequest) if err != nil { return fmt.Errorf("error querying security groups for ELB: %q", err) } @@ -4624,11 +4603,11 @@ func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancer if !c.tagging.hasClusterTag(sg.Tags) { continue } - actualGroups = append(actualGroups, sg) + actualGroups = append(actualGroups, &sg) } } - taggedSecurityGroups, err := c.getTaggedSecurityGroups() + taggedSecurityGroups, err := c.getTaggedSecurityGroups(ctx) if err != nil { return fmt.Errorf("error querying for tagged security groups: %q", err) } @@ -4649,10 +4628,10 @@ func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancer } if securityGroup == nil { - klog.Warning("Ignoring instance without security group: ", aws.StringValue(instance.InstanceId)) + klog.Warning("Ignoring instance without security group: ", aws.ToString(instance.InstanceId)) continue } - id := aws.StringValue(securityGroup.GroupId) + id := aws.ToString(securityGroup.GroupId) if id == "" { klog.Warningf("found security group without id: %v", securityGroup) continue @@ -4663,7 +4642,7 @@ func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancer // Compare to actual groups for _, actualGroup := range actualGroups { - actualGroupID := aws.StringValue(actualGroup.GroupId) + actualGroupID := aws.ToString(actualGroup.GroupId) if actualGroupID == "" { klog.Warning("Ignoring group without ID: ", actualGroup) continue @@ -4685,19 +4664,19 @@ func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancer } else { klog.V(2).Infof("Removing rule for traffic from the load balancer (%s) to instance (%s)", loadBalancerSecurityGroupID, instanceSecurityGroupID) } - sourceGroupID := &ec2.UserIdGroupPair{} + sourceGroupID := ec2types.UserIdGroupPair{} sourceGroupID.GroupId = &loadBalancerSecurityGroupID allProtocols := "-1" - permission := &ec2.IpPermission{} + permission := ec2types.IpPermission{} permission.IpProtocol = &allProtocols - permission.UserIdGroupPairs = []*ec2.UserIdGroupPair{sourceGroupID} + permission.UserIdGroupPairs = []ec2types.UserIdGroupPair{sourceGroupID} - permissions := []*ec2.IpPermission{permission} + permissions := []ec2types.IpPermission{permission} if add { - changed, err := c.addSecurityGroupIngress(instanceSecurityGroupID, permissions) + changed, err := c.addSecurityGroupIngress(ctx, instanceSecurityGroupID, permissions) if err != nil { return err } @@ -4705,7 +4684,7 @@ func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancer klog.Warning("Allowing ingress was not needed; concurrent change? groupId=", instanceSecurityGroupID) } } else { - changed, err := c.removeSecurityGroupIngress(instanceSecurityGroupID, permissions) + changed, err := c.removeSecurityGroupIngress(ctx, instanceSecurityGroupID, permissions) if err != nil { return err } @@ -4726,7 +4705,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin loadBalancerName := c.GetLoadBalancerName(ctx, clusterName, service) if isNLB(service.Annotations) { - lb, err := c.describeLoadBalancerv2(loadBalancerName) + lb, err := c.describeLoadBalancerv2(ctx, loadBalancerName) if err != nil { return err } @@ -4747,14 +4726,14 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin // * Clean up SecurityGroupRules { - targetGroups, err := c.elbv2.DescribeTargetGroups( + targetGroups, err := c.elbv2.DescribeTargetGroups(ctx, &elbv2.DescribeTargetGroupsInput{LoadBalancerArn: lb.LoadBalancerArn}, ) if err != nil { return fmt.Errorf("error listing target groups before deleting load balancer: %q", err) } - _, err = c.elbv2.DeleteLoadBalancer( + _, err = c.elbv2.DeleteLoadBalancer(ctx, &elbv2.DeleteLoadBalancerInput{LoadBalancerArn: lb.LoadBalancerArn}, ) if err != nil { @@ -4762,7 +4741,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin } for _, group := range targetGroups.TargetGroups { - _, err := c.elbv2.DeleteTargetGroup( + _, err := c.elbv2.DeleteTargetGroup(ctx, &elbv2.DeleteTargetGroupInput{TargetGroupArn: group.TargetGroupArn}, ) if err != nil { @@ -4771,10 +4750,10 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin } } - return c.updateInstanceSecurityGroupsForNLB(loadBalancerName, nil, nil, nil, nil) + return c.updateInstanceSecurityGroupsForNLB(ctx, loadBalancerName, nil, nil, nil, nil) } - lb, err := c.describeLoadBalancer(loadBalancerName) + lb, err := c.describeLoadBalancer(ctx, loadBalancerName) if err != nil { return err } @@ -4786,7 +4765,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin { // De-authorize the load balancer security group from the instances security group - err = c.updateInstanceSecurityGroupsForLoadBalancer(lb, nil, service.Annotations) + err = c.updateInstanceSecurityGroupsForLoadBalancer(ctx, lb, nil, service.Annotations) if err != nil { klog.Errorf("Error deregistering load balancer from instance security groups: %q", err) return err @@ -4798,7 +4777,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin request := &elb.DeleteLoadBalancerInput{} request.LoadBalancerName = lb.LoadBalancerName - _, err = c.elb.DeleteLoadBalancer(request) + _, err = c.elb.DeleteLoadBalancer(ctx, request) if err != nil { // TODO: Check if error was because load balancer was concurrently deleted klog.Errorf("Error deleting load balancer: %q", err) @@ -4811,13 +4790,13 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin // Note that this is annoying: the load balancer disappears from the API immediately, but it is still // deleting in the background. We get a DependencyViolation until the load balancer has deleted itself - var loadBalancerSGs = aws.StringValueSlice(lb.SecurityGroups) + var loadBalancerSGs = lb.SecurityGroups describeRequest := &ec2.DescribeSecurityGroupsInput{} - describeRequest.Filters = []*ec2.Filter{ + describeRequest.Filters = []ec2types.Filter{ newEc2Filter("group-id", loadBalancerSGs...), } - response, err := c.ec2.DescribeSecurityGroups(describeRequest) + response, err := c.ec2.DescribeSecurityGroups(ctx, describeRequest) if err != nil { return fmt.Errorf("error querying security groups for ELB: %q", err) } @@ -4834,7 +4813,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin } for _, sg := range response { - sgID := aws.StringValue(sg.GroupId) + sgID := aws.ToString(sg.GroupId) if sgID == c.cfg.Global.ElbSecurityGroup { //We don't want to delete a security group that was defined in the Cloud Configuration. @@ -4865,13 +4844,14 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin for securityGroupID := range securityGroupIDs { request := &ec2.DeleteSecurityGroupInput{} request.GroupId = &securityGroupID - _, err := c.ec2.DeleteSecurityGroup(request) + _, err := c.ec2.DeleteSecurityGroup(ctx, request) if err == nil { delete(securityGroupIDs, securityGroupID) } else { ignore := false - if awsError, ok := err.(awserr.Error); ok { - if awsError.Code() == "DependencyViolation" { + var ae smithy.APIError + if errors.As(err, &ae) { + if ae.ErrorCode() == "DependencyViolation" { klog.V(2).Infof("Ignoring DependencyViolation while deleting load-balancer security group (%s), assuming because LB is in process of deleting", securityGroupID) ignore = true } @@ -4910,13 +4890,13 @@ func (c *Cloud) UpdateLoadBalancer(ctx context.Context, clusterName string, serv if isLBExternal(service.Annotations) { return cloudprovider.ImplementedElsewhere } - instances, err := c.findInstancesForELB(nodes, service.Annotations) + instances, err := c.findInstancesForELB(ctx, nodes, service.Annotations) if err != nil { return err } loadBalancerName := c.GetLoadBalancerName(ctx, clusterName, service) if isNLB(service.Annotations) { - lb, err := c.describeLoadBalancerv2(loadBalancerName) + lb, err := c.describeLoadBalancerv2(ctx, loadBalancerName) if err != nil { return err } @@ -4926,7 +4906,7 @@ func (c *Cloud) UpdateLoadBalancer(ctx context.Context, clusterName string, serv _, err = c.EnsureLoadBalancer(ctx, clusterName, service, nodes) return err } - lb, err := c.describeLoadBalancer(loadBalancerName) + lb, err := c.describeLoadBalancer(ctx, loadBalancerName) if err != nil { return err } @@ -4936,25 +4916,25 @@ func (c *Cloud) UpdateLoadBalancer(ctx context.Context, clusterName string, serv } if sslPolicyName, ok := service.Annotations[ServiceAnnotationLoadBalancerSSLNegotiationPolicy]; ok { - err := c.ensureSSLNegotiationPolicy(lb, sslPolicyName) + err := c.ensureSSLNegotiationPolicy(ctx, lb, sslPolicyName) if err != nil { return err } for _, port := range c.getLoadBalancerTLSPorts(lb) { - err := c.setSSLNegotiationPolicy(loadBalancerName, sslPolicyName, port) + err := c.setSSLNegotiationPolicy(ctx, loadBalancerName, sslPolicyName, port) if err != nil { return err } } } - err = c.ensureLoadBalancerInstances(aws.StringValue(lb.LoadBalancerName), lb.Instances, instances) + err = c.ensureLoadBalancerInstances(ctx, aws.ToString(lb.LoadBalancerName), lb.Instances, instances) if err != nil { klog.Warningf("Error registering/deregistering instances with the load balancer: %q", err) return err } - err = c.updateInstanceSecurityGroupsForLoadBalancer(lb, instances, service.Annotations) + err = c.updateInstanceSecurityGroupsForLoadBalancer(ctx, lb, instances, service.Annotations) if err != nil { return err } @@ -4963,8 +4943,8 @@ func (c *Cloud) UpdateLoadBalancer(ctx context.Context, clusterName string, serv } // Returns the instance with the specified ID -func (c *Cloud) getInstanceByID(instanceID string) (*ec2.Instance, error) { - instances, err := c.getInstancesByIDs([]*string{&instanceID}) +func (c *Cloud) getInstanceByID(ctx context.Context, instanceID string) (*ec2types.Instance, error) { + instances, err := c.getInstancesByIDs(ctx, []string{instanceID}) if err != nil { return nil, err } @@ -4979,8 +4959,8 @@ func (c *Cloud) getInstanceByID(instanceID string) (*ec2.Instance, error) { return instances[instanceID], nil } -func (c *Cloud) getInstancesByIDs(instanceIDs []*string) (map[string]*ec2.Instance, error) { - instancesByID := make(map[string]*ec2.Instance) +func (c *Cloud) getInstancesByIDs(ctx context.Context, instanceIDs []string) (map[string]*ec2types.Instance, error) { + instancesByID := make(map[string]*ec2types.Instance) if len(instanceIDs) == 0 { return instancesByID, nil } @@ -4989,46 +4969,45 @@ func (c *Cloud) getInstancesByIDs(instanceIDs []*string) (map[string]*ec2.Instan InstanceIds: instanceIDs, } - instances, err := c.ec2.DescribeInstances(request) + instances, err := c.ec2.DescribeInstances(ctx, request) if err != nil { return nil, err } for _, instance := range instances { - instanceID := aws.StringValue(instance.InstanceId) + instanceID := aws.ToString(instance.InstanceId) if instanceID == "" { continue } - instancesByID[instanceID] = instance + instancesByID[instanceID] = &instance } return instancesByID, nil } -func (c *Cloud) getInstancesByNodeNames(nodeNames []string, states ...string) ([]*ec2.Instance, error) { - names := aws.StringSlice(nodeNames) - ec2Instances := []*ec2.Instance{} +func (c *Cloud) getInstancesByNodeNames(ctx context.Context, nodeNames []string, states ...string) ([]*ec2types.Instance, error) { + ec2Instances := []*ec2types.Instance{} - for i := 0; i < len(names); i += filterNodeLimit { + for i := 0; i < len(nodeNames); i += filterNodeLimit { end := i + filterNodeLimit - if end > len(names) { - end = len(names) + if end > len(nodeNames) { + end = len(nodeNames) } - nameSlice := names[i:end] + nameSlice := nodeNames[i:end] - nodeNameFilter := &ec2.Filter{ + nodeNameFilter := ec2types.Filter{ Name: aws.String("private-dns-name"), Values: nameSlice, } - filters := []*ec2.Filter{nodeNameFilter} + filters := []ec2types.Filter{nodeNameFilter} if len(states) > 0 { filters = append(filters, newEc2Filter("instance-state-name", states...)) } - instances, err := c.describeInstances(filters) + instances, err := c.describeInstances(ctx, filters) if err != nil { klog.V(2).Infof("Failed to describe instances %v", nodeNames) return nil, err @@ -5044,20 +5023,20 @@ func (c *Cloud) getInstancesByNodeNames(nodeNames []string, states ...string) ([ } // TODO: Move to instanceCache -func (c *Cloud) describeInstances(filters []*ec2.Filter) ([]*ec2.Instance, error) { +func (c *Cloud) describeInstances(ctx context.Context, filters []ec2types.Filter) ([]*ec2types.Instance, error) { request := &ec2.DescribeInstancesInput{ Filters: filters, } - response, err := c.ec2.DescribeInstances(request) + response, err := c.ec2.DescribeInstances(ctx, request) if err != nil { return nil, err } - var matches []*ec2.Instance + var matches []*ec2types.Instance for _, instance := range response { if c.tagging.hasClusterTag(instance.Tags) { - matches = append(matches, instance) + matches = append(matches, &instance) } } return matches, nil @@ -5097,29 +5076,29 @@ func mapNodeNameToPrivateDNSName(nodeName types.NodeName) string { // // Deprecated: use instanceIDToNodeName instead. See // mapNodeNameToPrivateDNSName for details. -func mapInstanceToNodeName(i *ec2.Instance) types.NodeName { - return types.NodeName(aws.StringValue(i.PrivateDnsName)) +func mapInstanceToNodeName(i *ec2types.Instance) types.NodeName { + return types.NodeName(aws.ToString(i.PrivateDnsName)) } var aliveFilter = []string{ - ec2.InstanceStateNamePending, - ec2.InstanceStateNameRunning, - ec2.InstanceStateNameShuttingDown, - ec2.InstanceStateNameStopping, - ec2.InstanceStateNameStopped, + string(ec2types.InstanceStateNamePending), + string(ec2types.InstanceStateNameRunning), + string(ec2types.InstanceStateNameShuttingDown), + string(ec2types.InstanceStateNameStopping), + string(ec2types.InstanceStateNameStopped), } // Returns the instance with the specified node name // Returns nil if it does not exist -func (c *Cloud) findInstanceByNodeName(nodeName types.NodeName) (*ec2.Instance, error) { +func (c *Cloud) findInstanceByNodeName(ctx context.Context, nodeName types.NodeName) (*ec2types.Instance, error) { privateDNSName := mapNodeNameToPrivateDNSName(nodeName) - filters := []*ec2.Filter{ + filters := []ec2types.Filter{ newEc2Filter("private-dns-name", privateDNSName), // exclude instances in "terminated" state newEc2Filter("instance-state-name", aliveFilter...), } - instances, err := c.describeInstances(filters) + instances, err := c.describeInstances(ctx, filters) if err != nil { return nil, err } @@ -5135,8 +5114,8 @@ func (c *Cloud) findInstanceByNodeName(nodeName types.NodeName) (*ec2.Instance, // Returns the instance with the specified node name // Like findInstanceByNodeName, but returns error if node not found -func (c *Cloud) getInstanceByNodeName(nodeName types.NodeName) (*ec2.Instance, error) { - var instance *ec2.Instance +func (c *Cloud) getInstanceByNodeName(ctx context.Context, nodeName types.NodeName) (*ec2types.Instance, error) { + var instance *ec2types.Instance // we leverage node cache to try to retrieve node's instance id first, as // get instance by instance id is way more efficient than by filters in @@ -5144,9 +5123,9 @@ func (c *Cloud) getInstanceByNodeName(nodeName types.NodeName) (*ec2.Instance, e awsID, err := c.nodeNameToInstanceID(nodeName) if err != nil { klog.V(3).Infof("Unable to convert node name %q to aws instanceID, fall back to findInstanceByNodeName: %v", nodeName, err) - instance, err = c.findInstanceByNodeName(nodeName) + instance, err = c.findInstanceByNodeName(ctx, nodeName) } else { - instance, err = c.getInstanceByID(string(awsID)) + instance, err = c.getInstanceByID(ctx, string(awsID)) } if err == nil && instance == nil { return nil, cloudprovider.InstanceNotFound @@ -5154,12 +5133,12 @@ func (c *Cloud) getInstanceByNodeName(nodeName types.NodeName) (*ec2.Instance, e return instance, err } -func (c *Cloud) getFullInstance(nodeName types.NodeName) (*awsInstance, *ec2.Instance, error) { +func (c *Cloud) getFullInstance(ctx context.Context, nodeName types.NodeName) (*awsInstance, *ec2types.Instance, error) { if nodeName == "" { - instance, err := c.getInstanceByID(c.selfAWSInstance.awsID) + instance, err := c.getInstanceByID(ctx, c.selfAWSInstance.awsID) return c.selfAWSInstance, instance, err } - instance, err := c.getInstanceByNodeName(nodeName) + instance, err := c.getInstanceByNodeName(ctx, nodeName) if err != nil { return nil, nil, err } @@ -5277,10 +5256,10 @@ func getInitialAttachDetachDelay(status string) time.Duration { } // describeNetworkInterfaces returns network interface information for the given DNS name. -func (c *Cloud) describeNetworkInterfaces(nodeName string) (*ec2.NetworkInterface, error) { +func (c *Cloud) describeNetworkInterfaces(ctx context.Context, nodeName string) (*ec2types.NetworkInterface, error) { eniEndpoint := strings.TrimPrefix(nodeName, fargateNodeNamePrefix) - filters := []*ec2.Filter{ + filters := []ec2types.Filter{ newEc2Filter("attachment.status", "attached"), newEc2Filter("vpc-id", c.vpcID), } @@ -5298,7 +5277,7 @@ func (c *Cloud) describeNetworkInterfaces(nodeName string) (*ec2.NetworkInterfac Filters: filters, } - eni, err := c.ec2.DescribeNetworkInterfaces(request) + eni, err := c.ec2.DescribeNetworkInterfaces(ctx, request) if err != nil { return nil, err } @@ -5307,12 +5286,12 @@ func (c *Cloud) describeNetworkInterfaces(nodeName string) (*ec2.NetworkInterfac } if len(eni.NetworkInterfaces) != 1 { // This should not be possible - ids should be unique - return nil, fmt.Errorf("multiple interfaces found with same id %q", eni.NetworkInterfaces) + return nil, fmt.Errorf("multiple interfaces found with same id %v", eni.NetworkInterfaces) } - return eni.NetworkInterfaces[0], nil + return &eni.NetworkInterfaces[0], nil } -func getRegionFromMetadata(cfg CloudConfig, metadata EC2Metadata) (string, error) { +func getRegionFromMetadata(ctx context.Context, cfg CloudConfig, metadata EC2Metadata) (string, error) { // For backwards compatibility reasons, keeping this check to avoid breaking possible // cases where Zone was set to override the region configuration. Otherwise, fall back // to getting region the standard way. @@ -5323,5 +5302,5 @@ func getRegionFromMetadata(cfg CloudConfig, metadata EC2Metadata) (string, error return azToRegion(zone) } - return cfg.GetRegion(metadata) + return cfg.GetRegion(ctx, metadata) } diff --git a/pkg/providers/v1/aws_assumerole_provider.go b/pkg/providers/v1/aws_assumerole_provider.go deleted file mode 100644 index ad5a63b4c7..0000000000 --- a/pkg/providers/v1/aws_assumerole_provider.go +++ /dev/null @@ -1,62 +0,0 @@ -/* -Copyright 2014 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package aws - -import ( - "sync" - "time" - - "github.com/aws/aws-sdk-go/aws/credentials" -) - -const ( - invalidateCredsAfter = 1 * time.Second -) - -// assumeRoleProviderWithRateLimiting makes sure we call the underlying provider only -// once after `invalidateCredsAfter` period -type assumeRoleProviderWithRateLimiting struct { - provider credentials.Provider - invalidateCredsAfter time.Duration - sync.RWMutex - lastError error - lastValue credentials.Value - lastRetrieveTime time.Time -} - -func assumeRoleProvider(provider credentials.Provider) credentials.Provider { - return &assumeRoleProviderWithRateLimiting{provider: provider, - invalidateCredsAfter: invalidateCredsAfter} -} - -func (l *assumeRoleProviderWithRateLimiting) Retrieve() (credentials.Value, error) { - l.Lock() - defer l.Unlock() - if time.Since(l.lastRetrieveTime) < l.invalidateCredsAfter { - if l.lastError != nil { - return credentials.Value{}, l.lastError - } - return l.lastValue, nil - } - l.lastValue, l.lastError = l.provider.Retrieve() - l.lastRetrieveTime = time.Now() - return l.lastValue, l.lastError -} - -func (l *assumeRoleProviderWithRateLimiting) IsExpired() bool { - return l.provider.IsExpired() -} diff --git a/pkg/providers/v1/aws_assumerole_provider_test.go b/pkg/providers/v1/aws_assumerole_provider_test.go deleted file mode 100644 index db5af7355a..0000000000 --- a/pkg/providers/v1/aws_assumerole_provider_test.go +++ /dev/null @@ -1,132 +0,0 @@ -/* -Copyright 2014 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package aws - -import ( - "fmt" - "reflect" - "sync" - "testing" - "time" - - "github.com/aws/aws-sdk-go/aws/credentials" -) - -func Test_assumeRoleProviderWithRateLimiting_Retrieve(t *testing.T) { - type fields struct { - provider credentials.Provider - invalidateCredsAfter time.Duration - RWMutex sync.RWMutex - lastError error - lastValue credentials.Value - lastRetrieveTime time.Time - } - tests := []struct { - name string - fields *fields - want credentials.Value - wantProviderCalled bool - sleepBeforeCallingProvider time.Duration - wantErr bool - wantErrString string - }{{ - name: "Call assume role provider and verify access ID returned", - fields: &fields{provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID"}}, - want: credentials.Value{AccessKeyID: "fakeID"}, - wantProviderCalled: true, - }, { - name: "Immediate call to assume role API, shouldn't call the underlying provider and return the last value", - fields: &fields{ - provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID"}, - invalidateCredsAfter: 100 * time.Millisecond, - lastValue: credentials.Value{AccessKeyID: "fakeID1"}, - lastRetrieveTime: time.Now(), - }, - want: credentials.Value{AccessKeyID: "fakeID1"}, - wantProviderCalled: false, - sleepBeforeCallingProvider: 10 * time.Millisecond, - }, { - name: "Assume role provider returns an error when trying to assume a role", - fields: &fields{ - provider: &fakeAssumeRoleProvider{err: fmt.Errorf("can't assume fake role")}, - invalidateCredsAfter: 10 * time.Millisecond, - lastRetrieveTime: time.Now(), - }, - wantProviderCalled: true, - wantErr: true, - wantErrString: "can't assume fake role", - sleepBeforeCallingProvider: 15 * time.Millisecond, - }, { - name: "Immediate call to assume role API, shouldn't call the underlying provider and return the last error value", - fields: &fields{ - provider: &fakeAssumeRoleProvider{}, - invalidateCredsAfter: 100 * time.Millisecond, - lastRetrieveTime: time.Now(), - }, - want: credentials.Value{}, - wantProviderCalled: false, - wantErr: true, - wantErrString: "can't assume fake role", - }, { - name: "Delayed call to assume role API, should call the underlying provider", - fields: &fields{ - provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID2"}, - invalidateCredsAfter: 20 * time.Millisecond, - lastRetrieveTime: time.Now(), - }, - want: credentials.Value{AccessKeyID: "fakeID2"}, - wantProviderCalled: true, - sleepBeforeCallingProvider: 25 * time.Millisecond, - }} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - l := &assumeRoleProviderWithRateLimiting{ - provider: tt.fields.provider, - invalidateCredsAfter: tt.fields.invalidateCredsAfter, - lastError: tt.fields.lastError, - lastValue: tt.fields.lastValue, - lastRetrieveTime: tt.fields.lastRetrieveTime, - } - time.Sleep(tt.sleepBeforeCallingProvider) - got, err := l.Retrieve() - if (err != nil) != tt.wantErr && (tt.wantErr && reflect.DeepEqual(err, tt.wantErrString)) { - t.Errorf("assumeRoleProviderWithRateLimiting.Retrieve() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("assumeRoleProviderWithRateLimiting.Retrieve() got = %v, want %v", got, tt.want) - return - } - if tt.wantProviderCalled != tt.fields.provider.(*fakeAssumeRoleProvider).providerCalled { - t.Errorf("provider called %v, want %v", tt.fields.provider.(*fakeAssumeRoleProvider).providerCalled, tt.wantProviderCalled) - } - }) - } -} - -type fakeAssumeRoleProvider struct { - accesskeyID string - err error - providerCalled bool -} - -func (f *fakeAssumeRoleProvider) Retrieve() (credentials.Value, error) { - f.providerCalled = true - return credentials.Value{AccessKeyID: f.accesskeyID}, f.err -} - -func (f *fakeAssumeRoleProvider) IsExpired() bool { return true } diff --git a/pkg/providers/v1/aws_fakes.go b/pkg/providers/v1/aws_fakes.go index fa57928076..f97a3a4d55 100644 --- a/pkg/providers/v1/aws_fakes.go +++ b/pkg/providers/v1/aws_fakes.go @@ -17,27 +17,31 @@ limitations under the License. package aws import ( + "context" "errors" "fmt" + "io" "sort" "strconv" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/autoscaling" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elbv2" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/aws" + stscredsv2 "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/kms" "k8s.io/klog/v2" ) // FakeAWSServices is an fake AWS session used for testing type FakeAWSServices struct { region string - instances []*ec2.Instance - selfInstance *ec2.Instance + instances []*ec2types.Instance + selfInstance *ec2types.Instance networkInterfacesMacs []string networkInterfacesPrivateIPs [][]string networkInterfacesVpcIDs []string @@ -66,23 +70,23 @@ func NewFakeAWSServices(clusterID string) *FakeAWSServices { s.networkInterfacesMacs = []string{"aa:bb:cc:dd:ee:00", "aa:bb:cc:dd:ee:01"} s.networkInterfacesVpcIDs = []string{"vpc-mac0", "vpc-mac1"} - selfInstance := &ec2.Instance{} + selfInstance := &ec2types.Instance{} selfInstance.InstanceId = aws.String("i-self") - selfInstance.Placement = &ec2.Placement{ + selfInstance.Placement = &ec2types.Placement{ AvailabilityZone: aws.String("us-east-1a"), } selfInstance.PrivateDnsName = aws.String("ip-172-20-0-100.ec2.internal") selfInstance.PrivateIpAddress = aws.String("192.168.0.1") selfInstance.PublicIpAddress = aws.String("1.2.3.4") s.selfInstance = selfInstance - s.instances = []*ec2.Instance{selfInstance} + s.instances = []*ec2types.Instance{selfInstance} - selfInstance.NetworkInterfaces = []*ec2.InstanceNetworkInterface{ + selfInstance.NetworkInterfaces = []ec2types.InstanceNetworkInterface{ { - Attachment: &ec2.InstanceNetworkInterfaceAttachment{ - DeviceIndex: aws.Int64(1), + Attachment: &ec2types.InstanceNetworkInterfaceAttachment{ + DeviceIndex: aws.Int32(1), }, - PrivateIpAddresses: []*ec2.InstancePrivateIpAddress{ + PrivateIpAddresses: []ec2types.InstancePrivateIpAddress{ { Primary: aws.Bool(true), PrivateDnsName: aws.String("ip-172-20-1-100.ec2.internal"), @@ -94,13 +98,13 @@ func NewFakeAWSServices(clusterID string) *FakeAWSServices { PrivateIpAddress: aws.String("172.20.1.2"), }, }, - Status: aws.String(ec2.NetworkInterfaceStatusInUse), + Status: ec2types.NetworkInterfaceStatusInUse, }, { - Attachment: &ec2.InstanceNetworkInterfaceAttachment{ - DeviceIndex: aws.Int64(0), + Attachment: &ec2types.InstanceNetworkInterfaceAttachment{ + DeviceIndex: aws.Int32(0), }, - PrivateIpAddresses: []*ec2.InstancePrivateIpAddress{ + PrivateIpAddresses: []ec2types.InstancePrivateIpAddress{ { Primary: aws.Bool(true), PrivateDnsName: aws.String("ip-172-20-0-100.ec2.internal"), @@ -112,15 +116,14 @@ func NewFakeAWSServices(clusterID string) *FakeAWSServices { PrivateIpAddress: aws.String("172.20.0.101"), }, }, - Status: aws.String(ec2.NetworkInterfaceStatusInUse), + Status: ec2types.NetworkInterfaceStatusInUse, }, } - var tag ec2.Tag + var tag ec2types.Tag tag.Key = aws.String(TagNameKubernetesClusterLegacy) tag.Value = aws.String(clusterID) - selfInstance.Tags = []*ec2.Tag{&tag} - + selfInstance.Tags = []ec2types.Tag{tag} s.callCounts = make(map[string]int) return s @@ -129,7 +132,7 @@ func NewFakeAWSServices(clusterID string) *FakeAWSServices { // WithAz sets the ec2 placement availability zone func (s *FakeAWSServices) WithAz(az string) *FakeAWSServices { if s.selfInstance.Placement == nil { - s.selfInstance.Placement = &ec2.Placement{} + s.selfInstance.Placement = &ec2types.Placement{} } s.selfInstance.Placement.AvailabilityZone = aws.String(az) return s @@ -151,56 +154,56 @@ func (s *FakeAWSServices) countCall(service string, api string, resourceID strin } // Compute returns a fake EC2 client -func (s *FakeAWSServices) Compute(region string) (EC2, error) { +func (s *FakeAWSServices) Compute(ctx context.Context, region string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (EC2, error) { return s.ec2, nil } // LoadBalancing returns a fake ELB client -func (s *FakeAWSServices) LoadBalancing(region string) (ELB, error) { +func (s *FakeAWSServices) LoadBalancing(ctx context.Context, region string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (ELB, error) { return s.elb, nil } // LoadBalancingV2 returns a fake ELBV2 client -func (s *FakeAWSServices) LoadBalancingV2(region string) (ELBV2, error) { +func (s *FakeAWSServices) LoadBalancingV2(ctx context.Context, region string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (ELBV2, error) { return s.elbv2, nil } // Autoscaling returns a fake ASG client -func (s *FakeAWSServices) Autoscaling(region string) (ASG, error) { +func (s *FakeAWSServices) Autoscaling(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (ASG, error) { return s.asg, nil } // Metadata returns a fake EC2Metadata client -func (s *FakeAWSServices) Metadata() (EC2Metadata, error) { +func (s *FakeAWSServices) Metadata(ctx context.Context) (EC2Metadata, error) { return s.metadata, nil } // KeyManagement returns a fake KMS client -func (s *FakeAWSServices) KeyManagement(region string) (KMS, error) { +func (s *FakeAWSServices) KeyManagement(ctx context.Context, regionName string, assumeRoleProvider *stscredsv2.AssumeRoleProvider) (KMS, error) { return s.kms, nil } // FakeEC2 is a fake EC2 client used for testing type FakeEC2 interface { EC2 - CreateSubnet(*ec2.Subnet) (*ec2.CreateSubnetOutput, error) + CreateSubnet(*ec2types.Subnet) (*ec2.CreateSubnetOutput, error) RemoveSubnets() - CreateRouteTable(*ec2.RouteTable) (*ec2.CreateRouteTableOutput, error) + CreateRouteTable(*ec2types.RouteTable) (*ec2.CreateRouteTableOutput, error) RemoveRouteTables() } // FakeEC2Impl is an implementation of the FakeEC2 interface used for testing type FakeEC2Impl struct { aws *FakeAWSServices - Subnets []*ec2.Subnet + Subnets []ec2types.Subnet DescribeSubnetsInput *ec2.DescribeSubnetsInput - RouteTables []*ec2.RouteTable + RouteTables []ec2types.RouteTable DescribeRouteTablesInput *ec2.DescribeRouteTablesInput } // DescribeInstances returns fake instance descriptions -func (ec2i *FakeEC2Impl) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*ec2.Instance, error) { - matches := []*ec2.Instance{} +func (ec2i *FakeEC2Impl) DescribeInstances(ctx context.Context, request *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) ([]ec2types.Instance, error) { + matches := []ec2types.Instance{} for _, instance := range ec2i.aws.instances { if request.InstanceIds != nil { if instance.InstanceId == nil { @@ -210,7 +213,7 @@ func (ec2i *FakeEC2Impl) DescribeInstances(request *ec2.DescribeInstancesInput) found := false for _, instanceID := range request.InstanceIds { - if *instanceID == *instance.InstanceId { + if instanceID == *instance.InstanceId { found = true break } @@ -231,81 +234,86 @@ func (ec2i *FakeEC2Impl) DescribeInstances(request *ec2.DescribeInstancesInput) continue } } - matches = append(matches, instance) + matches = append(matches, *instance) } return matches, nil } +// DescribeInstanceTopology is not implemented but is required for interface conformance +func (ec2i *FakeEC2Impl) DescribeInstanceTopology(ctx context.Context, request *ec2.DescribeInstanceTopologyInput, optFns ...func(*ec2.Options)) ([]ec2types.InstanceTopology, error) { + panic("Not implemented") +} + // AttachVolume is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) AttachVolume(request *ec2.AttachVolumeInput) (resp *ec2.VolumeAttachment, err error) { +func (ec2i *FakeEC2Impl) AttachVolume(ctx context.Context, request *ec2.AttachVolumeInput, optFns ...func(*ec2.Options)) (resp *ec2.AttachVolumeOutput, err error) { panic("Not implemented") } // DetachVolume is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) DetachVolume(request *ec2.DetachVolumeInput) (resp *ec2.VolumeAttachment, err error) { +func (ec2i *FakeEC2Impl) DetachVolume(ctx context.Context, request *ec2.DetachVolumeInput, optFns ...func(*ec2.Options)) (resp *ec2.DetachVolumeOutput, err error) { panic("Not implemented") } // DescribeVolumes is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) DescribeVolumes(request *ec2.DescribeVolumesInput) ([]*ec2.Volume, error) { +func (ec2i *FakeEC2Impl) DescribeVolumes(ctx context.Context, request *ec2.DescribeVolumesInput, optFns ...func(*ec2.Options)) ([]ec2types.Volume, error) { panic("Not implemented") } // CreateVolume is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) CreateVolume(request *ec2.CreateVolumeInput) (resp *ec2.Volume, err error) { +func (ec2i *FakeEC2Impl) CreateVolume(ctx context.Context, request *ec2.CreateVolumeInput, optFns ...func(*ec2.Options)) (resp *ec2.CreateVolumeOutput, err error) { panic("Not implemented") } // DeleteVolume is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) DeleteVolume(request *ec2.DeleteVolumeInput) (resp *ec2.DeleteVolumeOutput, err error) { +func (ec2i *FakeEC2Impl) DeleteVolume(ctx context.Context, request *ec2.DeleteVolumeInput, optFns ...func(*ec2.Options)) (resp *ec2.DeleteVolumeOutput, err error) { panic("Not implemented") } // DescribeSecurityGroups is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) { +func (ec2i *FakeEC2Impl) DescribeSecurityGroups(ctx context.Context, request *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) ([]ec2types.SecurityGroup, error) { panic("Not implemented") } // CreateSecurityGroup is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) CreateSecurityGroup(*ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) { +func (ec2i *FakeEC2Impl) CreateSecurityGroup(ctx context.Context, request *ec2.CreateSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error) { panic("Not implemented") } // DeleteSecurityGroup is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) DeleteSecurityGroup(*ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error) { +func (ec2i *FakeEC2Impl) DeleteSecurityGroup(ctx context.Context, request *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error) { panic("Not implemented") } // AuthorizeSecurityGroupIngress is not implemented but is required for // interface conformance -func (ec2i *FakeEC2Impl) AuthorizeSecurityGroupIngress(*ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error) { +func (ec2i *FakeEC2Impl) AuthorizeSecurityGroupIngress(ctx context.Context, request *ec2.AuthorizeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error) { panic("Not implemented") } // RevokeSecurityGroupIngress is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) RevokeSecurityGroupIngress(*ec2.RevokeSecurityGroupIngressInput) (*ec2.RevokeSecurityGroupIngressOutput, error) { +func (ec2i *FakeEC2Impl) RevokeSecurityGroupIngress(ctx context.Context, request *ec2.RevokeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error) { panic("Not implemented") } // DescribeVolumeModifications is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) DescribeVolumeModifications(*ec2.DescribeVolumesModificationsInput) ([]*ec2.VolumeModification, error) { +func (ec2i *FakeEC2Impl) DescribeVolumeModifications(ctx context.Context, request *ec2.DescribeVolumesModificationsInput, optFns ...func(*ec2.Options)) ([]ec2types.VolumeModification, error) { panic("Not implemented") } // ModifyVolume is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) ModifyVolume(*ec2.ModifyVolumeInput) (*ec2.ModifyVolumeOutput, error) { +func (ec2i *FakeEC2Impl) ModifyVolume(ctx context.Context, request *ec2.ModifyVolumeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyVolumeOutput, error) { panic("Not implemented") } // CreateSubnet creates fake subnets -func (ec2i *FakeEC2Impl) CreateSubnet(request *ec2.Subnet) (*ec2.CreateSubnetOutput, error) { - ec2i.Subnets = append(ec2i.Subnets, request) +func (ec2i *FakeEC2Impl) CreateSubnet(request *ec2types.Subnet) (*ec2.CreateSubnetOutput, error) { + ec2i.Subnets = append(ec2i.Subnets, *request) response := &ec2.CreateSubnetOutput{ Subnet: request, } @@ -313,7 +321,7 @@ func (ec2i *FakeEC2Impl) CreateSubnet(request *ec2.Subnet) (*ec2.CreateSubnetOut } // DescribeSubnets returns fake subnet descriptions -func (ec2i *FakeEC2Impl) DescribeSubnets(request *ec2.DescribeSubnetsInput) ([]*ec2.Subnet, error) { +func (ec2i *FakeEC2Impl) DescribeSubnets(ctx context.Context, request *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) ([]ec2types.Subnet, error) { ec2i.DescribeSubnetsInput = request return ec2i.Subnets, nil } @@ -325,8 +333,8 @@ func (ec2i *FakeEC2Impl) RemoveSubnets() { // DescribeAvailabilityZones returns fake availability zones // For every input returns a hardcoded list of fake availability zones for the moment -func (ec2i *FakeEC2Impl) DescribeAvailabilityZones(request *ec2.DescribeAvailabilityZonesInput) ([]*ec2.AvailabilityZone, error) { - var azs []*ec2.AvailabilityZone +func (ec2i *FakeEC2Impl) DescribeAvailabilityZones(ctx context.Context, request *ec2.DescribeAvailabilityZonesInput, optFns ...func(*ec2.Options)) ([]ec2types.AvailabilityZone, error) { + var azs []ec2types.AvailabilityZone fakeZones := [5]string{"az-local", "az-wavelength", "us-west-2a", "us-west-2b", "us-west-2c"} for _, name := range fakeZones { @@ -339,32 +347,32 @@ func (ec2i *FakeEC2Impl) DescribeAvailabilityZones(request *ec2.DescribeAvailabi default: zoneType = aws.String(regularAvailabilityZoneType) } - zone := &ec2.AvailabilityZone{ZoneName: aws.String(name), ZoneType: zoneType, ZoneId: aws.String(name)} + zone := ec2types.AvailabilityZone{ZoneName: aws.String(name), ZoneType: zoneType, ZoneId: aws.String(name)} azs = append(azs, zone) } return azs, nil } // CreateTags is a mock for CreateTags from EC2 -func (ec2i *FakeEC2Impl) CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) { +func (ec2i *FakeEC2Impl) CreateTags(ctx context.Context, input *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) { for _, id := range input.Resources { - callCount := ec2i.aws.countCall("ec2", "CreateTags", *id) - if *id == "i-error" { + callCount := ec2i.aws.countCall("ec2", "CreateTags", id) + if id == "i-error" { return nil, errors.New("Unable to tag") } - if *id == "i-not-found" { - return nil, awserr.New("InvalidInstanceID.NotFound", "Instance not found", nil) + if id == "i-not-found" { + return nil, errors.New("InvalidInstanceID.NotFound: Instance not found") } // return an Instance not found error for the first `n` calls // instance ID should be of the format `i-not-found-count-$N-$SUFFIX` - if strings.HasPrefix(*id, "i-not-found-count-") { - notFoundCount, err := strconv.Atoi(strings.Split(*id, "-")[4]) + if strings.HasPrefix(id, "i-not-found-count-") { + notFoundCount, err := strconv.Atoi(strings.Split(id, "-")[4]) if err != nil { panic(err) } if callCount < notFoundCount { - return nil, awserr.New("InvalidInstanceID.NotFound", "Instance not found", nil) + return nil, errors.New("InvalidInstanceID.NotFound: Instance not found") } } } @@ -372,28 +380,28 @@ func (ec2i *FakeEC2Impl) CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTags } // DeleteTags is a mock for DeleteTags from EC2 -func (ec2i *FakeEC2Impl) DeleteTags(input *ec2.DeleteTagsInput) (*ec2.DeleteTagsOutput, error) { +func (ec2i *FakeEC2Impl) DeleteTags(ctx context.Context, input *ec2.DeleteTagsInput, optFns ...func(*ec2.Options)) (*ec2.DeleteTagsOutput, error) { for _, id := range input.Resources { - if *id == "i-error" { + if id == "i-error" { return nil, errors.New("Unable to remove tag") } - if *id == "i-not-found" { - return nil, awserr.New("InvalidInstanceID.NotFound", "Instance not found", nil) + if id == "i-not-found" { + return nil, errors.New("InvalidInstanceID.NotFound: Instance not found") } } return &ec2.DeleteTagsOutput{}, nil } // DescribeRouteTables returns fake route table descriptions -func (ec2i *FakeEC2Impl) DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ([]*ec2.RouteTable, error) { +func (ec2i *FakeEC2Impl) DescribeRouteTables(ctx context.Context, request *ec2.DescribeRouteTablesInput, optFns ...func(*ec2.Options)) ([]ec2types.RouteTable, error) { ec2i.DescribeRouteTablesInput = request return ec2i.RouteTables, nil } // CreateRouteTable creates fake route tables -func (ec2i *FakeEC2Impl) CreateRouteTable(request *ec2.RouteTable) (*ec2.CreateRouteTableOutput, error) { - ec2i.RouteTables = append(ec2i.RouteTables, request) +func (ec2i *FakeEC2Impl) CreateRouteTable(request *ec2types.RouteTable) (*ec2.CreateRouteTableOutput, error) { + ec2i.RouteTables = append(ec2i.RouteTables, *request) response := &ec2.CreateRouteTableOutput{ RouteTable: request, } @@ -406,24 +414,24 @@ func (ec2i *FakeEC2Impl) RemoveRouteTables() { } // CreateRoute is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) CreateRoute(request *ec2.CreateRouteInput) (*ec2.CreateRouteOutput, error) { +func (ec2i *FakeEC2Impl) CreateRoute(ctx context.Context, request *ec2.CreateRouteInput, optFns ...func(*ec2.Options)) (*ec2.CreateRouteOutput, error) { panic("Not implemented") } // DeleteRoute is not implemented but is required for interface conformance -func (ec2i *FakeEC2Impl) DeleteRoute(request *ec2.DeleteRouteInput) (*ec2.DeleteRouteOutput, error) { +func (ec2i *FakeEC2Impl) DeleteRoute(ctx context.Context, request *ec2.DeleteRouteInput, optFns ...func(*ec2.Options)) (*ec2.DeleteRouteOutput, error) { panic("Not implemented") } // ModifyInstanceAttribute is not implemented but is required for interface // conformance -func (ec2i *FakeEC2Impl) ModifyInstanceAttribute(request *ec2.ModifyInstanceAttributeInput) (*ec2.ModifyInstanceAttributeOutput, error) { +func (ec2i *FakeEC2Impl) ModifyInstanceAttribute(ctx context.Context, request *ec2.ModifyInstanceAttributeInput, optFns ...func(*ec2.Options)) (*ec2.ModifyInstanceAttributeOutput, error) { panic("Not implemented") } // DescribeVpcs returns fake VPC descriptions -func (ec2i *FakeEC2Impl) DescribeVpcs(request *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) { - return &ec2.DescribeVpcsOutput{Vpcs: []*ec2.Vpc{{CidrBlock: aws.String("172.20.0.0/16")}}}, nil +func (ec2i *FakeEC2Impl) DescribeVpcs(ctx context.Context, request *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) { + return &ec2.DescribeVpcsOutput{Vpcs: []ec2types.Vpc{{CidrBlock: aws.String("172.20.0.0/16")}}}, nil } // FakeMetadata is a fake EC2 metadata service client used for testing @@ -432,25 +440,26 @@ type FakeMetadata struct { } // GetMetadata returns fake EC2 metadata for testing -func (m *FakeMetadata) GetMetadata(key string) (string, error) { +func (m *FakeMetadata) GetMetadata(ctx context.Context, input *imds.GetMetadataInput, optFns ...func(*imds.Options)) (*imds.GetMetadataOutput, error) { + key := input.Path networkInterfacesPrefix := "network/interfaces/macs/" i := m.aws.selfInstance if key == "placement/availability-zone" { az := "" if i.Placement != nil { - az = aws.StringValue(i.Placement.AvailabilityZone) + az = aws.ToString(i.Placement.AvailabilityZone) } - return az, nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(az))}, nil } else if key == "instance-id" { - return aws.StringValue(i.InstanceId), nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(*i.InstanceId))}, nil } else if key == "local-hostname" { - return aws.StringValue(i.PrivateDnsName), nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(*i.PrivateDnsName))}, nil } else if key == "public-hostname" { - return aws.StringValue(i.PublicDnsName), nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(*i.PublicDnsName))}, nil } else if key == "local-ipv4" { - return aws.StringValue(i.PrivateIpAddress), nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(*i.PrivateIpAddress))}, nil } else if key == "public-ipv4" { - return aws.StringValue(i.PublicIpAddress), nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(*i.PublicIpAddress))}, nil } else if strings.HasPrefix(key, networkInterfacesPrefix) { if key == networkInterfacesPrefix { // Return the MACs sorted lexically rather than in device-number @@ -459,7 +468,8 @@ func (m *FakeMetadata) GetMetadata(key string) (string, error) { macs := make([]string, len(m.aws.networkInterfacesMacs)) copy(macs, m.aws.networkInterfacesMacs) sort.Strings(macs) - return strings.Join(macs, "/\n") + "/\n", nil + value := strings.Join(macs, "/\n") + "/\n" + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(value))}, nil } keySplit := strings.Split(key, "/") @@ -467,7 +477,7 @@ func (m *FakeMetadata) GetMetadata(key string) (string, error) { if len(keySplit) == 5 && keySplit[4] == "vpc-id" { for i, macElem := range m.aws.networkInterfacesMacs { if macParam == macElem { - return m.aws.networkInterfacesVpcIDs[i], nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(m.aws.networkInterfacesVpcIDs[i]))}, nil } } } @@ -479,27 +489,29 @@ func (m *FakeMetadata) GetMetadata(key string) (string, error) { // Introduce an artificial gap, just to test eg: [eth0, eth2] n++ } - return fmt.Sprintf("%d\n", n), nil + value := fmt.Sprintf("%d\n", n) + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(value))}, nil } } } if len(keySplit) == 5 && keySplit[4] == "local-ipv4s" { for i, macElem := range m.aws.networkInterfacesMacs { if macParam == macElem { - return strings.Join(m.aws.networkInterfacesPrivateIPs[i], "/\n"), nil + value := strings.Join(m.aws.networkInterfacesPrivateIPs[i], "/\n") + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(value))}, nil } } } - return "", nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(""))}, nil } - return "", nil + return &imds.GetMetadataOutput{Content: io.NopCloser(strings.NewReader(""))}, nil } -// Region returns AWS region -func (m *FakeMetadata) Region() (string, error) { - return m.aws.region, nil +// GetRegion returns AWS region +func (m *FakeMetadata) GetRegion(ctx context.Context, params *imds.GetRegionInput, optFns ...func(*imds.Options)) (*imds.GetRegionOutput, error) { + return &imds.GetRegionOutput{Region: m.aws.region}, nil } // FakeELB is a fake ELB client used for testing @@ -509,108 +521,108 @@ type FakeELB struct { // CreateLoadBalancer is not implemented but is required for interface // conformance -func (elb *FakeELB) CreateLoadBalancer(*elb.CreateLoadBalancerInput) (*elb.CreateLoadBalancerOutput, error) { +func (elb *FakeELB) CreateLoadBalancer(ctx context.Context, input *elb.CreateLoadBalancerInput, opts ...func(*elb.Options)) (*elb.CreateLoadBalancerOutput, error) { panic("Not implemented") } // DeleteLoadBalancer is not implemented but is required for interface // conformance -func (elb *FakeELB) DeleteLoadBalancer(input *elb.DeleteLoadBalancerInput) (*elb.DeleteLoadBalancerOutput, error) { +func (elb *FakeELB) DeleteLoadBalancer(ctx context.Context, input *elb.DeleteLoadBalancerInput, opts ...func(*elb.Options)) (*elb.DeleteLoadBalancerOutput, error) { panic("Not implemented") } // DescribeLoadBalancers is not implemented but is required for interface // conformance -func (elb *FakeELB) DescribeLoadBalancers(input *elb.DescribeLoadBalancersInput) (*elb.DescribeLoadBalancersOutput, error) { +func (elb *FakeELB) DescribeLoadBalancers(ctx context.Context, input *elb.DescribeLoadBalancersInput, opts ...func(*elb.Options)) (*elb.DescribeLoadBalancersOutput, error) { panic("Not implemented") } // AddTags is not implemented but is required for interface conformance -func (elb *FakeELB) AddTags(input *elb.AddTagsInput) (*elb.AddTagsOutput, error) { +func (elb *FakeELB) AddTags(ctx context.Context, input *elb.AddTagsInput, opts ...func(*elb.Options)) (*elb.AddTagsOutput, error) { panic("Not implemented") } // RegisterInstancesWithLoadBalancer is not implemented but is required for // interface conformance -func (elb *FakeELB) RegisterInstancesWithLoadBalancer(*elb.RegisterInstancesWithLoadBalancerInput) (*elb.RegisterInstancesWithLoadBalancerOutput, error) { +func (elb *FakeELB) RegisterInstancesWithLoadBalancer(ctx context.Context, input *elb.RegisterInstancesWithLoadBalancerInput, opts ...func(*elb.Options)) (*elb.RegisterInstancesWithLoadBalancerOutput, error) { panic("Not implemented") } // DeregisterInstancesFromLoadBalancer is not implemented but is required for // interface conformance -func (elb *FakeELB) DeregisterInstancesFromLoadBalancer(*elb.DeregisterInstancesFromLoadBalancerInput) (*elb.DeregisterInstancesFromLoadBalancerOutput, error) { +func (elb *FakeELB) DeregisterInstancesFromLoadBalancer(ctx context.Context, input *elb.DeregisterInstancesFromLoadBalancerInput, opts ...func(*elb.Options)) (*elb.DeregisterInstancesFromLoadBalancerOutput, error) { panic("Not implemented") } // DetachLoadBalancerFromSubnets is not implemented but is required for // interface conformance -func (elb *FakeELB) DetachLoadBalancerFromSubnets(*elb.DetachLoadBalancerFromSubnetsInput) (*elb.DetachLoadBalancerFromSubnetsOutput, error) { +func (elb *FakeELB) DetachLoadBalancerFromSubnets(ctx context.Context, input *elb.DetachLoadBalancerFromSubnetsInput, opts ...func(*elb.Options)) (*elb.DetachLoadBalancerFromSubnetsOutput, error) { panic("Not implemented") } // AttachLoadBalancerToSubnets is not implemented but is required for interface // conformance -func (elb *FakeELB) AttachLoadBalancerToSubnets(*elb.AttachLoadBalancerToSubnetsInput) (*elb.AttachLoadBalancerToSubnetsOutput, error) { +func (elb *FakeELB) AttachLoadBalancerToSubnets(ctx context.Context, input *elb.AttachLoadBalancerToSubnetsInput, opts ...func(*elb.Options)) (*elb.AttachLoadBalancerToSubnetsOutput, error) { panic("Not implemented") } // CreateLoadBalancerListeners is not implemented but is required for interface // conformance -func (elb *FakeELB) CreateLoadBalancerListeners(*elb.CreateLoadBalancerListenersInput) (*elb.CreateLoadBalancerListenersOutput, error) { +func (elb *FakeELB) CreateLoadBalancerListeners(ctx context.Context, input *elb.CreateLoadBalancerListenersInput, opts ...func(*elb.Options)) (*elb.CreateLoadBalancerListenersOutput, error) { panic("Not implemented") } // DeleteLoadBalancerListeners is not implemented but is required for interface // conformance -func (elb *FakeELB) DeleteLoadBalancerListeners(*elb.DeleteLoadBalancerListenersInput) (*elb.DeleteLoadBalancerListenersOutput, error) { +func (elb *FakeELB) DeleteLoadBalancerListeners(ctx context.Context, input *elb.DeleteLoadBalancerListenersInput, opts ...func(*elb.Options)) (*elb.DeleteLoadBalancerListenersOutput, error) { panic("Not implemented") } // ApplySecurityGroupsToLoadBalancer is not implemented but is required for // interface conformance -func (elb *FakeELB) ApplySecurityGroupsToLoadBalancer(*elb.ApplySecurityGroupsToLoadBalancerInput) (*elb.ApplySecurityGroupsToLoadBalancerOutput, error) { +func (elb *FakeELB) ApplySecurityGroupsToLoadBalancer(ctx context.Context, input *elb.ApplySecurityGroupsToLoadBalancerInput, opts ...func(*elb.Options)) (*elb.ApplySecurityGroupsToLoadBalancerOutput, error) { panic("Not implemented") } // ConfigureHealthCheck is not implemented but is required for interface // conformance -func (elb *FakeELB) ConfigureHealthCheck(*elb.ConfigureHealthCheckInput) (*elb.ConfigureHealthCheckOutput, error) { +func (elb *FakeELB) ConfigureHealthCheck(ctx context.Context, input *elb.ConfigureHealthCheckInput, opts ...func(*elb.Options)) (*elb.ConfigureHealthCheckOutput, error) { panic("Not implemented") } // CreateLoadBalancerPolicy is not implemented but is required for interface // conformance -func (elb *FakeELB) CreateLoadBalancerPolicy(*elb.CreateLoadBalancerPolicyInput) (*elb.CreateLoadBalancerPolicyOutput, error) { +func (elb *FakeELB) CreateLoadBalancerPolicy(ctx context.Context, input *elb.CreateLoadBalancerPolicyInput, opts ...func(*elb.Options)) (*elb.CreateLoadBalancerPolicyOutput, error) { panic("Not implemented") } // SetLoadBalancerPoliciesForBackendServer is not implemented but is required // for interface conformance -func (elb *FakeELB) SetLoadBalancerPoliciesForBackendServer(*elb.SetLoadBalancerPoliciesForBackendServerInput) (*elb.SetLoadBalancerPoliciesForBackendServerOutput, error) { +func (elb *FakeELB) SetLoadBalancerPoliciesForBackendServer(ctx context.Context, input *elb.SetLoadBalancerPoliciesForBackendServerInput, opts ...func(*elb.Options)) (*elb.SetLoadBalancerPoliciesForBackendServerOutput, error) { panic("Not implemented") } // SetLoadBalancerPoliciesOfListener is not implemented but is required for // interface conformance -func (elb *FakeELB) SetLoadBalancerPoliciesOfListener(input *elb.SetLoadBalancerPoliciesOfListenerInput) (*elb.SetLoadBalancerPoliciesOfListenerOutput, error) { +func (elb *FakeELB) SetLoadBalancerPoliciesOfListener(ctx context.Context, input *elb.SetLoadBalancerPoliciesOfListenerInput, opts ...func(*elb.Options)) (*elb.SetLoadBalancerPoliciesOfListenerOutput, error) { panic("Not implemented") } // DescribeLoadBalancerPolicies is not implemented but is required for // interface conformance -func (elb *FakeELB) DescribeLoadBalancerPolicies(input *elb.DescribeLoadBalancerPoliciesInput) (*elb.DescribeLoadBalancerPoliciesOutput, error) { +func (elb *FakeELB) DescribeLoadBalancerPolicies(ctx context.Context, input *elb.DescribeLoadBalancerPoliciesInput, opts ...func(*elb.Options)) (*elb.DescribeLoadBalancerPoliciesOutput, error) { panic("Not implemented") } // DescribeLoadBalancerAttributes is not implemented but is required for // interface conformance -func (elb *FakeELB) DescribeLoadBalancerAttributes(*elb.DescribeLoadBalancerAttributesInput) (*elb.DescribeLoadBalancerAttributesOutput, error) { +func (elb *FakeELB) DescribeLoadBalancerAttributes(ctx context.Context, input *elb.DescribeLoadBalancerAttributesInput, opts ...func(*elb.Options)) (*elb.DescribeLoadBalancerAttributesOutput, error) { panic("Not implemented") } // ModifyLoadBalancerAttributes is not implemented but is required for // interface conformance -func (elb *FakeELB) ModifyLoadBalancerAttributes(*elb.ModifyLoadBalancerAttributesInput) (*elb.ModifyLoadBalancerAttributesOutput, error) { +func (elb *FakeELB) ModifyLoadBalancerAttributes(ctx context.Context, input *elb.ModifyLoadBalancerAttributesInput, opts ...func(*elb.Options)) (*elb.ModifyLoadBalancerAttributesOutput, error) { panic("Not implemented") } @@ -620,117 +632,97 @@ type FakeELBV2 struct { } // AddTags is not implemented but is required for interface conformance -func (elb *FakeELBV2) AddTags(input *elbv2.AddTagsInput) (*elbv2.AddTagsOutput, error) { +func (elb *FakeELBV2) AddTags(ctx context.Context, input *elbv2.AddTagsInput, optFns ...func(*elbv2.Options)) (*elbv2.AddTagsOutput, error) { panic("Not implemented") } -// CreateLoadBalancer is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) CreateLoadBalancer(*elbv2.CreateLoadBalancerInput) (*elbv2.CreateLoadBalancerOutput, error) { +// CreateLoadBalancer is not implemented but is required for interface conformance +func (elb *FakeELBV2) CreateLoadBalancer(ctx context.Context, input *elbv2.CreateLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateLoadBalancerOutput, error) { panic("Not implemented") } -// DescribeLoadBalancers is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DescribeLoadBalancers(*elbv2.DescribeLoadBalancersInput) (*elbv2.DescribeLoadBalancersOutput, error) { +// DescribeLoadBalancers is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeLoadBalancers(ctx context.Context, input *elbv2.DescribeLoadBalancersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancersOutput, error) { panic("Not implemented") } -// DeleteLoadBalancer is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DeleteLoadBalancer(*elbv2.DeleteLoadBalancerInput) (*elbv2.DeleteLoadBalancerOutput, error) { +// DeleteLoadBalancer is not implemented but is required for interface conformance +func (elb *FakeELBV2) DeleteLoadBalancer(ctx context.Context, input *elbv2.DeleteLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteLoadBalancerOutput, error) { panic("Not implemented") } -// ModifyLoadBalancerAttributes is not implemented but is required for -// interface conformance -func (elb *FakeELBV2) ModifyLoadBalancerAttributes(*elbv2.ModifyLoadBalancerAttributesInput) (*elbv2.ModifyLoadBalancerAttributesOutput, error) { +// ModifyLoadBalancerAttributes is not implemented but is required for interface conformance +func (elb *FakeELBV2) ModifyLoadBalancerAttributes(ctx context.Context, input *elbv2.ModifyLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyLoadBalancerAttributesOutput, error) { panic("Not implemented") } -// DescribeLoadBalancerAttributes is not implemented but is required for -// interface conformance -func (elb *FakeELBV2) DescribeLoadBalancerAttributes(*elbv2.DescribeLoadBalancerAttributesInput) (*elbv2.DescribeLoadBalancerAttributesOutput, error) { +// DescribeLoadBalancerAttributes is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeLoadBalancerAttributes(ctx context.Context, input *elbv2.DescribeLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancerAttributesOutput, error) { panic("Not implemented") } -// CreateTargetGroup is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) CreateTargetGroup(*elbv2.CreateTargetGroupInput) (*elbv2.CreateTargetGroupOutput, error) { +// CreateTargetGroup is not implemented but is required for interface conformance +func (elb *FakeELBV2) CreateTargetGroup(ctx context.Context, input *elbv2.CreateTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateTargetGroupOutput, error) { panic("Not implemented") } -// DescribeTargetGroups is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DescribeTargetGroups(*elbv2.DescribeTargetGroupsInput) (*elbv2.DescribeTargetGroupsOutput, error) { +// DescribeTargetGroups is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeTargetGroups(ctx context.Context, input *elbv2.DescribeTargetGroupsInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupsOutput, error) { panic("Not implemented") } -// ModifyTargetGroup is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) ModifyTargetGroup(*elbv2.ModifyTargetGroupInput) (*elbv2.ModifyTargetGroupOutput, error) { +// ModifyTargetGroup is not implemented but is required for interface conformance +func (elb *FakeELBV2) ModifyTargetGroup(ctx context.Context, input *elbv2.ModifyTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupOutput, error) { panic("Not implemented") } -// DeleteTargetGroup is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DeleteTargetGroup(*elbv2.DeleteTargetGroupInput) (*elbv2.DeleteTargetGroupOutput, error) { +// DeleteTargetGroup is not implemented but is required for interface conformance +func (elb *FakeELBV2) DeleteTargetGroup(ctx context.Context, input *elbv2.DeleteTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteTargetGroupOutput, error) { panic("Not implemented") } -// DescribeTargetHealth is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DescribeTargetHealth(input *elbv2.DescribeTargetHealthInput) (*elbv2.DescribeTargetHealthOutput, error) { +// DescribeTargetHealth is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeTargetHealth(ctx context.Context, input *elbv2.DescribeTargetHealthInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetHealthOutput, error) { panic("Not implemented") } -// DescribeTargetGroupAttributes is not implemented but is required for -// interface conformance -func (elb *FakeELBV2) DescribeTargetGroupAttributes(*elbv2.DescribeTargetGroupAttributesInput) (*elbv2.DescribeTargetGroupAttributesOutput, error) { +// DescribeTargetGroupAttributes is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeTargetGroupAttributes(ctx context.Context, input *elbv2.DescribeTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupAttributesOutput, error) { panic("Not implemented") } -// ModifyTargetGroupAttributes is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) ModifyTargetGroupAttributes(*elbv2.ModifyTargetGroupAttributesInput) (*elbv2.ModifyTargetGroupAttributesOutput, error) { +// ModifyTargetGroupAttributes is not implemented but is required for interface conformance +func (elb *FakeELBV2) ModifyTargetGroupAttributes(ctx context.Context, input *elbv2.ModifyTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupAttributesOutput, error) { panic("Not implemented") } // RegisterTargets is not implemented but is required for interface conformance -func (elb *FakeELBV2) RegisterTargets(*elbv2.RegisterTargetsInput) (*elbv2.RegisterTargetsOutput, error) { +func (elb *FakeELBV2) RegisterTargets(ctx context.Context, input *elbv2.RegisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.RegisterTargetsOutput, error) { panic("Not implemented") } -// DeregisterTargets is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DeregisterTargets(*elbv2.DeregisterTargetsInput) (*elbv2.DeregisterTargetsOutput, error) { +// DeregisterTargets is not implemented but is required for interface conformance +func (elb *FakeELBV2) DeregisterTargets(ctx context.Context, input *elbv2.DeregisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.DeregisterTargetsOutput, error) { panic("Not implemented") } // CreateListener is not implemented but is required for interface conformance -func (elb *FakeELBV2) CreateListener(*elbv2.CreateListenerInput) (*elbv2.CreateListenerOutput, error) { +func (elb *FakeELBV2) CreateListener(ctx context.Context, input *elbv2.CreateListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateListenerOutput, error) { panic("Not implemented") } -// DescribeListeners is not implemented but is required for interface -// conformance -func (elb *FakeELBV2) DescribeListeners(*elbv2.DescribeListenersInput) (*elbv2.DescribeListenersOutput, error) { +// DescribeListeners is not implemented but is required for interface conformance +func (elb *FakeELBV2) DescribeListeners(ctx context.Context, input *elbv2.DescribeListenersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeListenersOutput, error) { panic("Not implemented") } // DeleteListener is not implemented but is required for interface conformance -func (elb *FakeELBV2) DeleteListener(*elbv2.DeleteListenerInput) (*elbv2.DeleteListenerOutput, error) { +func (elb *FakeELBV2) DeleteListener(ctx context.Context, input *elbv2.DeleteListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteListenerOutput, error) { panic("Not implemented") } // ModifyListener is not implemented but is required for interface conformance -func (elb *FakeELBV2) ModifyListener(*elbv2.ModifyListenerInput) (*elbv2.ModifyListenerOutput, error) { - panic("Not implemented") -} - -// WaitUntilLoadBalancersDeleted is not implemented but is required for -// interface conformance -func (elb *FakeELBV2) WaitUntilLoadBalancersDeleted(*elbv2.DescribeLoadBalancersInput) error { +func (elb *FakeELBV2) ModifyListener(ctx context.Context, input *elbv2.ModifyListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyListenerOutput, error) { panic("Not implemented") } @@ -741,13 +733,13 @@ type FakeASG struct { // UpdateAutoScalingGroup is not implemented but is required for interface // conformance -func (a *FakeASG) UpdateAutoScalingGroup(*autoscaling.UpdateAutoScalingGroupInput) (*autoscaling.UpdateAutoScalingGroupOutput, error) { +func (a *FakeASG) UpdateAutoScalingGroup(ctx context.Context, input *autoscaling.UpdateAutoScalingGroupInput, optFns ...func(*autoscaling.Options)) (*autoscaling.UpdateAutoScalingGroupOutput, error) { panic("Not implemented") } // DescribeAutoScalingGroups is not implemented but is required for interface // conformance -func (a *FakeASG) DescribeAutoScalingGroups(*autoscaling.DescribeAutoScalingGroupsInput) (*autoscaling.DescribeAutoScalingGroupsOutput, error) { +func (a *FakeASG) DescribeAutoScalingGroups(ctx context.Context, input *autoscaling.DescribeAutoScalingGroupsInput, optFns ...func(*autoscaling.Options)) (*autoscaling.DescribeAutoScalingGroupsOutput, error) { panic("Not implemented") } @@ -757,26 +749,26 @@ type FakeKMS struct { } // DescribeKey is not implemented but is required for interface conformance -func (kms *FakeKMS) DescribeKey(*kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) { +func (kms *FakeKMS) DescribeKey(ctx context.Context, input *kms.DescribeKeyInput, optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { panic("Not implemented") } -func instanceMatchesFilter(instance *ec2.Instance, filter *ec2.Filter) bool { +func instanceMatchesFilter(instance *ec2types.Instance, filter ec2types.Filter) bool { name := *filter.Name if name == "private-dns-name" { if instance.PrivateDnsName == nil { return false } - return contains(filter.Values, *instance.PrivateDnsName) + return contains(filter.Values, aws.ToString(instance.PrivateDnsName)) } if name == "instance-state-name" { - return contains(filter.Values, *instance.State.Name) + return contains(filter.Values, string(instance.State.Name)) } if name == "tag-key" { for _, instanceTag := range instance.Tags { - if contains(filter.Values, aws.StringValue(instanceTag.Key)) { + if contains(filter.Values, aws.ToString(instanceTag.Key)) { return true } } @@ -786,7 +778,7 @@ func instanceMatchesFilter(instance *ec2.Instance, filter *ec2.Filter) bool { if strings.HasPrefix(name, "tag:") { tagName := name[4:] for _, instanceTag := range instance.Tags { - if aws.StringValue(instanceTag.Key) == tagName && contains(filter.Values, aws.StringValue(instanceTag.Value)) { + if aws.ToString(instanceTag.Key) == tagName && contains(filter.Values, aws.ToString(instanceTag.Value)) { return true } } @@ -796,10 +788,10 @@ func instanceMatchesFilter(instance *ec2.Instance, filter *ec2.Filter) bool { panic("Unknown filter name: " + name) } -func contains(haystack []*string, needle string) bool { +func contains(haystack []string, needle string) bool { for _, s := range haystack { // (deliberately panic if s == nil) - if needle == *s { + if needle == s { return true } } @@ -807,28 +799,28 @@ func contains(haystack []*string, needle string) bool { } // DescribeNetworkInterfaces returns list of ENIs for testing -func (ec2i *FakeEC2Impl) DescribeNetworkInterfaces(input *ec2.DescribeNetworkInterfacesInput) (*ec2.DescribeNetworkInterfacesOutput, error) { - networkInterface := []*ec2.NetworkInterface{ +func (ec2i *FakeEC2Impl) DescribeNetworkInterfaces(ctx context.Context, input *ec2.DescribeNetworkInterfacesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) { + networkInterface := []ec2types.NetworkInterface{ { PrivateIpAddress: aws.String("1.2.3.4"), AvailabilityZone: aws.String("us-west-2c"), }, } for _, filter := range input.Filters { - if strings.HasPrefix(*filter.Values[0], fargateNodeNamePrefix) { + if strings.HasPrefix(filter.Values[0], fargateNodeNamePrefix) { // verify filter doesn't have fargate prefix - panic(fmt.Sprintf("invalid endpoint specified for DescribeNetworkInterface call %s", *filter.Values[0])) - } else if strings.HasPrefix(*filter.Values[0], "not-found") { + panic(fmt.Sprintf("invalid endpoint specified for DescribeNetworkInterface call %s", filter.Values[0])) + } else if strings.HasPrefix(filter.Values[0], "not-found") { // for negative testing return &ec2.DescribeNetworkInterfacesOutput{}, nil } - if strings.Contains(*filter.Values[0], "return.private.dns.name") { + if strings.Contains(filter.Values[0], "return.private.dns.name") { networkInterface[0].PrivateDnsName = aws.String("ip-1-2-3-4.compute.amazon.com") } - if *filter.Values[0] == "return.private.dns.name.ipv6" { - networkInterface[0].Ipv6Addresses = []*ec2.NetworkInterfaceIpv6Address{ + if filter.Values[0] == "return.private.dns.name.ipv6" { + networkInterface[0].Ipv6Addresses = []ec2types.NetworkInterfaceIpv6Address{ { Ipv6Address: aws.String("2001:db8:3333:4444:5555:6666:7777:8888"), }, diff --git a/pkg/providers/v1/aws_instancegroups.go b/pkg/providers/v1/aws_instancegroups.go index 84b6c7b988..d4234d6a88 100644 --- a/pkg/providers/v1/aws_instancegroups.go +++ b/pkg/providers/v1/aws_instancegroups.go @@ -17,10 +17,12 @@ limitations under the License. package aws import ( + "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/autoscaling" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" + autoscalingtypes "github.com/aws/aws-sdk-go-v2/service/autoscaling/types" "k8s.io/klog/v2" ) @@ -30,12 +32,12 @@ var _ InstanceGroups = &Cloud{} // ResizeInstanceGroup sets the size of the specificed instancegroup Exported // so it can be used by the e2e tests, which don't want to instantiate a full // cloudprovider. -func ResizeInstanceGroup(asg ASG, instanceGroupName string, size int) error { +func ResizeInstanceGroup(ctx context.Context, asg ASG, instanceGroupName string, size int) error { request := &autoscaling.UpdateAutoScalingGroupInput{ AutoScalingGroupName: aws.String(instanceGroupName), - DesiredCapacity: aws.Int64(int64(size)), + DesiredCapacity: aws.Int32(int32(size)), } - if _, err := asg.UpdateAutoScalingGroup(request); err != nil { + if _, err := asg.UpdateAutoScalingGroup(ctx, request); err != nil { return fmt.Errorf("error resizing AWS autoscaling group: %q", err) } return nil @@ -43,18 +45,18 @@ func ResizeInstanceGroup(asg ASG, instanceGroupName string, size int) error { // ResizeInstanceGroup implements InstanceGroups.ResizeInstanceGroup // Set the size to the fixed size -func (c *Cloud) ResizeInstanceGroup(instanceGroupName string, size int) error { - return ResizeInstanceGroup(c.asg, instanceGroupName, size) +func (c *Cloud) ResizeInstanceGroup(ctx context.Context, instanceGroupName string, size int) error { + return ResizeInstanceGroup(ctx, c.asg, instanceGroupName, size) } // DescribeInstanceGroup gets info about the specified instancegroup // Exported so it can be used by the e2e tests, // which don't want to instantiate a full cloudprovider. -func DescribeInstanceGroup(asg ASG, instanceGroupName string) (InstanceGroupInfo, error) { +func DescribeInstanceGroup(ctx context.Context, asg ASG, instanceGroupName string) (InstanceGroupInfo, error) { request := &autoscaling.DescribeAutoScalingGroupsInput{ - AutoScalingGroupNames: []*string{aws.String(instanceGroupName)}, + AutoScalingGroupNames: []string{instanceGroupName}, } - response, err := asg.DescribeAutoScalingGroups(request) + response, err := asg.DescribeAutoScalingGroups(ctx, request) if err != nil { return nil, fmt.Errorf("error listing AWS autoscaling group (%s): %q", instanceGroupName, err) } @@ -66,20 +68,20 @@ func DescribeInstanceGroup(asg ASG, instanceGroupName string) (InstanceGroupInfo klog.Warning("AWS returned multiple autoscaling groups with name ", instanceGroupName) } group := response.AutoScalingGroups[0] - return &awsInstanceGroup{group: group}, nil + return &awsInstanceGroup{group: &group}, nil } // DescribeInstanceGroup implements InstanceGroups.DescribeInstanceGroup // Queries the cloud provider for information about the specified instance group -func (c *Cloud) DescribeInstanceGroup(instanceGroupName string) (InstanceGroupInfo, error) { - return DescribeInstanceGroup(c.asg, instanceGroupName) +func (c *Cloud) DescribeInstanceGroup(ctx context.Context, instanceGroupName string) (InstanceGroupInfo, error) { + return DescribeInstanceGroup(ctx, c.asg, instanceGroupName) } // awsInstanceGroup implements InstanceGroupInfo var _ InstanceGroupInfo = &awsInstanceGroup{} type awsInstanceGroup struct { - group *autoscaling.Group + group *autoscalingtypes.AutoScalingGroup } // Implement InstanceGroupInfo.CurrentSize diff --git a/pkg/providers/v1/aws_loadbalancer.go b/pkg/providers/v1/aws_loadbalancer.go index c39ea3de37..b82d378c2e 100644 --- a/pkg/providers/v1/aws_loadbalancer.go +++ b/pkg/providers/v1/aws_loadbalancer.go @@ -17,8 +17,10 @@ limitations under the License. package aws import ( + "context" "crypto/sha1" "encoding/hex" + "errors" "fmt" "reflect" "regexp" @@ -26,11 +28,13 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elbv2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" @@ -58,13 +62,13 @@ const ( var ( // Defaults for ELB Healthcheck - defaultElbHCHealthyThreshold = int64(2) - defaultElbHCUnhealthyThreshold = int64(6) - defaultElbHCTimeout = int64(5) - defaultElbHCInterval = int64(10) - defaultNlbHealthCheckInterval = int64(30) - defaultNlbHealthCheckTimeout = int64(10) - defaultNlbHealthCheckThreshold = int64(3) + defaultElbHCHealthyThreshold = int32(2) + defaultElbHCUnhealthyThreshold = int32(6) + defaultElbHCTimeout = int32(5) + defaultElbHCInterval = int32(10) + defaultNlbHealthCheckInterval = int32(30) + defaultNlbHealthCheckTimeout = int32(10) + defaultNlbHealthCheckThreshold = int32(3) defaultHealthCheckPort = "traffic-port" defaultHealthCheckPath = "/" @@ -90,19 +94,19 @@ func isLBExternal(annotations map[string]string) bool { type healthCheckConfig struct { Port string Path string - Protocol string - Interval int64 - Timeout int64 - HealthyThreshold int64 - UnhealthyThreshold int64 + Protocol elbv2types.ProtocolEnum + Interval int32 + Timeout int32 + HealthyThreshold int32 + UnhealthyThreshold int32 } type nlbPortMapping struct { - FrontendPort int64 - FrontendProtocol string + FrontendPort int32 + FrontendProtocol elbv2types.ProtocolEnum - TrafficPort int64 - TrafficProtocol string + TrafficPort int32 + TrafficProtocol elbv2types.ProtocolEnum SSLCertificateARN string SSLPolicy string @@ -138,8 +142,8 @@ func getKeyValuePropertiesFromAnnotation(annotations map[string]string, annotati } // ensureLoadBalancerv2 ensures a v2 load balancer is created -func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBalancerName string, mappings []nlbPortMapping, instanceIDs, discoveredSubnetIDs []string, internalELB bool, annotations map[string]string) (*elbv2.LoadBalancer, error) { - loadBalancer, err := c.describeLoadBalancerv2(loadBalancerName) +func (c *Cloud) ensureLoadBalancerv2(ctx context.Context, namespacedName types.NamespacedName, loadBalancerName string, mappings []nlbPortMapping, instanceIDs, discoveredSubnetIDs []string, internalELB bool, annotations map[string]string) (*elbv2types.LoadBalancer, error) { + loadBalancer, err := c.describeLoadBalancerv2(ctx, loadBalancerName) if err != nil { return nil, err } @@ -155,11 +159,11 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa if loadBalancer == nil { // Create the LB createRequest := &elbv2.CreateLoadBalancerInput{ - Type: aws.String(elbv2.LoadBalancerTypeEnumNetwork), + Type: elbv2types.LoadBalancerTypeEnumNetwork, Name: aws.String(loadBalancerName), } if internalELB { - createRequest.Scheme = aws.String("internal") + createRequest.Scheme = elbv2types.LoadBalancerSchemeEnumInternal } var allocationIDs []string @@ -175,27 +179,27 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa createRequest.SubnetMappings = createSubnetMappings(discoveredSubnetIDs, allocationIDs) for k, v := range tags { - createRequest.Tags = append(createRequest.Tags, &elbv2.Tag{ + createRequest.Tags = append(createRequest.Tags, elbv2types.Tag{ Key: aws.String(k), Value: aws.String(v), }) } klog.Infof("Creating load balancer for %v with name: %s", namespacedName, loadBalancerName) - createResponse, err := c.elbv2.CreateLoadBalancer(createRequest) + createResponse, err := c.elbv2.CreateLoadBalancer(ctx, createRequest) if err != nil { return nil, fmt.Errorf("error creating load balancer: %q", err) } - loadBalancer = createResponse.LoadBalancers[0] + loadBalancer = &createResponse.LoadBalancers[0] for i := range mappings { // It is easier to keep track of updates by having possibly // duplicate target groups where the backend port is the same - _, err := c.createListenerV2(createResponse.LoadBalancers[0].LoadBalancerArn, mappings[i], namespacedName, instanceIDs, *createResponse.LoadBalancers[0].VpcId, tags) + _, err := c.createListenerV2(ctx, createResponse.LoadBalancers[0].LoadBalancerArn, mappings[i], namespacedName, instanceIDs, *createResponse.LoadBalancers[0].VpcId, tags) if err != nil { return nil, fmt.Errorf("error creating listener: %q", err) } } - if err := c.reconcileLBAttributes(aws.StringValue(loadBalancer.LoadBalancerArn), annotations); err != nil { + if err := c.reconcileLBAttributes(ctx, aws.ToString(loadBalancer.LoadBalancerArn), annotations); err != nil { return nil, err } } else { @@ -203,7 +207,7 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa // sync mappings { - listenerDescriptions, err := c.elbv2.DescribeListeners( + listenerDescriptions, err := c.elbv2.DescribeListeners(ctx, &elbv2.DescribeListenersInput{ LoadBalancerArn: loadBalancer.LoadBalancerArn, }, @@ -213,15 +217,15 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa } // actual maps FrontendPort to an elbv2.Listener - actual := map[int64]map[string]*elbv2.Listener{} + actual := map[int32]map[elbv2types.ProtocolEnum]*elbv2types.Listener{} for _, listener := range listenerDescriptions.Listeners { if actual[*listener.Port] == nil { - actual[*listener.Port] = map[string]*elbv2.Listener{} + actual[*listener.Port] = map[elbv2types.ProtocolEnum]*elbv2types.Listener{} } - actual[*listener.Port][*listener.Protocol] = listener + actual[*listener.Port][listener.Protocol] = &listener } - actualTargetGroups, err := c.elbv2.DescribeTargetGroups( + actualTargetGroups, err := c.elbv2.DescribeTargetGroups(ctx, &elbv2.DescribeTargetGroupsInput{ LoadBalancerArn: loadBalancer.LoadBalancerArn, }, @@ -230,9 +234,9 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa return nil, fmt.Errorf("error listing target groups: %q", err) } - nodePortTargetGroup := map[int64]*elbv2.TargetGroup{} + nodePortTargetGroup := map[int32]*elbv2types.TargetGroup{} for _, targetGroup := range actualTargetGroups.TargetGroups { - nodePortTargetGroup[*targetGroup.Port] = targetGroup + nodePortTargetGroup[*targetGroup.Port] = &targetGroup } // Handle additions/modifications @@ -244,22 +248,22 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa if listener, ok := actual[frontendPort][frontendProtocol]; ok { listenerNeedsModification := false - if aws.StringValue(listener.Protocol) != mapping.FrontendProtocol { + if listener.Protocol != mapping.FrontendProtocol { listenerNeedsModification = true } switch mapping.FrontendProtocol { - case elbv2.ProtocolEnumTls: + case elbv2types.ProtocolEnumTls: { - if aws.StringValue(listener.SslPolicy) != mapping.SSLPolicy { + if aws.ToString(listener.SslPolicy) != mapping.SSLPolicy { listenerNeedsModification = true } - if len(listener.Certificates) == 0 || aws.StringValue(listener.Certificates[0].CertificateArn) != mapping.SSLCertificateARN { + if len(listener.Certificates) == 0 || aws.ToString(listener.Certificates[0].CertificateArn) != mapping.SSLCertificateARN { listenerNeedsModification = true } } - case elbv2.ProtocolEnumTcp: + case elbv2types.ProtocolEnumTcp: { - if aws.StringValue(listener.SslPolicy) != "" { + if aws.ToString(listener.SslPolicy) != "" { listenerNeedsModification = true } if len(listener.Certificates) != 0 { @@ -273,14 +277,14 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa targetGroupRecreated := false targetGroup, ok := nodePortTargetGroup[nodePort] - if targetGroup != nil && (!strings.EqualFold(mapping.HealthCheckConfig.Protocol, aws.StringValue(targetGroup.HealthCheckProtocol)) || - mapping.HealthCheckConfig.Interval != aws.Int64Value(targetGroup.HealthCheckIntervalSeconds)) { + if targetGroup != nil && (!strings.EqualFold(string(mapping.HealthCheckConfig.Protocol), string(targetGroup.HealthCheckProtocol)) || + mapping.HealthCheckConfig.Interval != aws.ToInt32(targetGroup.HealthCheckIntervalSeconds)) { healthCheckModified = true } - if !ok || aws.StringValue(targetGroup.Protocol) != mapping.TrafficProtocol || healthCheckModified { + if !ok || targetGroup.Protocol != mapping.TrafficProtocol || healthCheckModified { // create new target group - targetGroup, err = c.ensureTargetGroup( + targetGroup, err = c.ensureTargetGroup(ctx, nil, namespacedName, mapping, @@ -298,38 +302,38 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa if listenerNeedsModification { modifyListenerInput := &elbv2.ModifyListenerInput{ ListenerArn: listener.ListenerArn, - Port: aws.Int64(frontendPort), - Protocol: aws.String(mapping.FrontendProtocol), - DefaultActions: []*elbv2.Action{{ + Port: aws.Int32(frontendPort), + Protocol: mapping.FrontendProtocol, + DefaultActions: []elbv2types.Action{{ TargetGroupArn: targetGroup.TargetGroupArn, - Type: aws.String("forward"), + Type: elbv2types.ActionTypeEnumForward, }}, } - if mapping.FrontendProtocol == elbv2.ProtocolEnumTls { + if mapping.FrontendProtocol == elbv2types.ProtocolEnumTls { if mapping.SSLPolicy != "" { modifyListenerInput.SslPolicy = aws.String(mapping.SSLPolicy) } - modifyListenerInput.Certificates = []*elbv2.Certificate{ + modifyListenerInput.Certificates = []elbv2types.Certificate{ { CertificateArn: aws.String(mapping.SSLCertificateARN), }, } } - if _, err := c.elbv2.ModifyListener(modifyListenerInput); err != nil { + if _, err := c.elbv2.ModifyListener(ctx, modifyListenerInput); err != nil { return nil, fmt.Errorf("error updating load balancer listener: %q", err) } } // Delete old targetGroup if needed if targetGroupRecreated { - if _, err := c.elbv2.DeleteTargetGroup(&elbv2.DeleteTargetGroupInput{ + if _, err := c.elbv2.DeleteTargetGroup(ctx, &elbv2.DeleteTargetGroupInput{ TargetGroupArn: listener.DefaultActions[0].TargetGroupArn, }); err != nil { return nil, fmt.Errorf("error deleting old target group: %q", err) } } else { // Run ensureTargetGroup to make sure instances in service are up-to-date - _, err = c.ensureTargetGroup( + _, err = c.ensureTargetGroup(ctx, targetGroup, namespacedName, mapping, @@ -346,17 +350,17 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa } // Additions - _, err := c.createListenerV2(loadBalancer.LoadBalancerArn, mapping, namespacedName, instanceIDs, *loadBalancer.VpcId, tags) + _, err := c.createListenerV2(ctx, loadBalancer.LoadBalancerArn, mapping, namespacedName, instanceIDs, *loadBalancer.VpcId, tags) if err != nil { return nil, err } dirty = true } - frontEndPorts := map[int64]map[string]bool{} + frontEndPorts := map[int32]map[elbv2types.ProtocolEnum]bool{} for i := range mappings { if frontEndPorts[mappings[i].FrontendPort] == nil { - frontEndPorts[mappings[i].FrontendPort] = map[string]bool{} + frontEndPorts[mappings[i].FrontendPort] = map[elbv2types.ProtocolEnum]bool{} } frontEndPorts[mappings[i].FrontendPort][mappings[i].FrontendProtocol] = true } @@ -365,7 +369,7 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa for port := range actual { for protocol := range actual[port] { if _, ok := frontEndPorts[port][protocol]; !ok { - err := c.deleteListenerV2(actual[port][protocol]) + err := c.deleteListenerV2(ctx, actual[port][protocol]) if err != nil { return nil, err } @@ -374,29 +378,29 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa } } } - if err := c.reconcileLBAttributes(aws.StringValue(loadBalancer.LoadBalancerArn), annotations); err != nil { + if err := c.reconcileLBAttributes(ctx, aws.ToString(loadBalancer.LoadBalancerArn), annotations); err != nil { return nil, err } // Subnets cannot be modified on NLBs if dirty { - loadBalancers, err := c.elbv2.DescribeLoadBalancers( + loadBalancers, err := c.elbv2.DescribeLoadBalancers(ctx, &elbv2.DescribeLoadBalancersInput{ - LoadBalancerArns: []*string{ - loadBalancer.LoadBalancerArn, + LoadBalancerArns: []string{ + aws.ToString(loadBalancer.LoadBalancerArn), }, }, ) if err != nil { return nil, fmt.Errorf("error retrieving load balancer after update: %q", err) } - loadBalancer = loadBalancers.LoadBalancers[0] + loadBalancer = &loadBalancers.LoadBalancers[0] } } return loadBalancer, nil } -func (c *Cloud) reconcileLBAttributes(loadBalancerArn string, annotations map[string]string) error { +func (c *Cloud) reconcileLBAttributes(ctx context.Context, loadBalancerArn string, annotations map[string]string) error { desiredLoadBalancerAttributes := map[string]string{} desiredLoadBalancerAttributes[lbAttrLoadBalancingCrossZoneEnabled] = "false" @@ -435,25 +439,25 @@ func (c *Cloud) reconcileLBAttributes(loadBalancerArn string, annotations map[st desiredLoadBalancerAttributes[lbAttrAccessLogsS3Prefix] = annotations[ServiceAnnotationLoadBalancerAccessLogS3BucketPrefix] currentLoadBalancerAttributes := map[string]string{} - describeAttributesOutput, err := c.elbv2.DescribeLoadBalancerAttributes(&elbv2.DescribeLoadBalancerAttributesInput{ + describeAttributesOutput, err := c.elbv2.DescribeLoadBalancerAttributes(ctx, &elbv2.DescribeLoadBalancerAttributesInput{ LoadBalancerArn: aws.String(loadBalancerArn), }) if err != nil { return fmt.Errorf("unable to retrieve load balancer attributes during attribute sync: %q", err) } for _, attr := range describeAttributesOutput.Attributes { - currentLoadBalancerAttributes[aws.StringValue(attr.Key)] = aws.StringValue(attr.Value) + currentLoadBalancerAttributes[aws.ToString(attr.Key)] = aws.ToString(attr.Value) } - var changedAttributes []*elbv2.LoadBalancerAttribute + var changedAttributes []elbv2types.LoadBalancerAttribute if desiredLoadBalancerAttributes[lbAttrLoadBalancingCrossZoneEnabled] != currentLoadBalancerAttributes[lbAttrLoadBalancingCrossZoneEnabled] { - changedAttributes = append(changedAttributes, &elbv2.LoadBalancerAttribute{ + changedAttributes = append(changedAttributes, elbv2types.LoadBalancerAttribute{ Key: aws.String(lbAttrLoadBalancingCrossZoneEnabled), Value: aws.String(desiredLoadBalancerAttributes[lbAttrLoadBalancingCrossZoneEnabled]), }) } if desiredLoadBalancerAttributes[lbAttrAccessLogsS3Enabled] != currentLoadBalancerAttributes[lbAttrAccessLogsS3Enabled] { - changedAttributes = append(changedAttributes, &elbv2.LoadBalancerAttribute{ + changedAttributes = append(changedAttributes, elbv2types.LoadBalancerAttribute{ Key: aws.String(lbAttrAccessLogsS3Enabled), Value: aws.String(desiredLoadBalancerAttributes[lbAttrAccessLogsS3Enabled]), }) @@ -462,13 +466,13 @@ func (c *Cloud) reconcileLBAttributes(loadBalancerArn string, annotations map[st // ELBV2 API forbids us to set bucket to an empty bucket, so we keep it unchanged if AccessLogsS3Enabled==false. if desiredLoadBalancerAttributes[lbAttrAccessLogsS3Enabled] == "true" { if desiredLoadBalancerAttributes[lbAttrAccessLogsS3Bucket] != currentLoadBalancerAttributes[lbAttrAccessLogsS3Bucket] { - changedAttributes = append(changedAttributes, &elbv2.LoadBalancerAttribute{ + changedAttributes = append(changedAttributes, elbv2types.LoadBalancerAttribute{ Key: aws.String(lbAttrAccessLogsS3Bucket), Value: aws.String(desiredLoadBalancerAttributes[lbAttrAccessLogsS3Bucket]), }) } if desiredLoadBalancerAttributes[lbAttrAccessLogsS3Prefix] != currentLoadBalancerAttributes[lbAttrAccessLogsS3Prefix] { - changedAttributes = append(changedAttributes, &elbv2.LoadBalancerAttribute{ + changedAttributes = append(changedAttributes, elbv2types.LoadBalancerAttribute{ Key: aws.String(lbAttrAccessLogsS3Prefix), Value: aws.String(desiredLoadBalancerAttributes[lbAttrAccessLogsS3Prefix]), }) @@ -478,7 +482,7 @@ func (c *Cloud) reconcileLBAttributes(loadBalancerArn string, annotations map[st if len(changedAttributes) > 0 { klog.V(2).Infof("updating load-balancer attributes for %q", loadBalancerArn) - _, err = c.elbv2.ModifyLoadBalancerAttributes(&elbv2.ModifyLoadBalancerAttributesInput{ + _, err = c.elbv2.ModifyLoadBalancerAttributes(ctx, &elbv2.ModifyLoadBalancerAttributesInput{ LoadBalancerArn: aws.String(loadBalancerArn), Attributes: changedAttributes, }) @@ -494,17 +498,17 @@ var invalidELBV2NameRegex = regexp.MustCompile("[^[:alnum:]]") // buildTargetGroupName will build unique name for targetGroup of service & port. // the name is in format k8s-{namespace:8}-{name:8}-{uuid:10} (chosen to benefit most common use cases). // Note: nodePort & targetProtocol & targetType are included since they cannot be modified on existing targetGroup. -func (c *Cloud) buildTargetGroupName(serviceName types.NamespacedName, servicePort int64, nodePort int64, targetProtocol string, targetType string, mapping nlbPortMapping) string { +func (c *Cloud) buildTargetGroupName(serviceName types.NamespacedName, servicePort int32, nodePort int32, targetProtocol elbv2types.ProtocolEnum, targetType elbv2types.TargetTypeEnum, mapping nlbPortMapping) string { hasher := sha1.New() _, _ = hasher.Write([]byte(c.tagging.clusterID())) _, _ = hasher.Write([]byte(serviceName.Namespace)) _, _ = hasher.Write([]byte(serviceName.Name)) - _, _ = hasher.Write([]byte(strconv.FormatInt(servicePort, 10))) - _, _ = hasher.Write([]byte(strconv.FormatInt(nodePort, 10))) + _, _ = hasher.Write([]byte(strconv.FormatInt(int64(servicePort), 10))) + _, _ = hasher.Write([]byte(strconv.FormatInt(int64(nodePort), 10))) _, _ = hasher.Write([]byte(targetProtocol)) _, _ = hasher.Write([]byte(targetType)) _, _ = hasher.Write([]byte(mapping.HealthCheckConfig.Protocol)) - _, _ = hasher.Write([]byte(strconv.FormatInt(mapping.HealthCheckConfig.Interval, 10))) + _, _ = hasher.Write([]byte(strconv.FormatInt(int64(mapping.HealthCheckConfig.Interval), 10))) tgUUID := hex.EncodeToString(hasher.Sum(nil)) sanitizedNamespace := invalidELBV2NameRegex.ReplaceAllString(serviceName.Namespace, "") @@ -512,8 +516,8 @@ func (c *Cloud) buildTargetGroupName(serviceName types.NamespacedName, servicePo return fmt.Sprintf("k8s-%.8s-%.8s-%.10s", sanitizedNamespace, sanitizedServiceName, tgUUID) } -func (c *Cloud) createListenerV2(loadBalancerArn *string, mapping nlbPortMapping, namespacedName types.NamespacedName, instanceIDs []string, vpcID string, tags map[string]string) (listener *elbv2.Listener, err error) { - target, err := c.ensureTargetGroup( +func (c *Cloud) createListenerV2(ctx context.Context, loadBalancerArn *string, mapping nlbPortMapping, namespacedName types.NamespacedName, instanceIDs []string, vpcID string, tags map[string]string) (listener *elbv2types.Listener, err error) { + target, err := c.ensureTargetGroup(ctx, nil, namespacedName, mapping, @@ -525,9 +529,9 @@ func (c *Cloud) createListenerV2(loadBalancerArn *string, mapping nlbPortMapping return nil, err } - elbTags := []*elbv2.Tag{} + elbTags := []elbv2types.Tag{} for k, v := range tags { - elbTags = append(elbTags, &elbv2.Tag{ + elbTags = append(elbTags, elbv2types.Tag{ Key: aws.String(k), Value: aws.String(v), }) @@ -535,11 +539,11 @@ func (c *Cloud) createListenerV2(loadBalancerArn *string, mapping nlbPortMapping createListernerInput := &elbv2.CreateListenerInput{ LoadBalancerArn: loadBalancerArn, - Port: aws.Int64(mapping.FrontendPort), - Protocol: aws.String(mapping.FrontendProtocol), - DefaultActions: []*elbv2.Action{{ + Port: aws.Int32(mapping.FrontendPort), + Protocol: mapping.FrontendProtocol, + DefaultActions: []elbv2types.Action{{ TargetGroupArn: target.TargetGroupArn, - Type: aws.String(elbv2.ActionTypeEnumForward), + Type: elbv2types.ActionTypeEnumForward, }}, Tags: elbTags, } @@ -547,7 +551,7 @@ func (c *Cloud) createListenerV2(loadBalancerArn *string, mapping nlbPortMapping if mapping.SSLPolicy != "" { createListernerInput.SslPolicy = aws.String(mapping.SSLPolicy) } - createListernerInput.Certificates = []*elbv2.Certificate{ + createListernerInput.Certificates = []elbv2types.Certificate{ { CertificateArn: aws.String(mapping.SSLCertificateARN), }, @@ -555,20 +559,20 @@ func (c *Cloud) createListenerV2(loadBalancerArn *string, mapping nlbPortMapping } klog.Infof("Creating load balancer listener for %v", namespacedName) - createListenerOutput, err := c.elbv2.CreateListener(createListernerInput) + createListenerOutput, err := c.elbv2.CreateListener(ctx, createListernerInput) if err != nil { return nil, fmt.Errorf("error creating load balancer listener: %q", err) } - return createListenerOutput.Listeners[0], nil + return &createListenerOutput.Listeners[0], nil } // cleans up listener and corresponding target group -func (c *Cloud) deleteListenerV2(listener *elbv2.Listener) error { - _, err := c.elbv2.DeleteListener(&elbv2.DeleteListenerInput{ListenerArn: listener.ListenerArn}) +func (c *Cloud) deleteListenerV2(ctx context.Context, listener *elbv2types.Listener) error { + _, err := c.elbv2.DeleteListener(ctx, &elbv2.DeleteListenerInput{ListenerArn: listener.ListenerArn}) if err != nil { return fmt.Errorf("error deleting load balancer listener: %q", err) } - _, err = c.elbv2.DeleteTargetGroup(&elbv2.DeleteTargetGroupInput{TargetGroupArn: listener.DefaultActions[0].TargetGroupArn}) + _, err = c.elbv2.DeleteTargetGroup(ctx, &elbv2.DeleteTargetGroupInput{TargetGroupArn: listener.DefaultActions[0].TargetGroupArn}) if err != nil { return fmt.Errorf("error deleting load balancer target group: %q", err) } @@ -576,41 +580,41 @@ func (c *Cloud) deleteListenerV2(listener *elbv2.Listener) error { } // ensureTargetGroup creates a target group with a set of instances. -func (c *Cloud) ensureTargetGroup(targetGroup *elbv2.TargetGroup, serviceName types.NamespacedName, mapping nlbPortMapping, instances []string, vpcID string, tags map[string]string) (*elbv2.TargetGroup, error) { +func (c *Cloud) ensureTargetGroup(ctx context.Context, targetGroup *elbv2types.TargetGroup, serviceName types.NamespacedName, mapping nlbPortMapping, instances []string, vpcID string, tags map[string]string) (*elbv2types.TargetGroup, error) { dirty := false expectedTargets := c.computeTargetGroupExpectedTargets(instances, mapping.TrafficPort) if targetGroup == nil { - targetType := "instance" + targetType := elbv2types.TargetTypeEnumInstance name := c.buildTargetGroupName(serviceName, mapping.FrontendPort, mapping.TrafficPort, mapping.TrafficProtocol, targetType, mapping) klog.Infof("Creating load balancer target group for %v with name: %s", serviceName, name) input := &elbv2.CreateTargetGroupInput{ VpcId: aws.String(vpcID), Name: aws.String(name), - Port: aws.Int64(mapping.TrafficPort), - Protocol: aws.String(mapping.TrafficProtocol), - TargetType: aws.String(targetType), - HealthCheckIntervalSeconds: aws.Int64(mapping.HealthCheckConfig.Interval), + Port: aws.Int32(mapping.TrafficPort), + Protocol: mapping.TrafficProtocol, + TargetType: targetType, + HealthCheckIntervalSeconds: aws.Int32(mapping.HealthCheckConfig.Interval), HealthCheckPort: aws.String(mapping.HealthCheckConfig.Port), - HealthCheckProtocol: aws.String(mapping.HealthCheckConfig.Protocol), - HealthyThresholdCount: aws.Int64(mapping.HealthCheckConfig.HealthyThreshold), - UnhealthyThresholdCount: aws.Int64(mapping.HealthCheckConfig.UnhealthyThreshold), + HealthCheckProtocol: mapping.HealthCheckConfig.Protocol, + HealthyThresholdCount: aws.Int32(mapping.HealthCheckConfig.HealthyThreshold), + UnhealthyThresholdCount: aws.Int32(mapping.HealthCheckConfig.UnhealthyThreshold), // HealthCheckTimeoutSeconds: Currently not configurable, 6 seconds for HTTP, 10 for TCP/HTTPS } - if mapping.HealthCheckConfig.Protocol != elbv2.ProtocolEnumTcp { + if mapping.HealthCheckConfig.Protocol != elbv2types.ProtocolEnumTcp { input.HealthCheckPath = aws.String(mapping.HealthCheckConfig.Path) } if len(tags) != 0 { - targetGroupTags := make([]*elbv2.Tag, 0, len(tags)) + targetGroupTags := make([]elbv2types.Tag, 0, len(tags)) for k, v := range tags { - targetGroupTags = append(targetGroupTags, &elbv2.Tag{ + targetGroupTags = append(targetGroupTags, elbv2types.Tag{ Key: aws.String(k), Value: aws.String(v), }) } input.Tags = targetGroupTags } - result, err := c.elbv2.CreateTargetGroup(input) + result, err := c.elbv2.CreateTargetGroup(ctx, input) if err != nil { return nil, fmt.Errorf("error creating load balancer target group: %q", err) } @@ -619,21 +623,21 @@ func (c *Cloud) ensureTargetGroup(targetGroup *elbv2.TargetGroup, serviceName ty } tg := result.TargetGroups[0] - tgARN := aws.StringValue(tg.TargetGroupArn) - if err := c.ensureTargetGroupTargets(tgARN, expectedTargets, nil); err != nil { + tgARN := aws.ToString(tg.TargetGroupArn) + if err := c.ensureTargetGroupTargets(ctx, tgARN, expectedTargets, nil); err != nil { return nil, err } - return tg, nil + return &tg, nil } // handle instances in service { - tgARN := aws.StringValue(targetGroup.TargetGroupArn) - actualTargets, err := c.obtainTargetGroupActualTargets(tgARN) + tgARN := aws.ToString(targetGroup.TargetGroupArn) + actualTargets, err := c.obtainTargetGroupActualTargets(ctx, tgARN) if err != nil { return nil, err } - if err := c.ensureTargetGroupTargets(tgARN, expectedTargets, actualTargets); err != nil { + if err := c.ensureTargetGroupTargets(ctx, tgARN, expectedTargets, actualTargets); err != nil { return nil, err } } @@ -645,24 +649,24 @@ func (c *Cloud) ensureTargetGroup(targetGroup *elbv2.TargetGroup, serviceName ty input := &elbv2.ModifyTargetGroupInput{ TargetGroupArn: targetGroup.TargetGroupArn, } - if mapping.HealthCheckConfig.Port != aws.StringValue(targetGroup.HealthCheckPort) { + if mapping.HealthCheckConfig.Port != aws.ToString(targetGroup.HealthCheckPort) { input.HealthCheckPort = aws.String(mapping.HealthCheckConfig.Port) dirtyHealthCheck = true } - if mapping.HealthCheckConfig.HealthyThreshold != aws.Int64Value(targetGroup.HealthyThresholdCount) { + if mapping.HealthCheckConfig.HealthyThreshold != aws.ToInt32(targetGroup.HealthyThresholdCount) { dirtyHealthCheck = true - input.HealthyThresholdCount = aws.Int64(mapping.HealthCheckConfig.HealthyThreshold) - input.UnhealthyThresholdCount = aws.Int64(mapping.HealthCheckConfig.UnhealthyThreshold) + input.HealthyThresholdCount = aws.Int32(mapping.HealthCheckConfig.HealthyThreshold) + input.UnhealthyThresholdCount = aws.Int32(mapping.HealthCheckConfig.UnhealthyThreshold) } - if !strings.EqualFold(mapping.HealthCheckConfig.Protocol, elbv2.ProtocolEnumTcp) { - if mapping.HealthCheckConfig.Path != aws.StringValue(input.HealthCheckPath) { + if !strings.EqualFold(string(mapping.HealthCheckConfig.Protocol), string(elbv2types.ProtocolEnumTcp)) { + if mapping.HealthCheckConfig.Path != aws.ToString(input.HealthCheckPath) { input.HealthCheckPath = aws.String(mapping.HealthCheckConfig.Path) dirtyHealthCheck = true } } if dirtyHealthCheck { - _, err := c.elbv2.ModifyTargetGroup(input) + _, err := c.elbv2.ModifyTargetGroup(ctx, input) if err != nil { return nil, fmt.Errorf("error modifying target group health check: %q", err) } @@ -672,19 +676,19 @@ func (c *Cloud) ensureTargetGroup(targetGroup *elbv2.TargetGroup, serviceName ty } if dirty { - result, err := c.elbv2.DescribeTargetGroups(&elbv2.DescribeTargetGroupsInput{ - TargetGroupArns: []*string{targetGroup.TargetGroupArn}, + result, err := c.elbv2.DescribeTargetGroups(ctx, &elbv2.DescribeTargetGroupsInput{ + TargetGroupArns: []string{aws.ToString(targetGroup.TargetGroupArn)}, }) if err != nil { return nil, fmt.Errorf("error retrieving target group after creation/update: %q", err) } - targetGroup = result.TargetGroups[0] + targetGroup = &result.TargetGroups[0] } return targetGroup, nil } -func (c *Cloud) ensureTargetGroupTargets(tgARN string, expectedTargets []*elbv2.TargetDescription, actualTargets []*elbv2.TargetDescription) error { +func (c *Cloud) ensureTargetGroupTargets(ctx context.Context, tgARN string, expectedTargets []*elbv2types.TargetDescription, actualTargets []*elbv2types.TargetDescription) error { targetsToRegister, targetsToDeregister := c.diffTargetGroupTargets(expectedTargets, actualTargets) if len(targetsToRegister) > 0 { targetsToRegisterChunks := c.chunkTargetDescriptions(targetsToRegister, defaultRegisterTargetsChunkSize) @@ -693,7 +697,7 @@ func (c *Cloud) ensureTargetGroupTargets(tgARN string, expectedTargets []*elbv2. TargetGroupArn: aws.String(tgARN), Targets: targetsChunk, } - if _, err := c.elbv2.RegisterTargets(req); err != nil { + if _, err := c.elbv2.RegisterTargets(ctx, req); err != nil { return fmt.Errorf("error trying to register targets in target group: %q", err) } } @@ -705,7 +709,7 @@ func (c *Cloud) ensureTargetGroupTargets(tgARN string, expectedTargets []*elbv2. TargetGroupArn: aws.String(tgARN), Targets: targetsChunk, } - if _, err := c.elbv2.DeregisterTargets(req); err != nil { + if _, err := c.elbv2.DeregisterTargets(ctx, req); err != nil { return fmt.Errorf("error trying to deregister targets in target group: %q", err) } } @@ -713,28 +717,28 @@ func (c *Cloud) ensureTargetGroupTargets(tgARN string, expectedTargets []*elbv2. return nil } -func (c *Cloud) computeTargetGroupExpectedTargets(instanceIDs []string, port int64) []*elbv2.TargetDescription { - expectedTargets := make([]*elbv2.TargetDescription, 0, len(instanceIDs)) +func (c *Cloud) computeTargetGroupExpectedTargets(instanceIDs []string, port int32) []*elbv2types.TargetDescription { + expectedTargets := make([]*elbv2types.TargetDescription, 0, len(instanceIDs)) for _, instanceID := range instanceIDs { - expectedTargets = append(expectedTargets, &elbv2.TargetDescription{ + expectedTargets = append(expectedTargets, &elbv2types.TargetDescription{ Id: aws.String(instanceID), - Port: aws.Int64(port), + Port: aws.Int32(port), }) } return expectedTargets } -func (c *Cloud) obtainTargetGroupActualTargets(tgARN string) ([]*elbv2.TargetDescription, error) { +func (c *Cloud) obtainTargetGroupActualTargets(ctx context.Context, tgARN string) ([]*elbv2types.TargetDescription, error) { req := &elbv2.DescribeTargetHealthInput{ TargetGroupArn: aws.String(tgARN), } - resp, err := c.elbv2.DescribeTargetHealth(req) + resp, err := c.elbv2.DescribeTargetHealth(ctx, req) if err != nil { return nil, fmt.Errorf("error describing target group health: %q", err) } - actualTargets := make([]*elbv2.TargetDescription, 0, len(resp.TargetHealthDescriptions)) + actualTargets := make([]*elbv2types.TargetDescription, 0, len(resp.TargetHealthDescriptions)) for _, targetDesc := range resp.TargetHealthDescriptions { - if targetDesc.TargetHealth.Reason != nil && aws.StringValue(targetDesc.TargetHealth.Reason) == elbv2.TargetHealthReasonEnumTargetDeregistrationInProgress { + if targetDesc.TargetHealth.Reason == elbv2types.TargetHealthReasonEnumDeregistrationInProgress { continue } actualTargets = append(actualTargets, targetDesc.Target) @@ -743,16 +747,16 @@ func (c *Cloud) obtainTargetGroupActualTargets(tgARN string) ([]*elbv2.TargetDes } // diffTargetGroupTargets computes the targets to register and targets to deregister based on existingTargets and desired instances. -func (c *Cloud) diffTargetGroupTargets(expectedTargets []*elbv2.TargetDescription, actualTargets []*elbv2.TargetDescription) (targetsToRegister []*elbv2.TargetDescription, targetsToDeregister []*elbv2.TargetDescription) { - expectedTargetsByUID := make(map[string]*elbv2.TargetDescription, len(expectedTargets)) +func (c *Cloud) diffTargetGroupTargets(expectedTargets []*elbv2types.TargetDescription, actualTargets []*elbv2types.TargetDescription) (targetsToRegister []elbv2types.TargetDescription, targetsToDeregister []elbv2types.TargetDescription) { + expectedTargetsByUID := make(map[string]elbv2types.TargetDescription, len(expectedTargets)) for _, target := range expectedTargets { - targetUID := fmt.Sprintf("%v:%v", aws.StringValue(target.Id), aws.Int64Value(target.Port)) - expectedTargetsByUID[targetUID] = target + targetUID := fmt.Sprintf("%v:%v", aws.ToString(target.Id), aws.ToInt32(target.Port)) + expectedTargetsByUID[targetUID] = *target } - actualTargetsByUID := make(map[string]*elbv2.TargetDescription, len(actualTargets)) + actualTargetsByUID := make(map[string]elbv2types.TargetDescription, len(actualTargets)) for _, target := range actualTargets { - targetUID := fmt.Sprintf("%v:%v", aws.StringValue(target.Id), aws.Int64Value(target.Port)) - actualTargetsByUID[targetUID] = target + targetUID := fmt.Sprintf("%v:%v", aws.ToString(target.Id), aws.ToInt32(target.Port)) + actualTargetsByUID[targetUID] = *target } expectedTargetsUIDs := sets.StringKeySet(expectedTargetsByUID) @@ -767,8 +771,8 @@ func (c *Cloud) diffTargetGroupTargets(expectedTargets []*elbv2.TargetDescriptio } // chunkTargetDescriptions will split slice of TargetDescription into chunks -func (c *Cloud) chunkTargetDescriptions(targets []*elbv2.TargetDescription, chunkSize int) [][]*elbv2.TargetDescription { - var chunks [][]*elbv2.TargetDescription +func (c *Cloud) chunkTargetDescriptions(targets []elbv2types.TargetDescription, chunkSize int) [][]elbv2types.TargetDescription { + var chunks [][]elbv2types.TargetDescription for i := 0; i < len(targets); i += chunkSize { end := i + chunkSize if end > len(targets) { @@ -781,12 +785,12 @@ func (c *Cloud) chunkTargetDescriptions(targets []*elbv2.TargetDescription, chun // updateInstanceSecurityGroupsForNLB will adjust securityGroup's settings to allow inbound traffic into instances from clientCIDRs and portMappings. // TIP: if either instances or clientCIDRs or portMappings are nil, then the securityGroup rules for lbName are cleared. -func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[InstanceID]*ec2.Instance, subnetCIDRs []string, clientCIDRs []string, portMappings []nlbPortMapping) error { +func (c *Cloud) updateInstanceSecurityGroupsForNLB(ctx context.Context, lbName string, instances map[InstanceID]*ec2types.Instance, subnetCIDRs []string, clientCIDRs []string, portMappings []nlbPortMapping) error { if c.cfg.Global.DisableSecurityGroupIngress { return nil } - clusterSGs, err := c.getTaggedSecurityGroups() + clusterSGs, err := c.getTaggedSecurityGroups(ctx) if err != nil { return fmt.Errorf("error querying for tagged security groups: %q", err) } @@ -798,17 +802,17 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[ return err } if sg == nil { - klog.Warningf("Ignoring instance without security group: %s", aws.StringValue(instance.InstanceId)) + klog.Warningf("Ignoring instance without security group: %s", aws.ToString(instance.InstanceId)) continue } - desiredSGIDs.Insert(aws.StringValue(sg.GroupId)) + desiredSGIDs.Insert(aws.ToString(sg.GroupId)) } // TODO(@M00nF1sh): do we really needs to support SG without cluster tag at current version? // findSecurityGroupForInstance might return SG that are not tagged. { for sgID := range desiredSGIDs.Difference(sets.StringKeySet(clusterSGs)) { - sg, err := c.findSecurityGroup(sgID) + sg, err := c.findSecurityGroup(ctx, sgID) if err != nil { return fmt.Errorf("error finding instance group: %q", err) } @@ -817,20 +821,21 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[ } { - clientPorts := sets.Int64{} + clientPorts := sets.Set[int32]{} clientProtocol := "tcp" - healthCheckPorts := sets.Int64{} + healthCheckPorts := sets.Set[int32]{} for _, port := range portMappings { clientPorts.Insert(port.TrafficPort) hcPort := port.TrafficPort if port.HealthCheckConfig.Port != defaultHealthCheckPort { - var err error - if hcPort, err = strconv.ParseInt(port.HealthCheckConfig.Port, 10, 0); err != nil { + hcPort64, err := strconv.ParseInt(port.HealthCheckConfig.Port, 10, 0) + if err != nil { return fmt.Errorf("Invalid health check port %v", port.HealthCheckConfig.Port) } + hcPort = int32(hcPort64) } healthCheckPorts.Insert(hcPort) - if port.TrafficProtocol == string(v1.ProtocolUDP) { + if port.TrafficProtocol == elbv2types.ProtocolEnumUdp { clientProtocol = "udp" } } @@ -842,23 +847,23 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[ // If the client rule is 1) all addresses 2) tcp and 3) has same ports as the healthcheck, // then the health rules are a subset of the client rule and are not needed. if len(clientCIDRs) != 1 || clientCIDRs[0] != "0.0.0.0/0" || clientProtocol != "tcp" || !healthCheckPorts.Equal(clientPorts) { - if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", healthCheckPorts, subnetCIDRs); err != nil { + if err := c.updateInstanceSecurityGroupForNLBTraffic(ctx, sgID, sgPerms, healthRuleAnnotation, "tcp", healthCheckPorts, subnetCIDRs); err != nil { return err } } - if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, clientProtocol, clientPorts, clientCIDRs); err != nil { + if err := c.updateInstanceSecurityGroupForNLBTraffic(ctx, sgID, sgPerms, clientRuleAnnotation, clientProtocol, clientPorts, clientCIDRs); err != nil { return err } } else { - if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", nil, nil); err != nil { + if err := c.updateInstanceSecurityGroupForNLBTraffic(ctx, sgID, sgPerms, healthRuleAnnotation, "tcp", nil, nil); err != nil { return err } - if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, clientProtocol, nil, nil); err != nil { + if err := c.updateInstanceSecurityGroupForNLBTraffic(ctx, sgID, sgPerms, clientRuleAnnotation, clientProtocol, nil, nil); err != nil { return err } } if !sgPerms.Equal(NewIPPermissionSet(sg.IpPermissions...).Ungroup()) { - if err := c.updateInstanceSecurityGroupForNLBMTU(sgID, sgPerms); err != nil { + if err := c.updateInstanceSecurityGroupForNLBMTU(ctx, sgID, sgPerms); err != nil { return err } } @@ -869,15 +874,15 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[ // updateInstanceSecurityGroupForNLBTraffic will manage permissions set(identified by ruleDesc) on securityGroup to match desired set(allow protocol traffic from ports/cidr). // Note: sgPerms will be updated to reflect the current permission set on SG after update. -func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(sgID string, sgPerms IPPermissionSet, ruleDesc string, protocol string, ports sets.Int64, cidrs []string) error { +func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(ctx context.Context, sgID string, sgPerms IPPermissionSet, ruleDesc string, protocol string, ports sets.Set[int32], cidrs []string) error { desiredPerms := NewIPPermissionSet() for port := range ports { for _, cidr := range cidrs { - desiredPerms.Insert(&ec2.IpPermission{ + desiredPerms.Insert(ec2types.IpPermission{ IpProtocol: aws.String(protocol), - FromPort: aws.Int64(port), - ToPort: aws.Int64(port), - IpRanges: []*ec2.IpRange{ + FromPort: aws.Int32(int32(port)), + ToPort: aws.Int32(int32(port)), + IpRanges: []ec2types.IpRange{ { CidrIp: aws.String(cidr), Description: aws.String(ruleDesc), @@ -892,7 +897,7 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(sgID string, sgPerms IP permsToRevoke.DeleteIf(IPPermissionNotMatch{IPPermissionMatchDesc{ruleDesc}}) if len(permsToRevoke) > 0 { permsToRevokeList := permsToRevoke.List() - changed, err := c.removeSecurityGroupIngress(sgID, permsToRevokeList) + changed, err := c.removeSecurityGroupIngress(ctx, sgID, permsToRevokeList) if err != nil { klog.Warningf("Error remove traffic permission from security group: %q", err) return err @@ -904,7 +909,7 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(sgID string, sgPerms IP } if len(permsToGrant) > 0 { permsToGrantList := permsToGrant.List() - changed, err := c.addSecurityGroupIngress(sgID, permsToGrantList) + changed, err := c.addSecurityGroupIngress(ctx, sgID, permsToGrantList) if err != nil { klog.Warningf("Error add traffic permission to security group: %q", err) return err @@ -918,16 +923,16 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(sgID string, sgPerms IP } // Note: sgPerms will be updated to reflect the current permission set on SG after update. -func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(sgID string, sgPerms IPPermissionSet) error { +func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(ctx context.Context, sgID string, sgPerms IPPermissionSet) error { desiredPerms := NewIPPermissionSet() for _, perm := range sgPerms { for _, ipRange := range perm.IpRanges { - if strings.Contains(aws.StringValue(ipRange.Description), NLBClientRuleDescription) { - desiredPerms.Insert(&ec2.IpPermission{ + if strings.Contains(aws.ToString(ipRange.Description), NLBClientRuleDescription) { + desiredPerms.Insert(ec2types.IpPermission{ IpProtocol: aws.String("icmp"), - FromPort: aws.Int64(3), - ToPort: aws.Int64(4), - IpRanges: []*ec2.IpRange{ + FromPort: aws.Int32(3), + ToPort: aws.Int32(4), + IpRanges: []ec2types.IpRange{ { CidrIp: ipRange.CidrIp, Description: aws.String(NLBMtuDiscoveryRuleDescription), @@ -943,7 +948,7 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(sgID string, sgPerms IPPerm permsToRevoke.DeleteIf(IPPermissionNotMatch{IPPermissionMatchDesc{NLBMtuDiscoveryRuleDescription}}) if len(permsToRevoke) > 0 { permsToRevokeList := permsToRevoke.List() - changed, err := c.removeSecurityGroupIngress(sgID, permsToRevokeList) + changed, err := c.removeSecurityGroupIngress(ctx, sgID, permsToRevokeList) if err != nil { klog.Warningf("Error remove MTU permission from security group: %q", err) return err @@ -956,7 +961,7 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(sgID string, sgPerms IPPerm } if len(permsToGrant) > 0 { permsToGrantList := permsToGrant.List() - changed, err := c.addSecurityGroupIngress(sgID, permsToGrantList) + changed, err := c.addSecurityGroupIngress(ctx, sgID, permsToGrantList) if err != nil { klog.Warningf("Error add MTU permission to security group: %q", err) return err @@ -969,8 +974,8 @@ func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(sgID string, sgPerms IPPerm return nil } -func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBalancerName string, listeners []*elb.Listener, subnetIDs []string, securityGroupIDs []string, internalELB, proxyProtocol bool, loadBalancerAttributes *elb.LoadBalancerAttributes, annotations map[string]string) (*elb.LoadBalancerDescription, error) { - loadBalancer, err := c.describeLoadBalancer(loadBalancerName) +func (c *Cloud) ensureLoadBalancer(ctx context.Context, namespacedName types.NamespacedName, loadBalancerName string, listeners []elbtypes.Listener, subnetIDs []string, securityGroupIDs []string, internalELB, proxyProtocol bool, loadBalancerAttributes *elbtypes.LoadBalancerAttributes, annotations map[string]string) (*elbtypes.LoadBalancerDescription, error) { + loadBalancer, err := c.describeLoadBalancer(ctx, loadBalancerName) if err != nil { return nil, err } @@ -992,13 +997,13 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala if subnetIDs == nil { createRequest.Subnets = nil } else { - createRequest.Subnets = aws.StringSlice(subnetIDs) + createRequest.Subnets = subnetIDs } if securityGroupIDs == nil { createRequest.SecurityGroups = nil } else { - createRequest.SecurityGroups = aws.StringSlice(securityGroupIDs) + createRequest.SecurityGroups = securityGroupIDs } // Get additional tags set by the user @@ -1009,26 +1014,26 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala tags = c.tagging.buildTags(ResourceLifecycleOwned, tags) for k, v := range tags { - createRequest.Tags = append(createRequest.Tags, &elb.Tag{ + createRequest.Tags = append(createRequest.Tags, elbtypes.Tag{ Key: aws.String(k), Value: aws.String(v), }) } klog.Infof("Creating load balancer for %v with name: %s", namespacedName, loadBalancerName) - _, err := c.elb.CreateLoadBalancer(createRequest) + _, err := c.elb.CreateLoadBalancer(ctx, createRequest) if err != nil { return nil, err } if proxyProtocol { - err = c.createProxyProtocolPolicy(loadBalancerName) + err = c.createProxyProtocolPolicy(ctx, loadBalancerName) if err != nil { return nil, err } for _, listener := range listeners { klog.V(2).Infof("Adjusting AWS loadbalancer proxy protocol on node port %d. Setting to true", *listener.InstancePort) - err := c.setBackendPolicies(loadBalancerName, *listener.InstancePort, []*string{aws.String(ProxyProtocolPolicyName)}) + err := c.setBackendPolicies(ctx, loadBalancerName, listener.InstancePort, []*string{aws.String(ProxyProtocolPolicyName)}) if err != nil { return nil, err } @@ -1041,8 +1046,8 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala { // Sync subnets - expected := sets.NewString(subnetIDs...) - actual := stringSetFromPointers(loadBalancer.Subnets) + expected := sets.New[string](subnetIDs...) + actual := sets.New[string](loadBalancer.Subnets...) additions := expected.Difference(actual) removals := actual.Difference(expected) @@ -1050,9 +1055,9 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala if removals.Len() != 0 { request := &elb.DetachLoadBalancerFromSubnetsInput{} request.LoadBalancerName = aws.String(loadBalancerName) - request.Subnets = stringSetToPointers(removals) + request.Subnets = stringSetToList(removals) klog.V(2).Info("Detaching load balancer from removed subnets") - _, err := c.elb.DetachLoadBalancerFromSubnets(request) + _, err := c.elb.DetachLoadBalancerFromSubnets(ctx, request) if err != nil { return nil, fmt.Errorf("error detaching AWS loadbalancer from subnets: %q", err) } @@ -1062,9 +1067,9 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala if additions.Len() != 0 { request := &elb.AttachLoadBalancerToSubnetsInput{} request.LoadBalancerName = aws.String(loadBalancerName) - request.Subnets = stringSetToPointers(additions) + request.Subnets = stringSetToList(additions) klog.V(2).Info("Attaching load balancer to added subnets") - _, err := c.elb.AttachLoadBalancerToSubnets(request) + _, err := c.elb.AttachLoadBalancerToSubnets(ctx, request) if err != nil { return nil, fmt.Errorf("error attaching AWS loadbalancer to subnets: %q", err) } @@ -1074,8 +1079,8 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala { // Sync security groups - expected := sets.NewString(securityGroupIDs...) - actual := stringSetFromPointers(loadBalancer.SecurityGroups) + expected := sets.New[string](securityGroupIDs...) + actual := stringSetFromList(loadBalancer.SecurityGroups) if !expected.Equal(actual) { // This call just replaces the security groups, unlike e.g. subnets (!) @@ -1084,10 +1089,10 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala if securityGroupIDs == nil { request.SecurityGroups = nil } else { - request.SecurityGroups = aws.StringSlice(securityGroupIDs) + request.SecurityGroups = securityGroupIDs } klog.V(2).Info("Applying updated security groups to load balancer") - _, err := c.elb.ApplySecurityGroupsToLoadBalancer(request) + _, err := c.elb.ApplySecurityGroupsToLoadBalancer(ctx, request) if err != nil { return nil, fmt.Errorf("error applying AWS loadbalancer security groups: %q", err) } @@ -1103,7 +1108,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala request.LoadBalancerName = aws.String(loadBalancerName) request.LoadBalancerPorts = removals klog.V(2).Info("Deleting removed load balancer listeners") - if _, err := c.elb.DeleteLoadBalancerListeners(request); err != nil { + if _, err := c.elb.DeleteLoadBalancerListeners(ctx, request); err != nil { return nil, fmt.Errorf("error deleting AWS loadbalancer listeners: %q", err) } dirty = true @@ -1114,7 +1119,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala request.LoadBalancerName = aws.String(loadBalancerName) request.Listeners = additions klog.V(2).Info("Creating added load balancer listeners") - if _, err := c.elb.CreateLoadBalancerListeners(request); err != nil { + if _, err := c.elb.CreateLoadBalancerListeners(ctx, request); err != nil { return nil, fmt.Errorf("error creating AWS loadbalancer listeners: %q", err) } dirty = true @@ -1132,7 +1137,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala // back if a policy of the same name already exists. However, the aws-sdk does not // seem to return an error to us in these cases. Therefore, this will issue an API // request every time. - err := c.createProxyProtocolPolicy(loadBalancerName) + err := c.createProxyProtocolPolicy(ctx, loadBalancerName) if err != nil { return nil, err } @@ -1140,11 +1145,11 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala proxyPolicies = append(proxyPolicies, aws.String(ProxyProtocolPolicyName)) } - foundBackends := make(map[int64]bool) - proxyProtocolBackends := make(map[int64]bool) + foundBackends := make(map[int32]bool) + proxyProtocolBackends := make(map[int32]bool) for _, backendListener := range loadBalancer.BackendServerDescriptions { - foundBackends[*backendListener.InstancePort] = false - proxyProtocolBackends[*backendListener.InstancePort] = proxyProtocolEnabled(backendListener) + foundBackends[aws.ToInt32(backendListener.InstancePort)] = false + proxyProtocolBackends[aws.ToInt32(backendListener.InstancePort)] = proxyProtocolEnabled(backendListener) } for _, listener := range listeners { @@ -1165,7 +1170,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala if setPolicy { klog.V(2).Infof("Adjusting AWS loadbalancer proxy protocol on node port %d. Setting to %t", instancePort, proxyProtocol) - err := c.setBackendPolicies(loadBalancerName, instancePort, proxyPolicies) + err := c.setBackendPolicies(ctx, loadBalancerName, aws.Int32(instancePort), proxyPolicies) if err != nil { return nil, err } @@ -1179,7 +1184,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala for instancePort, found := range foundBackends { if !found { klog.V(2).Infof("Adjusting AWS loadbalancer proxy protocol on node port %d. Setting to false", instancePort) - err := c.setBackendPolicies(loadBalancerName, instancePort, []*string{}) + err := c.setBackendPolicies(ctx, loadBalancerName, aws.Int32(instancePort), []*string{}) if err != nil { return nil, err } @@ -1193,7 +1198,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala klog.V(2).Infof("Creating additional load balancer tags for %s", loadBalancerName) tags := getKeyValuePropertiesFromAnnotation(annotations, ServiceAnnotationLoadBalancerAdditionalTags) if len(tags) > 0 { - err := c.addLoadBalancerTags(loadBalancerName, tags) + err := c.addLoadBalancerTags(ctx, loadBalancerName, tags) if err != nil { return nil, fmt.Errorf("unable to create additional load balancer tags: %v", err) } @@ -1207,7 +1212,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala { describeAttributesRequest := &elb.DescribeLoadBalancerAttributesInput{} describeAttributesRequest.LoadBalancerName = aws.String(loadBalancerName) - describeAttributesOutput, err := c.elb.DescribeLoadBalancerAttributes(describeAttributesRequest) + describeAttributesOutput, err := c.elb.DescribeLoadBalancerAttributes(ctx, describeAttributesRequest) if err != nil { klog.Warning("Unable to retrieve load balancer attributes during attribute sync") return nil, err @@ -1222,7 +1227,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala modifyAttributesRequest := &elb.ModifyLoadBalancerAttributesInput{} modifyAttributesRequest.LoadBalancerName = aws.String(loadBalancerName) modifyAttributesRequest.LoadBalancerAttributes = loadBalancerAttributes - _, err = c.elb.ModifyLoadBalancerAttributes(modifyAttributesRequest) + _, err = c.elb.ModifyLoadBalancerAttributes(ctx, modifyAttributesRequest) if err != nil { return nil, fmt.Errorf("Unable to update load balancer attributes during attribute sync: %q", err) } @@ -1231,7 +1236,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala } if dirty { - loadBalancer, err = c.describeLoadBalancer(loadBalancerName) + loadBalancer, err = c.describeLoadBalancer(ctx, loadBalancerName) if err != nil { klog.Warning("Unable to retrieve load balancer after creation/update") return nil, err @@ -1245,10 +1250,10 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala // NOTE: there exists an O(nlgn) implementation for this function. However, as the default limit of // // listeners per elb is 100, this implementation is reduced from O(m*n) => O(n). -func syncElbListeners(loadBalancerName string, listeners []*elb.Listener, listenerDescriptions []*elb.ListenerDescription) ([]*elb.Listener, []*int64) { +func syncElbListeners(loadBalancerName string, listeners []elbtypes.Listener, listenerDescriptions []elbtypes.ListenerDescription) ([]elbtypes.Listener, []int32) { foundSet := make(map[int]bool) - removals := []*int64{} - additions := []*elb.Listener{} + removals := []int32{} + additions := []elbtypes.Listener{} for _, listenerDescription := range listenerDescriptions { actual := listenerDescription.Listener @@ -1259,11 +1264,7 @@ func syncElbListeners(loadBalancerName string, listeners []*elb.Listener, listen found := false for i, expected := range listeners { - if expected == nil { - klog.Warning("Ignoring empty desired listener for loadbalancer: ", loadBalancerName) - continue - } - if elbListenersAreEqual(actual, expected) { + if elbListenersAreEqual(*actual, expected) { // The current listener on the actual // elb is in the set of desired listeners. foundSet[i] = true @@ -1285,17 +1286,17 @@ func syncElbListeners(loadBalancerName string, listeners []*elb.Listener, listen return additions, removals } -func elbListenersAreEqual(actual, expected *elb.Listener) bool { +func elbListenersAreEqual(actual, expected elbtypes.Listener) bool { if !elbProtocolsAreEqual(actual.Protocol, expected.Protocol) { return false } if !elbProtocolsAreEqual(actual.InstanceProtocol, expected.InstanceProtocol) { return false } - if aws.Int64Value(actual.InstancePort) != aws.Int64Value(expected.InstancePort) { + if aws.ToInt32(actual.InstancePort) != aws.ToInt32(expected.InstancePort) { return false } - if aws.Int64Value(actual.LoadBalancerPort) != aws.Int64Value(expected.LoadBalancerPort) { + if actual.LoadBalancerPort != expected.LoadBalancerPort { return false } if !awsArnEquals(actual.SSLCertificateId, expected.SSLCertificateId) { @@ -1304,11 +1305,11 @@ func elbListenersAreEqual(actual, expected *elb.Listener) bool { return true } -func createSubnetMappings(subnetIDs []string, allocationIDs []string) []*elbv2.SubnetMapping { - response := []*elbv2.SubnetMapping{} +func createSubnetMappings(subnetIDs []string, allocationIDs []string) []elbv2types.SubnetMapping { + response := []elbv2types.SubnetMapping{} for index, id := range subnetIDs { - sm := &elbv2.SubnetMapping{SubnetId: aws.String(id)} + sm := elbv2types.SubnetMapping{SubnetId: aws.String(id)} if len(allocationIDs) > 0 { sm.AllocationId = aws.String(allocationIDs[index]) } @@ -1324,7 +1325,7 @@ func elbProtocolsAreEqual(l, r *string) bool { if l == nil || r == nil { return l == r } - return strings.EqualFold(aws.StringValue(l), aws.StringValue(r)) + return strings.EqualFold(aws.ToString(l), aws.ToString(r)) } // awsArnEquals checks if two ARN strings are considered the same @@ -1333,23 +1334,23 @@ func awsArnEquals(l, r *string) bool { if l == nil || r == nil { return l == r } - return strings.EqualFold(aws.StringValue(l), aws.StringValue(r)) + return strings.EqualFold(aws.ToString(l), aws.ToString(r)) } // getExpectedHealthCheck returns an elb.Healthcheck for the provided target // and using either sensible defaults or overrides via Service annotations -func (c *Cloud) getExpectedHealthCheck(target string, annotations map[string]string) (*elb.HealthCheck, error) { - healthcheck := &elb.HealthCheck{Target: &target} - getOrDefault := func(annotation string, defaultValue int64) (*int64, error) { - i64 := defaultValue - var err error +func (c *Cloud) getExpectedHealthCheck(target string, annotations map[string]string) (*elbtypes.HealthCheck, error) { + healthcheck := &elbtypes.HealthCheck{Target: &target} + getOrDefault := func(annotation string, defaultValue int32) (*int32, error) { + i32 := defaultValue if s, ok := annotations[annotation]; ok { - i64, err = strconv.ParseInt(s, 10, 0) + i64, err := strconv.ParseInt(s, 10, 0) if err != nil { return nil, fmt.Errorf("failed parsing health check annotation value: %v", err) } + i32 = int32(i64) } - return &i64, nil + return &i32, nil } var err error healthcheck.HealthyThreshold, err = getOrDefault(ServiceAnnotationLoadBalancerHCHealthyThreshold, defaultElbHCHealthyThreshold) @@ -1368,15 +1369,15 @@ func (c *Cloud) getExpectedHealthCheck(target string, annotations map[string]str if err != nil { return nil, err } - if err = healthcheck.Validate(); err != nil { + if err = ValidateHealthCheck(healthcheck); err != nil { return nil, fmt.Errorf("some of the load balancer health check parameters are invalid: %v", err) } return healthcheck, nil } // Makes sure that the health check for an ELB matches the configured health check node port -func (c *Cloud) ensureLoadBalancerHealthCheck(loadBalancer *elb.LoadBalancerDescription, protocol string, port int32, path string, annotations map[string]string) error { - name := aws.StringValue(loadBalancer.LoadBalancerName) +func (c *Cloud) ensureLoadBalancerHealthCheck(ctx context.Context, loadBalancer *elbtypes.LoadBalancerDescription, protocol string, port int32, path string, annotations map[string]string) error { + name := aws.ToString(loadBalancer.LoadBalancerName) actual := loadBalancer.HealthCheck // Override healthcheck protocol, port and path based on annotations @@ -1409,11 +1410,11 @@ func (c *Cloud) ensureLoadBalancerHealthCheck(loadBalancer *elb.LoadBalancerDesc // comparing attributes 1 by 1 to avoid breakage in case a new field is // added to the HC which breaks the equality - if aws.StringValue(expected.Target) == aws.StringValue(actual.Target) && - aws.Int64Value(expected.HealthyThreshold) == aws.Int64Value(actual.HealthyThreshold) && - aws.Int64Value(expected.UnhealthyThreshold) == aws.Int64Value(actual.UnhealthyThreshold) && - aws.Int64Value(expected.Interval) == aws.Int64Value(actual.Interval) && - aws.Int64Value(expected.Timeout) == aws.Int64Value(actual.Timeout) { + if aws.ToString(expected.Target) == aws.ToString(actual.Target) && + aws.ToInt32(expected.HealthyThreshold) == aws.ToInt32(actual.HealthyThreshold) && + aws.ToInt32(expected.UnhealthyThreshold) == aws.ToInt32(actual.UnhealthyThreshold) && + aws.ToInt32(expected.Interval) == aws.ToInt32(actual.Interval) && + aws.ToInt32(expected.Timeout) == aws.ToInt32(actual.Timeout) { return nil } @@ -1421,7 +1422,7 @@ func (c *Cloud) ensureLoadBalancerHealthCheck(loadBalancer *elb.LoadBalancerDesc request.HealthCheck = expected request.LoadBalancerName = loadBalancer.LoadBalancerName - _, err = c.elb.ConfigureHealthCheck(request) + _, err = c.elb.ConfigureHealthCheck(ctx, request) if err != nil { return fmt.Errorf("error configuring load balancer health check for %q: %q", name, err) } @@ -1430,7 +1431,7 @@ func (c *Cloud) ensureLoadBalancerHealthCheck(loadBalancer *elb.LoadBalancerDesc } // Makes sure that exactly the specified hosts are registered as instances with the load balancer -func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances []*elb.Instance, instanceIDs map[InstanceID]*ec2.Instance) error { +func (c *Cloud) ensureLoadBalancerInstances(ctx context.Context, loadBalancerName string, lbInstances []elbtypes.Instance, instanceIDs map[InstanceID]*ec2types.Instance) error { expected := sets.NewString() for id := range instanceIDs { expected.Insert(string(id)) @@ -1438,22 +1439,22 @@ func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances actual := sets.NewString() for _, lbInstance := range lbInstances { - actual.Insert(aws.StringValue(lbInstance.InstanceId)) + actual.Insert(aws.ToString(lbInstance.InstanceId)) } additions := expected.Difference(actual) removals := actual.Difference(expected) - addInstances := []*elb.Instance{} + addInstances := []elbtypes.Instance{} for _, instanceID := range additions.List() { - addInstance := &elb.Instance{} + addInstance := elbtypes.Instance{} addInstance.InstanceId = aws.String(instanceID) addInstances = append(addInstances, addInstance) } - removeInstances := []*elb.Instance{} + removeInstances := []elbtypes.Instance{} for _, instanceID := range removals.List() { - removeInstance := &elb.Instance{} + removeInstance := elbtypes.Instance{} removeInstance.InstanceId = aws.String(instanceID) removeInstances = append(removeInstances, removeInstance) } @@ -1462,7 +1463,7 @@ func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances registerRequest := &elb.RegisterInstancesWithLoadBalancerInput{} registerRequest.Instances = addInstances registerRequest.LoadBalancerName = aws.String(loadBalancerName) - _, err := c.elb.RegisterInstancesWithLoadBalancer(registerRequest) + _, err := c.elb.RegisterInstancesWithLoadBalancer(ctx, registerRequest) if err != nil { return err } @@ -1473,7 +1474,7 @@ func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances deregisterRequest := &elb.DeregisterInstancesFromLoadBalancerInput{} deregisterRequest.Instances = removeInstances deregisterRequest.LoadBalancerName = aws.String(loadBalancerName) - _, err := c.elb.DeregisterInstancesFromLoadBalancer(deregisterRequest) + _, err := c.elb.DeregisterInstancesFromLoadBalancer(ctx, deregisterRequest) if err != nil { return err } @@ -1483,33 +1484,30 @@ func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances return nil } -func (c *Cloud) getLoadBalancerTLSPorts(loadBalancer *elb.LoadBalancerDescription) []int64 { +func (c *Cloud) getLoadBalancerTLSPorts(loadBalancer *elbtypes.LoadBalancerDescription) []int64 { ports := []int64{} for _, listenerDescription := range loadBalancer.ListenerDescriptions { - protocol := aws.StringValue(listenerDescription.Listener.Protocol) + protocol := aws.ToString(listenerDescription.Listener.Protocol) if protocol == "SSL" || protocol == "HTTPS" { - ports = append(ports, aws.Int64Value(listenerDescription.Listener.LoadBalancerPort)) + ports = append(ports, int64(listenerDescription.Listener.LoadBalancerPort)) } } return ports } -func (c *Cloud) ensureSSLNegotiationPolicy(loadBalancer *elb.LoadBalancerDescription, policyName string) error { +func (c *Cloud) ensureSSLNegotiationPolicy(ctx context.Context, loadBalancer *elbtypes.LoadBalancerDescription, policyName string) error { klog.V(2).Info("Describing load balancer policies on load balancer") - result, err := c.elb.DescribeLoadBalancerPolicies(&elb.DescribeLoadBalancerPoliciesInput{ + result, err := c.elb.DescribeLoadBalancerPolicies(ctx, &elb.DescribeLoadBalancerPoliciesInput{ LoadBalancerName: loadBalancer.LoadBalancerName, - PolicyNames: []*string{ - aws.String(fmt.Sprintf(SSLNegotiationPolicyNameFormat, policyName)), + PolicyNames: []string{ + fmt.Sprintf(SSLNegotiationPolicyNameFormat, policyName), }, }) if err != nil { - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - case elb.ErrCodePolicyNotFoundException: - default: - return fmt.Errorf("error describing security policies on load balancer: %q", err) - } + var notFoundErr *elbtypes.PolicyNotFoundException + if !errors.As(err, ¬FoundErr) { + return fmt.Errorf("error describing security policies on load balancer: %q", err) } } @@ -1520,11 +1518,11 @@ func (c *Cloud) ensureSSLNegotiationPolicy(loadBalancer *elb.LoadBalancerDescrip klog.V(2).Infof("Creating SSL negotiation policy '%s' on load balancer", fmt.Sprintf(SSLNegotiationPolicyNameFormat, policyName)) // there is an upper limit of 98 policies on an ELB, we're pretty safe from // running into it - _, err = c.elb.CreateLoadBalancerPolicy(&elb.CreateLoadBalancerPolicyInput{ + _, err = c.elb.CreateLoadBalancerPolicy(ctx, &elb.CreateLoadBalancerPolicyInput{ LoadBalancerName: loadBalancer.LoadBalancerName, PolicyName: aws.String(fmt.Sprintf(SSLNegotiationPolicyNameFormat, policyName)), PolicyTypeName: aws.String("SSLNegotiationPolicyType"), - PolicyAttributes: []*elb.PolicyAttribute{ + PolicyAttributes: []elbtypes.PolicyAttribute{ { AttributeName: aws.String("Reference-Security-Policy"), AttributeValue: aws.String(policyName), @@ -1537,29 +1535,27 @@ func (c *Cloud) ensureSSLNegotiationPolicy(loadBalancer *elb.LoadBalancerDescrip return nil } -func (c *Cloud) setSSLNegotiationPolicy(loadBalancerName, sslPolicyName string, port int64) error { +func (c *Cloud) setSSLNegotiationPolicy(ctx context.Context, loadBalancerName, sslPolicyName string, port int64) error { policyName := fmt.Sprintf(SSLNegotiationPolicyNameFormat, sslPolicyName) request := &elb.SetLoadBalancerPoliciesOfListenerInput{ LoadBalancerName: aws.String(loadBalancerName), - LoadBalancerPort: aws.Int64(port), - PolicyNames: []*string{ - aws.String(policyName), - }, + LoadBalancerPort: int32(port), + PolicyNames: []string{policyName}, } klog.V(2).Infof("Setting SSL negotiation policy '%s' on load balancer", policyName) - _, err := c.elb.SetLoadBalancerPoliciesOfListener(request) + _, err := c.elb.SetLoadBalancerPoliciesOfListener(ctx, request) if err != nil { return fmt.Errorf("error setting SSL negotiation policy '%s' on load balancer: %q", policyName, err) } return nil } -func (c *Cloud) createProxyProtocolPolicy(loadBalancerName string) error { +func (c *Cloud) createProxyProtocolPolicy(ctx context.Context, loadBalancerName string) error { request := &elb.CreateLoadBalancerPolicyInput{ LoadBalancerName: aws.String(loadBalancerName), PolicyName: aws.String(ProxyProtocolPolicyName), PolicyTypeName: aws.String("ProxyProtocolPolicyType"), - PolicyAttributes: []*elb.PolicyAttribute{ + PolicyAttributes: []elbtypes.PolicyAttribute{ { AttributeName: aws.String("ProxyProtocol"), AttributeValue: aws.String("true"), @@ -1567,7 +1563,7 @@ func (c *Cloud) createProxyProtocolPolicy(loadBalancerName string) error { }, } klog.V(2).Info("Creating proxy protocol policy on load balancer") - _, err := c.elb.CreateLoadBalancerPolicy(request) + _, err := c.elb.CreateLoadBalancerPolicy(ctx, request) if err != nil { return fmt.Errorf("error creating proxy protocol policy on load balancer: %q", err) } @@ -1575,18 +1571,18 @@ func (c *Cloud) createProxyProtocolPolicy(loadBalancerName string) error { return nil } -func (c *Cloud) setBackendPolicies(loadBalancerName string, instancePort int64, policies []*string) error { +func (c *Cloud) setBackendPolicies(ctx context.Context, loadBalancerName string, instancePort *int32, policies []*string) error { request := &elb.SetLoadBalancerPoliciesForBackendServerInput{ - InstancePort: aws.Int64(instancePort), + InstancePort: instancePort, LoadBalancerName: aws.String(loadBalancerName), - PolicyNames: policies, + PolicyNames: aws.ToStringSlice(policies), } if len(policies) > 0 { klog.V(2).Infof("Adding AWS loadbalancer backend policies on node port %d", instancePort) } else { klog.V(2).Infof("Removing AWS loadbalancer backend policies on node port %d", instancePort) } - _, err := c.elb.SetLoadBalancerPoliciesForBackendServer(request) + _, err := c.elb.SetLoadBalancerPoliciesForBackendServer(ctx, request) if err != nil { return fmt.Errorf("error adjusting AWS loadbalancer backend policies: %q", err) } @@ -1594,9 +1590,9 @@ func (c *Cloud) setBackendPolicies(loadBalancerName string, instancePort int64, return nil } -func proxyProtocolEnabled(backend *elb.BackendServerDescription) bool { +func proxyProtocolEnabled(backend elbtypes.BackendServerDescription) bool { for _, policy := range backend.PolicyNames { - if aws.StringValue(policy) == ProxyProtocolPolicyName { + if policy == ProxyProtocolPolicyName { return true } } @@ -1607,7 +1603,7 @@ func proxyProtocolEnabled(backend *elb.BackendServerDescription) bool { // findInstancesForELB gets the EC2 instances corresponding to the Nodes, for setting up an ELB // We ignore Nodes (with a log message) where the instanceid cannot be determined from the provider, // and we ignore instances which are not found -func (c *Cloud) findInstancesForELB(nodes []*v1.Node, annotations map[string]string) (map[InstanceID]*ec2.Instance, error) { +func (c *Cloud) findInstancesForELB(ctx context.Context, nodes []*v1.Node, annotations map[string]string) (map[InstanceID]*ec2types.Instance, error) { targetNodes := filterTargetNodes(nodes, annotations) @@ -1618,7 +1614,7 @@ func (c *Cloud) findInstancesForELB(nodes []*v1.Node, annotations map[string]str MaxAge: defaultEC2InstanceCacheMaxAge, HasInstances: instanceIDs, // Refresh if any of the instance ids are missing } - snapshot, err := c.instanceCache.describeAllInstancesCached(cacheCriteria) + snapshot, err := c.instanceCache.describeAllInstancesCached(ctx, cacheCriteria) if err != nil { return nil, err } @@ -1660,3 +1656,48 @@ func filterTargetNodes(nodes []*v1.Node, annotations map[string]string) []*v1.No return targetNodes } + +// ValidateHealthCheck replaces ELB.HealthCheck.Validate() from AWS SDK Go V1, which has been deprecated in V2 +// V1 implementation: https://github.com/aws/aws-sdk-go/blob/v1.55.7/service/elb/api.go#L5346 +func ValidateHealthCheck(s *elbtypes.HealthCheck) error { + var validationErrors []string + + if s == nil { + validationErrors = append(validationErrors, "HealthCheck is nil") + return fmt.Errorf("HealthCheck validation errors: %s", strings.Join(validationErrors, "; ")) + } + + if s.HealthyThreshold == nil { + validationErrors = append(validationErrors, "HealthyThreshold is required") + } else if *s.HealthyThreshold < 2 { + validationErrors = append(validationErrors, "HealthyThreshold must be at least 2") + } + + if s.Interval == nil { + validationErrors = append(validationErrors, "Interval is required") + } else if *s.Interval < 5 { + validationErrors = append(validationErrors, "Interval must be at least 5") + } + + if s.Target == nil { + validationErrors = append(validationErrors, "Target is required") + } + + if s.Timeout == nil { + validationErrors = append(validationErrors, "Timeout is required") + } else if *s.Timeout < 2 { + validationErrors = append(validationErrors, "Timeout must be at least 2") + } + + if s.UnhealthyThreshold == nil { + validationErrors = append(validationErrors, "UnhealthyThreshold is required") + } else if *s.UnhealthyThreshold < 2 { + validationErrors = append(validationErrors, "UnhealthyThreshold must be at least 2") + } + + if len(validationErrors) > 0 { + return fmt.Errorf("HealthCheck validation errors: %s", strings.Join(validationErrors, "; ")) + } + + return nil +} diff --git a/pkg/providers/v1/aws_loadbalancer_test.go b/pkg/providers/v1/aws_loadbalancer_test.go index 866742c3d9..3aec714a6d 100644 --- a/pkg/providers/v1/aws_loadbalancer_test.go +++ b/pkg/providers/v1/aws_loadbalancer_test.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "context" "fmt" "reflect" "testing" @@ -26,10 +27,10 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elbv2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + elbtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/stretchr/testify/assert" ) @@ -213,76 +214,70 @@ func TestSyncElbListeners(t *testing.T) { tests := []struct { name string loadBalancerName string - listeners []*elb.Listener - listenerDescriptions []*elb.ListenerDescription - toCreate []*elb.Listener - toDelete []*int64 + listeners []elbtypes.Listener + listenerDescriptions []elbtypes.ListenerDescription + toCreate []elbtypes.Listener + toDelete []int32 }{ { name: "no edge cases", loadBalancerName: "lb_one", - listeners: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, - {InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, - {InstancePort: aws.Int64(8443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(8443), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, + listeners: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, + {InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, + {InstancePort: aws.Int32(8443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 8443, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, }, - listenerDescriptions: []*elb.ListenerDescription{ - {Listener: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}}, - {Listener: &elb.Listener{InstancePort: aws.Int64(8443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(8443), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}}, + listenerDescriptions: []elbtypes.ListenerDescription{ + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}}, + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(8443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 8443, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}}, }, - toDelete: []*int64{ - aws.Int64(80), - }, - toCreate: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, - {InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, + toDelete: []int32{80}, + toCreate: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, + {InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, }, }, { name: "no listeners to delete", loadBalancerName: "lb_two", - listeners: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, - {InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, + listeners: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, + {InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, }, - listenerDescriptions: []*elb.ListenerDescription{ - {Listener: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, + listenerDescriptions: []elbtypes.ListenerDescription{ + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, }, - toCreate: []*elb.Listener{ - {InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, + toCreate: []elbtypes.Listener{ + {InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP"), SSLCertificateId: aws.String("def-456")}, }, - toDelete: []*int64{}, + toDelete: []int32{}, }, { name: "no listeners to create", loadBalancerName: "lb_three", - listeners: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, - }, - listenerDescriptions: []*elb.ListenerDescription{ - {Listener: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}}, - {Listener: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, + listeners: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}, }, - toDelete: []*int64{ - aws.Int64(80), + listenerDescriptions: []elbtypes.ListenerDescription{ + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}}, + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, }, - toCreate: []*elb.Listener{}, + toDelete: []int32{80}, + toCreate: []elbtypes.Listener{}, }, { name: "nil actual listener", loadBalancerName: "lb_four", - listeners: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP")}, + listeners: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP")}, }, - listenerDescriptions: []*elb.ListenerDescription{ - {Listener: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, + listenerDescriptions: []elbtypes.ListenerDescription{ + {Listener: &elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP"), SSLCertificateId: aws.String("abc-123")}}, {Listener: nil}, }, - toDelete: []*int64{ - aws.Int64(443), - }, - toCreate: []*elb.Listener{ - {InstancePort: aws.Int64(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("HTTP")}, + toDelete: []int32{443}, + toCreate: []elbtypes.Listener{ + {InstancePort: aws.Int32(443), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 443, Protocol: aws.String("HTTP")}, }, }, } @@ -299,37 +294,37 @@ func TestSyncElbListeners(t *testing.T) { func TestElbListenersAreEqual(t *testing.T) { tests := []struct { name string - expected, actual *elb.Listener + expected, actual elbtypes.Listener equal bool }{ { name: "should be equal", - expected: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, - actual: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, + expected: elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, + actual: elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, equal: true, }, { name: "instance port should be different", - expected: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, - actual: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, + expected: elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, + actual: elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, equal: false, }, { name: "instance protocol should be different", - expected: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, - actual: &elb.Listener{InstancePort: aws.Int64(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, + expected: elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("HTTP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, + actual: elbtypes.Listener{InstancePort: aws.Int32(80), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, equal: false, }, { name: "load balancer port should be different", - expected: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(443), Protocol: aws.String("TCP")}, - actual: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, + expected: elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 443, Protocol: aws.String("TCP")}, + actual: elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, equal: false, }, { name: "protocol should be different", - expected: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("TCP")}, - actual: &elb.Listener{InstancePort: aws.Int64(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: aws.Int64(80), Protocol: aws.String("HTTP")}, + expected: elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("TCP")}, + actual: elbtypes.Listener{InstancePort: aws.Int32(443), InstanceProtocol: aws.String("TCP"), LoadBalancerPort: 80, Protocol: aws.String("HTTP")}, equal: false, }, } @@ -344,10 +339,10 @@ func TestElbListenersAreEqual(t *testing.T) { func TestBuildTargetGroupName(t *testing.T) { type args struct { serviceName types.NamespacedName - servicePort int64 - nodePort int64 - targetProtocol string - targetType string + servicePort int32 + nodePort int32 + targetProtocol elbv2types.ProtocolEnum + targetType elbv2types.TargetTypeEnum nlbConfig nlbPortMapping } tests := []struct { @@ -363,8 +358,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-7fa2e07508", @@ -376,8 +371,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-719ee635da", @@ -389,8 +384,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "another", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-another-servicea-f66e09847d", @@ -402,8 +397,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-b"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-serviceb-196c19c881", @@ -415,8 +410,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 9090, nodePort: 8080, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-06876706cb", @@ -428,8 +423,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 9090, - targetProtocol: "TCP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-119f844ec0", @@ -441,8 +436,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "UDP", - targetType: "instance", + targetProtocol: elbv2types.ProtocolEnumUdp, + targetType: elbv2types.TargetTypeEnumInstance, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-3868761686", @@ -454,8 +449,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "ip", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumIp, nlbConfig: nlbPortMapping{}, }, want: "k8s-default-servicea-0fa31f4b0f", @@ -467,8 +462,8 @@ func TestBuildTargetGroupName(t *testing.T) { serviceName: types.NamespacedName{Namespace: "default", Name: "service-a"}, servicePort: 80, nodePort: 8080, - targetProtocol: "TCP", - targetType: "ip", + targetProtocol: elbv2types.ProtocolEnumTcp, + targetType: elbv2types.TargetTypeEnumIp, nlbConfig: nlbPortMapping{ HealthCheckConfig: healthCheckConfig{ Protocol: "HTTP", @@ -545,11 +540,11 @@ func TestFilterTargetNodes(t *testing.T) { } } -func makeNodeInstancePair(offset int) (*v1.Node, *ec2.Instance) { +func makeNodeInstancePair(offset int) (*v1.Node, *ec2types.Instance) { instanceID := fmt.Sprintf("i-%x", int64(0x03bcc3496da09f78e)+int64(offset)) - instance := &ec2.Instance{ + instance := &ec2types.Instance{ InstanceId: aws.String(instanceID), - Placement: &ec2.Placement{ + Placement: &ec2types.Placement{ AvailabilityZone: aws.String("us-east-1b"), }, PrivateDnsName: aws.String(fmt.Sprintf("ip-192-168-32-%d.ec2.internal", 101+offset)), @@ -557,10 +552,10 @@ func makeNodeInstancePair(offset int) (*v1.Node, *ec2.Instance) { PublicIpAddress: aws.String(fmt.Sprintf("1.2.3.%d", 1+offset)), } - var tag ec2.Tag + var tag ec2types.Tag tag.Key = aws.String(fmt.Sprintf("%s%s", TagNameKubernetesClusterPrefix, TestClusterID)) tag.Value = aws.String("owned") - instance.Tags = []*ec2.Tag{&tag} + instance.Tags = []ec2types.Tag{tag} node := &v1.Node{ ObjectMeta: metav1.ObjectMeta{ @@ -584,32 +579,32 @@ func TestCloud_findInstancesForELB(t *testing.T) { } newNode, newInstance := makeNodeInstancePair(1) awsServices := NewFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Errorf("Error building aws cloud: %v", err) return } - want := map[InstanceID]*ec2.Instance{ + want := map[InstanceID]*ec2types.Instance{ "i-self": awsServices.selfInstance, } - got, err := c.findInstancesForELB([]*v1.Node{defaultNode}, nil) + got, err := c.findInstancesForELB(context.TODO(), []*v1.Node{defaultNode}, nil) assert.NoError(t, err) assert.True(t, reflect.DeepEqual(want, got)) // Add a new EC2 instance awsServices.instances = append(awsServices.instances, newInstance) - want = map[InstanceID]*ec2.Instance{ + want = map[InstanceID]*ec2types.Instance{ "i-self": awsServices.selfInstance, - InstanceID(aws.StringValue(newInstance.InstanceId)): newInstance, + InstanceID(aws.ToString(newInstance.InstanceId)): newInstance, } - got, err = c.findInstancesForELB([]*v1.Node{defaultNode, newNode}, nil) + got, err = c.findInstancesForELB(context.TODO(), []*v1.Node{defaultNode, newNode}, nil) assert.NoError(t, err) assert.True(t, reflect.DeepEqual(want, got)) // Verify existing instance cache gets used cacheExpiryOld := c.instanceCache.snapshot.timestamp - got, err = c.findInstancesForELB([]*v1.Node{defaultNode, newNode}, nil) + got, err = c.findInstancesForELB(context.TODO(), []*v1.Node{defaultNode, newNode}, nil) assert.NoError(t, err) assert.True(t, reflect.DeepEqual(want, got)) cacheExpiryNew := c.instanceCache.snapshot.timestamp @@ -618,7 +613,7 @@ func TestCloud_findInstancesForELB(t *testing.T) { // Force cache expiry and verify cache gets updated with new timestamp cacheExpiryOld = c.instanceCache.snapshot.timestamp c.instanceCache.snapshot.timestamp = c.instanceCache.snapshot.timestamp.Add(-(defaultEC2InstanceCacheMaxAge + 1*time.Second)) - got, err = c.findInstancesForELB([]*v1.Node{defaultNode, newNode}, nil) + got, err = c.findInstancesForELB(context.TODO(), []*v1.Node{defaultNode, newNode}, nil) assert.NoError(t, err) assert.True(t, reflect.DeepEqual(want, got)) cacheExpiryNew = c.instanceCache.snapshot.timestamp @@ -627,56 +622,56 @@ func TestCloud_findInstancesForELB(t *testing.T) { func TestCloud_chunkTargetDescriptions(t *testing.T) { type args struct { - targets []*elbv2.TargetDescription + targets []elbv2types.TargetDescription chunkSize int } tests := []struct { name string args args - want [][]*elbv2.TargetDescription + want [][]elbv2types.TargetDescription }{ { name: "can be evenly chunked", args: args{ - targets: []*elbv2.TargetDescription{ + targets: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, chunkSize: 2, }, - want: [][]*elbv2.TargetDescription{ + want: [][]elbv2types.TargetDescription{ { { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, { { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -684,46 +679,46 @@ func TestCloud_chunkTargetDescriptions(t *testing.T) { { name: "cannot be evenly chunked", args: args{ - targets: []*elbv2.TargetDescription{ + targets: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, chunkSize: 3, }, - want: [][]*elbv2.TargetDescription{ + want: [][]elbv2types.TargetDescription{ { { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, { { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -731,43 +726,43 @@ func TestCloud_chunkTargetDescriptions(t *testing.T) { { name: "chunkSize equal to total count", args: args{ - targets: []*elbv2.TargetDescription{ + targets: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, chunkSize: 4, }, - want: [][]*elbv2.TargetDescription{ + want: [][]elbv2types.TargetDescription{ { { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -775,43 +770,43 @@ func TestCloud_chunkTargetDescriptions(t *testing.T) { { name: "chunkSize greater than total count", args: args{ - targets: []*elbv2.TargetDescription{ + targets: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, chunkSize: 10, }, - want: [][]*elbv2.TargetDescription{ + want: [][]elbv2types.TargetDescription{ { { Id: aws.String("i-abcdefg1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdefg4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -827,7 +822,7 @@ func TestCloud_chunkTargetDescriptions(t *testing.T) { { name: "chunk empty slice", args: args{ - targets: []*elbv2.TargetDescription{}, + targets: []elbv2types.TargetDescription{}, chunkSize: 2, }, want: nil, @@ -844,38 +839,38 @@ func TestCloud_chunkTargetDescriptions(t *testing.T) { func TestCloud_diffTargetGroupTargets(t *testing.T) { type args struct { - expectedTargets []*elbv2.TargetDescription - actualTargets []*elbv2.TargetDescription + expectedTargets []*elbv2types.TargetDescription + actualTargets []*elbv2types.TargetDescription } tests := []struct { name string args args - wantTargetsToRegister []*elbv2.TargetDescription - wantTargetsToDeregister []*elbv2.TargetDescription + wantTargetsToRegister []elbv2types.TargetDescription + wantTargetsToDeregister []elbv2types.TargetDescription }{ { name: "all targets to register", args: args{ - expectedTargets: []*elbv2.TargetDescription{ + expectedTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, actualTargets: nil, }, - wantTargetsToRegister: []*elbv2.TargetDescription{ + wantTargetsToRegister: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, wantTargetsToDeregister: nil, @@ -884,79 +879,79 @@ func TestCloud_diffTargetGroupTargets(t *testing.T) { name: "all targets to deregister", args: args{ expectedTargets: nil, - actualTargets: []*elbv2.TargetDescription{ + actualTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, wantTargetsToRegister: nil, - wantTargetsToDeregister: []*elbv2.TargetDescription{ + wantTargetsToDeregister: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, { name: "some targets to register and deregister", args: args{ - expectedTargets: []*elbv2.TargetDescription{ + expectedTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef5"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, - actualTargets: []*elbv2.TargetDescription{ + actualTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, - wantTargetsToRegister: []*elbv2.TargetDescription{ + wantTargetsToRegister: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdef4"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef5"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, - wantTargetsToDeregister: []*elbv2.TargetDescription{ + wantTargetsToDeregister: []elbv2types.TargetDescription{ { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -972,32 +967,32 @@ func TestCloud_diffTargetGroupTargets(t *testing.T) { { name: "expected and actual targets equals", args: args{ - expectedTargets: []*elbv2.TargetDescription{ + expectedTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, - actualTargets: []*elbv2.TargetDescription{ + actualTargets: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -1018,12 +1013,12 @@ func TestCloud_diffTargetGroupTargets(t *testing.T) { func TestCloud_computeTargetGroupExpectedTargets(t *testing.T) { type args struct { instanceIDs []string - port int64 + port int32 } tests := []struct { name string args args - want []*elbv2.TargetDescription + want []*elbv2types.TargetDescription }{ { name: "no instance", @@ -1031,7 +1026,7 @@ func TestCloud_computeTargetGroupExpectedTargets(t *testing.T) { instanceIDs: nil, port: 8080, }, - want: []*elbv2.TargetDescription{}, + want: []*elbv2types.TargetDescription{}, }, { name: "one instance", @@ -1039,10 +1034,10 @@ func TestCloud_computeTargetGroupExpectedTargets(t *testing.T) { instanceIDs: []string{"i-abcdef1"}, port: 8080, }, - want: []*elbv2.TargetDescription{ + want: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, @@ -1052,18 +1047,18 @@ func TestCloud_computeTargetGroupExpectedTargets(t *testing.T) { instanceIDs: []string{"i-abcdef1", "i-abcdef2", "i-abcdef3"}, port: 8080, }, - want: []*elbv2.TargetDescription{ + want: []*elbv2types.TargetDescription{ { Id: aws.String("i-abcdef1"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef2"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, { Id: aws.String("i-abcdef3"), - Port: aws.Int64(8080), + Port: aws.Int32(8080), }, }, }, diff --git a/pkg/providers/v1/aws_routes.go b/pkg/providers/v1/aws_routes.go index e3e7c5b7a4..2fea34553e 100644 --- a/pkg/providers/v1/aws_routes.go +++ b/pkg/providers/v1/aws_routes.go @@ -20,22 +20,23 @@ import ( "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "k8s.io/klog/v2" cloudprovider "k8s.io/cloud-provider" ) -func (c *Cloud) findRouteTable(clusterName string) (*ec2.RouteTable, error) { +func (c *Cloud) findRouteTable(ctx context.Context, clusterName string) (*ec2types.RouteTable, error) { // This should be unnecessary (we already filter on TagNameKubernetesCluster, // and something is broken if cluster name doesn't match, but anyway... // TODO: All clouds should be cluster-aware by default - var tables []*ec2.RouteTable + var tables []ec2types.RouteTable if c.cfg.Global.RouteTableID != "" { - request := &ec2.DescribeRouteTablesInput{Filters: []*ec2.Filter{newEc2Filter("route-table-id", c.cfg.Global.RouteTableID)}} - response, err := c.ec2.DescribeRouteTables(request) + request := &ec2.DescribeRouteTablesInput{Filters: []ec2types.Filter{newEc2Filter("route-table-id", c.cfg.Global.RouteTableID)}} + response, err := c.ec2.DescribeRouteTables(ctx, request) if err != nil { return nil, err } @@ -43,7 +44,7 @@ func (c *Cloud) findRouteTable(clusterName string) (*ec2.RouteTable, error) { tables = response } else { request := &ec2.DescribeRouteTablesInput{} - response, err := c.ec2.DescribeRouteTables(request) + response, err := c.ec2.DescribeRouteTables(ctx, request) if err != nil { return nil, err } @@ -62,37 +63,37 @@ func (c *Cloud) findRouteTable(clusterName string) (*ec2.RouteTable, error) { if len(tables) != 1 { return nil, fmt.Errorf("found multiple matching AWS route tables for AWS cluster: %s", clusterName) } - return tables[0], nil + return &tables[0], nil } // ListRoutes implements Routes.ListRoutes // List all routes that match the filter func (c *Cloud) ListRoutes(ctx context.Context, clusterName string) ([]*cloudprovider.Route, error) { - table, err := c.findRouteTable(clusterName) + table, err := c.findRouteTable(ctx, clusterName) if err != nil { return nil, err } var routes []*cloudprovider.Route - var instanceIDs []*string + var instanceIDs []string for _, r := range table.Routes { - instanceID := aws.StringValue(r.InstanceId) + instanceID := aws.ToString(r.InstanceId) if instanceID == "" { continue } - instanceIDs = append(instanceIDs, &instanceID) + instanceIDs = append(instanceIDs, instanceID) } - instances, err := c.getInstancesByIDs(instanceIDs) + instances, err := c.getInstancesByIDs(ctx, instanceIDs) if err != nil { return nil, err } for _, r := range table.Routes { - destinationCIDR := aws.StringValue(r.DestinationCidrBlock) + destinationCIDR := aws.ToString(r.DestinationCidrBlock) if destinationCIDR == "" { continue } @@ -103,14 +104,14 @@ func (c *Cloud) ListRoutes(ctx context.Context, clusterName string) ([]*cloudpro } // Capture blackhole routes - if aws.StringValue(r.State) == ec2.RouteStateBlackhole { + if r.State == ec2types.RouteStateBlackhole { route.Blackhole = true routes = append(routes, route) continue } // Capture instance routes - instanceID := aws.StringValue(r.InstanceId) + instanceID := aws.ToString(r.InstanceId) if instanceID != "" { _, found := instances[instanceID] if found { @@ -130,12 +131,12 @@ func (c *Cloud) ListRoutes(ctx context.Context, clusterName string) ([]*cloudpro } // Sets the instance attribute "source-dest-check" to the specified value -func (c *Cloud) configureInstanceSourceDestCheck(instanceID string, sourceDestCheck bool) error { +func (c *Cloud) configureInstanceSourceDestCheck(ctx context.Context, instanceID string, sourceDestCheck bool) error { request := &ec2.ModifyInstanceAttributeInput{} request.InstanceId = aws.String(instanceID) - request.SourceDestCheck = &ec2.AttributeBooleanValue{Value: aws.Bool(sourceDestCheck)} + request.SourceDestCheck = &ec2types.AttributeBooleanValue{Value: aws.Bool(sourceDestCheck)} - _, err := c.ec2.ModifyInstanceAttribute(request) + _, err := c.ec2.ModifyInstanceAttribute(ctx, request) if err != nil { return fmt.Errorf("error configuring source-dest-check on instance %s: %q", instanceID, err) } @@ -145,46 +146,46 @@ func (c *Cloud) configureInstanceSourceDestCheck(instanceID string, sourceDestCh // CreateRoute implements Routes.CreateRoute // Create the described route func (c *Cloud) CreateRoute(ctx context.Context, clusterName string, nameHint string, route *cloudprovider.Route) error { - instance, err := c.getInstanceByNodeName(route.TargetNode) + instance, err := c.getInstanceByNodeName(ctx, route.TargetNode) if err != nil { return err } // In addition to configuring the route itself, we also need to configure the instance to accept that traffic // On AWS, this requires turning source-dest checks off - err = c.configureInstanceSourceDestCheck(aws.StringValue(instance.InstanceId), false) + err = c.configureInstanceSourceDestCheck(ctx, aws.ToString(instance.InstanceId), false) if err != nil { return err } - table, err := c.findRouteTable(clusterName) + table, err := c.findRouteTable(ctx, clusterName) if err != nil { return err } - var deleteRoute *ec2.Route + var deleteRoute *ec2types.Route for _, r := range table.Routes { - destinationCIDR := aws.StringValue(r.DestinationCidrBlock) + destinationCIDR := aws.ToString(r.DestinationCidrBlock) if destinationCIDR != route.DestinationCIDR { continue } - if aws.StringValue(r.State) == ec2.RouteStateBlackhole { - deleteRoute = r + if r.State == ec2types.RouteStateBlackhole { + deleteRoute = &r } } if deleteRoute != nil { - klog.Infof("deleting blackholed route: %s", aws.StringValue(deleteRoute.DestinationCidrBlock)) + klog.Infof("deleting blackholed route: %s", aws.ToString(deleteRoute.DestinationCidrBlock)) request := &ec2.DeleteRouteInput{} request.DestinationCidrBlock = deleteRoute.DestinationCidrBlock request.RouteTableId = table.RouteTableId - _, err = c.ec2.DeleteRoute(request) + _, err = c.ec2.DeleteRoute(ctx, request) if err != nil { - return fmt.Errorf("error deleting blackholed AWS route (%s): %q", aws.StringValue(deleteRoute.DestinationCidrBlock), err) + return fmt.Errorf("error deleting blackholed AWS route (%s): %q", aws.ToString(deleteRoute.DestinationCidrBlock), err) } } @@ -194,7 +195,7 @@ func (c *Cloud) CreateRoute(ctx context.Context, clusterName string, nameHint st request.InstanceId = instance.InstanceId request.RouteTableId = table.RouteTableId - _, err = c.ec2.CreateRoute(request) + _, err = c.ec2.CreateRoute(ctx, request) if err != nil { return fmt.Errorf("error creating AWS route (%s): %q", route.DestinationCIDR, err) } @@ -205,7 +206,7 @@ func (c *Cloud) CreateRoute(ctx context.Context, clusterName string, nameHint st // DeleteRoute implements Routes.DeleteRoute // Delete the specified route func (c *Cloud) DeleteRoute(ctx context.Context, clusterName string, route *cloudprovider.Route) error { - table, err := c.findRouteTable(clusterName) + table, err := c.findRouteTable(ctx, clusterName) if err != nil { return err } @@ -214,7 +215,7 @@ func (c *Cloud) DeleteRoute(ctx context.Context, clusterName string, route *clou request.DestinationCidrBlock = aws.String(route.DestinationCIDR) request.RouteTableId = table.RouteTableId - _, err = c.ec2.DeleteRoute(request) + _, err = c.ec2.DeleteRoute(ctx, request) if err != nil { return fmt.Errorf("error deleting AWS route (%s): %q", route.DestinationCIDR, err) } diff --git a/pkg/providers/v1/aws_sdk_client_middleware.go b/pkg/providers/v1/aws_sdk_client_middleware.go new file mode 100644 index 0000000000..eef4bc1ac6 --- /dev/null +++ b/pkg/providers/v1/aws_sdk_client_middleware.go @@ -0,0 +1,322 @@ +package aws + +import ( + "context" + "fmt" + "net/url" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" + "github.com/aws/aws-sdk-go-v2/service/ec2" + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/kms" + smithyendpoints "github.com/aws/smithy-go/endpoints" + smithymiddleware "github.com/aws/smithy-go/middleware" + + "k8s.io/client-go/pkg/version" +) + +// Adds middleware to AWS SDK Go V2 clients. +func (p *awsSDKProvider) AddMiddleware(ctx context.Context, regionName string, cfg *aws.Config) { + cfg.APIOptions = append(cfg.APIOptions, + middleware.AddUserAgentKeyValue("kubernetes", version.Get().String()), + func(stack *smithymiddleware.Stack) error { + return stack.Finalize.Add(awsHandlerLoggerMiddleware(), smithymiddleware.Before) + }, + ) + + delayer := p.getCrossRequestRetryDelay(regionName) + if delayer != nil { + cfg.APIOptions = append(cfg.APIOptions, + func(stack *smithymiddleware.Stack) error { + stack.Finalize.Add(delayPreSign(delayer), smithymiddleware.Before) + stack.Finalize.Insert(delayAfterRetry(delayer), "Retry", smithymiddleware.Before) + return nil + }, + ) + } + + p.addAPILoggingMiddleware(cfg) +} + +// Adds logging middleware for AWS SDK Go V2 clients +func (p *awsSDKProvider) addAPILoggingMiddleware(cfg *aws.Config) { + cfg.APIOptions = append(cfg.APIOptions, + func(stack *smithymiddleware.Stack) error { + stack.Serialize.Add(awsSendHandlerLoggerMiddleware(), smithymiddleware.After) + stack.Deserialize.Add(awsValidateResponseHandlerLoggerMiddleware(), smithymiddleware.Before) + return nil + }, + ) +} + +// GetEC2EndpointOpts returns client configuration options that override +// the signing name and region, if appropriate. +func (cfg *CloudConfig) GetEC2EndpointOpts(region string) []func(*ec2.Options) { + opts := []func(*ec2.Options){} + for _, override := range cfg.ServiceOverride { + if override.Service == ec2.ServiceID && override.Region == region { + opts = append(opts, + ec2.WithSigV4SigningName(override.SigningName), + ec2.WithSigV4SigningRegion(override.SigningRegion), + ) + } + } + return opts +} + +// GetCustomEC2Resolver returns an endpoint resolver for EC2 Clients +func (cfg *CloudConfig) GetCustomEC2Resolver() ec2.EndpointResolverV2 { + return &EC2Resolver{ + Resolver: ec2.NewDefaultEndpointResolverV2(), + Cfg: cfg, + } +} + +// EC2Resolver overrides the endpoint for an AWS SDK Go V2 EC2 Client, +// using the provided CloudConfig to determine if an override +// is appropriate. +type EC2Resolver struct { + Resolver ec2.EndpointResolverV2 + Cfg *CloudConfig +} + +// ResolveEndpoint resolves the endpoint, overriding when custom configurations are set. +func (r *EC2Resolver) ResolveEndpoint( + ctx context.Context, params ec2.EndpointParameters, +) ( + endpoint smithyendpoints.Endpoint, err error, +) { + for _, override := range r.Cfg.ServiceOverride { + if override.Service == ec2.ServiceID && override.Region == aws.ToString(params.Region) { + customURL, err := url.Parse(override.URL) + if err != nil { + return smithyendpoints.Endpoint{}, fmt.Errorf("could not parse override URL, %w", err) + } + return smithyendpoints.Endpoint{ + URI: *customURL, + }, nil + } + } + return r.Resolver.ResolveEndpoint(ctx, params) +} + +// GetELBEndpointOpts returns client configuration options that override +// the signing name and region, if appropriate. +func (cfg *CloudConfig) GetELBEndpointOpts(region string) []func(*elb.Options) { + opts := []func(*elb.Options){} + for _, override := range cfg.ServiceOverride { + if override.Service == elb.ServiceID && override.Region == region { + opts = append(opts, + elb.WithSigV4SigningName(override.SigningName), + elb.WithSigV4SigningRegion(override.SigningRegion), + ) + } + } + return opts +} + +// GetCustomELBResolver returns an endpoint resolver for ELB Clients +func (cfg *CloudConfig) GetCustomELBResolver() elb.EndpointResolverV2 { + return &ELBResolver{ + Resolver: elb.NewDefaultEndpointResolverV2(), + Cfg: cfg, + } +} + +// ELBResolver overrides the endpoint for an AWS SDK Go V2 ELB Client, +// using the provided CloudConfig to determine if an override +// is appropriate. +type ELBResolver struct { + Resolver elb.EndpointResolverV2 + Cfg *CloudConfig +} + +// ResolveEndpoint resolves the endpoint, overriding when custom configurations are set. +func (r *ELBResolver) ResolveEndpoint( + ctx context.Context, params elb.EndpointParameters, +) ( + endpoint smithyendpoints.Endpoint, err error, +) { + for _, override := range r.Cfg.ServiceOverride { + if override.Service == elb.ServiceID && override.Region == aws.ToString(params.Region) { + customURL, err := url.Parse(override.URL) + if err != nil { + return smithyendpoints.Endpoint{}, fmt.Errorf("could not parse override URL, %w", err) + } + return smithyendpoints.Endpoint{ + URI: *customURL, + }, nil + } + } + return r.Resolver.ResolveEndpoint(ctx, params) +} + +// GetELBV2EndpointOpts returns client configuration options that override +// the signing name and region, if appropriate. +func (cfg *CloudConfig) GetELBV2EndpointOpts(region string) []func(*elbv2.Options) { + opts := []func(*elbv2.Options){} + for _, override := range cfg.ServiceOverride { + if override.Service == elbv2.ServiceID && override.Region == region { + opts = append(opts, + elbv2.WithSigV4SigningName(override.SigningName), + elbv2.WithSigV4SigningRegion(override.SigningRegion), + ) + } + } + return opts +} + +// GetCustomELBV2Resolver returns an endpoint resolver for ELB Clients +func (cfg *CloudConfig) GetCustomELBV2Resolver() elbv2.EndpointResolverV2 { + return &ELBV2Resolver{ + Resolver: elbv2.NewDefaultEndpointResolverV2(), + Cfg: cfg, + } +} + +// ELBV2Resolver overrides the endpoint for an AWS SDK Go V2 ELB Client, +// using the provided CloudConfig to determine if an override +// is appropriate. +type ELBV2Resolver struct { + Resolver elbv2.EndpointResolverV2 + Cfg *CloudConfig +} + +// ResolveEndpoint resolves the endpoint, overriding when custom configurations are set. +func (r *ELBV2Resolver) ResolveEndpoint( + ctx context.Context, params elbv2.EndpointParameters, +) ( + endpoint smithyendpoints.Endpoint, err error, +) { + for _, override := range r.Cfg.ServiceOverride { + if override.Service == elbv2.ServiceID && override.Region == aws.ToString(params.Region) { + customURL, err := url.Parse(override.URL) + if err != nil { + return smithyendpoints.Endpoint{}, fmt.Errorf("could not parse override URL, %w", err) + } + return smithyendpoints.Endpoint{ + URI: *customURL, + }, nil + } + } + return r.Resolver.ResolveEndpoint(ctx, params) +} + +// GetKMSEndpointOpts returns client configuration options that override +// the signing name and region, if appropriate. +func (cfg *CloudConfig) GetKMSEndpointOpts(region string) []func(*kms.Options) { + opts := []func(*kms.Options){} + for _, override := range cfg.ServiceOverride { + if override.Service == kms.ServiceID && override.Region == region { + opts = append(opts, + kms.WithSigV4SigningName(override.SigningName), + kms.WithSigV4SigningRegion(override.SigningRegion), + ) + } + } + return opts +} + +// GetCustomKMSResolver returns an endpoint resolver for KMS Clients +func (cfg *CloudConfig) GetCustomKMSResolver() kms.EndpointResolverV2 { + return &KMSResolver{ + Resolver: kms.NewDefaultEndpointResolverV2(), + Cfg: cfg, + } +} + +// KMSResolver overrides the endpoint for an AWS SDK Go V2 KMS Client, +// using the provided CloudConfig to determine if an override +// is appropriate. +type KMSResolver struct { + Resolver kms.EndpointResolverV2 + Cfg *CloudConfig +} + +// ResolveEndpoint resolves the endpoint, overriding when custom configurations are set. +func (r *KMSResolver) ResolveEndpoint( + ctx context.Context, params kms.EndpointParameters, +) ( + endpoint smithyendpoints.Endpoint, err error, +) { + for _, override := range r.Cfg.ServiceOverride { + if override.Service == kms.ServiceID && override.Region == aws.ToString(params.Region) { + customURL, err := url.Parse(override.URL) + if err != nil { + return smithyendpoints.Endpoint{}, fmt.Errorf("could not parse override URL, %w", err) + } + return smithyendpoints.Endpoint{ + URI: *customURL, + }, nil + } + } + return r.Resolver.ResolveEndpoint(ctx, params) +} + +// GetIMDSEndpointOpts overrides the endpoint URL for IMDS clients +func (cfg *CloudConfig) GetIMDSEndpointOpts() []func(*imds.Options) { + opts := []func(*imds.Options){} + for _, override := range cfg.ServiceOverride { + if override.Service == imds.ServiceID { + opts = append(opts, func(o *imds.Options) { + o.Endpoint = override.URL + }) + } + } + return opts +} + +// GetAutoscalingEndpointOpts returns client configuration options that override +// the signing name and region, if appropriate. +func (cfg *CloudConfig) GetAutoscalingEndpointOpts(region string) []func(*autoscaling.Options) { + opts := []func(*autoscaling.Options){} + for _, override := range cfg.ServiceOverride { + if override.Service == autoscaling.ServiceID && override.Region == region { + opts = append(opts, + autoscaling.WithSigV4SigningName(override.SigningName), + autoscaling.WithSigV4SigningRegion(override.SigningRegion), + ) + } + } + return opts +} + +// GetCustomAutoscalingResolver returns an endpoint resolver for Autoscaling Clients +func (cfg *CloudConfig) GetCustomAutoscalingResolver() autoscaling.EndpointResolverV2 { + return &AutoscalingResolver{ + Resolver: autoscaling.NewDefaultEndpointResolverV2(), + Cfg: cfg, + } +} + +// AutoscalingResolver overrides the endpoint for an AWS SDK Go V2 Autoscaling Client, +// using the provided CloudConfig to determine if an override +// is appropriate. +type AutoscalingResolver struct { + Resolver autoscaling.EndpointResolverV2 + Cfg *CloudConfig +} + +// ResolveEndpoint resolves the endpoint, overriding when custom configurations are set. +func (r *AutoscalingResolver) ResolveEndpoint( + ctx context.Context, params autoscaling.EndpointParameters, +) ( + endpoint smithyendpoints.Endpoint, err error, +) { + for _, override := range r.Cfg.ServiceOverride { + if override.Service == autoscaling.ServiceID && override.Region == aws.ToString(params.Region) { + customURL, err := url.Parse(override.URL) + if err != nil { + return smithyendpoints.Endpoint{}, fmt.Errorf("could not parse override URL, %w", err) + } + return smithyendpoints.Endpoint{ + URI: *customURL, + }, nil + } + } + return r.Resolver.ResolveEndpoint(ctx, params) +} diff --git a/pkg/providers/v1/aws_sdk_client_test.go b/pkg/providers/v1/aws_sdk_client_test.go new file mode 100644 index 0000000000..9d6d20c554 --- /dev/null +++ b/pkg/providers/v1/aws_sdk_client_test.go @@ -0,0 +1,741 @@ +package aws + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "regexp" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" + "github.com/aws/aws-sdk-go-v2/service/ec2" + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/kms" + + "github.com/stretchr/testify/assert" +) + +type requestInfo struct { + usedCustomEndpoint bool + credential string +} + +// Given an override, a custom endpoint should be used when making API requests +func TestClientsEndpointOverride(t *testing.T) { + reqInfo := requestInfo{} // stores information about requests, should be reset between API calls + // Dummy server that sets usedCustomEndpoint when called, and collects information about the request + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqInfo.usedCustomEndpoint = true + // Extract credential from auth header + auth := r.Header.Get("Authorization") + credRe := regexp.MustCompile(`Credential=([^,]+)`) + credMatch := credRe.FindStringSubmatch(auth) + if len(credMatch) == 2 { // true when it's able to find exactly one match for the Credential header + reqInfo.credential = credMatch[1] + } + })) + defer testServer.Close() + + // Clients should be able to have their default signing region and name overridden + t.Run("With overridden URL, signing region, and signing name", func(t *testing.T) { + cfgWithServiceOverride := CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "custom-service", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "custom-service", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "custom-service", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "custom-service", + }, + "5": { + Service: autoscaling.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "custom-service", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // Test EC2 client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating EC2 client, %v", err) + } + _, err = ec2Client.DescribeVpcs(context.TODO(), &ec2.DescribeVpcsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "EC2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "EC2: signing name was not properly overridden") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "EC2: signing region was not properly overridden") + + // Test ELB client + reqInfo = requestInfo{} + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELB client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELB: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "ELB: signing name was not properly overridden") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "ELB: signing region was not properly overridden") + + // Test ELBV2 client + reqInfo = requestInfo{} + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELBV2 client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELBV2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "ELBV2: signing name was not properly overridden") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "ELBV2: signing region was not properly overridden") + + // Test KMS client + reqInfo = requestInfo{} + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating KMS client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, reqInfo.usedCustomEndpoint, "KMS: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "KMS: signing name was not properly overridden") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "KMS: signing region was not properly overridden") + + // Test Autoscaling client + reqInfo = requestInfo{} + autoscalingClient, err := mockProvider.Autoscaling(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating Autoscaling client, %v", err) + } + _, err = autoscalingClient.DescribeAutoScalingGroups(context.TODO(), &autoscaling.DescribeAutoScalingGroupsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "Autoscaling: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "Autoscaling: signing name was not properly overridden") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "Autoscaling: signing region was not properly overridden") + + }) + + // When the signing name is overridden but not the signing region, the signing name should be + // whatever is configured in the override, and the signing region should fall back to the request region. + t.Run("With overridden signing name and default region", func(t *testing.T) { + cfgWithServiceOverride := CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "custom-service", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "custom-service", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "custom-service", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "custom-service", + }, + "5": { + Service: autoscaling.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "custom-service", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // Test EC2 client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating EC2 client, %v", err) + } + _, err = ec2Client.DescribeVpcs(context.TODO(), &ec2.DescribeVpcsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "EC2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "EC2: blank signing region should fall back to request region") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "EC2: signing name was not properly overridden") + + // Test ELB client + reqInfo = requestInfo{} + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELB client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELB: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "ELB: blank signing region should fall back to request region") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "ELB: signing name was not properly overridden") + + // Test ELBV2 client + reqInfo = requestInfo{} + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELBV2 client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELBV2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "ELBV2: blank signing region should fall back to request region") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "ELBV2: signing name was not properly overridden") + + // Test KMS client + reqInfo = requestInfo{} + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating KMS client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, reqInfo.usedCustomEndpoint, "KMS: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "KMS: blank signing region should fall back to request region") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "KMS: signing name was not properly overridden") + + // Test Autoscaling client + reqInfo = requestInfo{} + autoscalingClient, err := mockProvider.Autoscaling(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating Autoscaling client, %v", err) + } + _, err = autoscalingClient.DescribeAutoScalingGroups(context.Background(), &autoscaling.DescribeAutoScalingGroupsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "Autoscaling: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "Autoscaling: blank signing region should fall back to request region") + assert.True(t, strings.Contains(reqInfo.credential, "custom-service"), "Autoscaling: signing name was not properly overridden") + }) + + // When the signing region is overridden but not the signing name, the signing region should be + // whatever is configured in the override, and the signing name should fall back to the client's service name. + t.Run("With overriden signing region and default name", func(t *testing.T) { + cfgWithServiceOverride := CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "", + }, + "5": { + Service: autoscaling.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "custom-region", + SigningName: "", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // Test EC2 client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating EC2 client, %v", err) + } + _, err = ec2Client.DescribeVpcs(context.TODO(), &ec2.DescribeVpcsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "EC2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ToLower(ec2.ServiceID)), "EC2: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "EC2: signing region was not properly overridden") + + // Test ELB client + reqInfo = requestInfo{} + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELB client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELB: custom endpoint was not used") + // remove whitespace due to multi-word service name + assert.True(t, strings.Contains(reqInfo.credential, strings.ReplaceAll(strings.ToLower(elb.ServiceID), " ", "")), "ELB: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "ELB: signing region was not properly overridden") + + // Test ELBV2 client + reqInfo = requestInfo{} + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELBV2 client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELBV2: custom endpoint was not used") + // ELB and ELBV2 use the same default signing name (https://docs.aws.amazon.com/general/latest/gr/elb.html) + assert.True(t, strings.Contains(reqInfo.credential, strings.ReplaceAll(strings.ToLower(elb.ServiceID), " ", "")), "ELBV2: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "ELBV2: signing region was not properly overridden") + + // Test KMS client + reqInfo = requestInfo{} + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating KMS client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, reqInfo.usedCustomEndpoint, "KMS: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ToLower(kms.ServiceID)), "KMS: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "KMS: signing region was not properly overridden") + + // Test Autoscaling client + reqInfo = requestInfo{} + autoscalingClient, err := mockProvider.Autoscaling(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating Autoscaling client, %v", err) + } + _, err = autoscalingClient.DescribeAutoScalingGroups(context.TODO(), &autoscaling.DescribeAutoScalingGroupsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "Autoscaling: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ReplaceAll(strings.ToLower(autoscaling.ServiceID), " ", "")), "Autoscaling: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "custom-region"), "Autoscaling: signing region was not properly overridden") + }) + + // When only the URL is overridden, and not the signing region or name, the URL should be whatever is configured in + // the override, the region should fall back to the request region, and the name should fall back to the client's + // service name. + t.Run("Only URL override", func(t *testing.T) { + cfgWithServiceOverride := CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "", + }, + "5": { + Service: imds.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "", + }, + "6": { + Service: autoscaling.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "", + SigningName: "", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // Test EC2 client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating EC2 client, %v", err) + } + _, err = ec2Client.DescribeVpcs(context.TODO(), &ec2.DescribeVpcsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "EC2: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ToLower(ec2.ServiceID)), "EC2: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "EC2: blank signing region should fall back to request region") + + // Test ELB client + reqInfo = requestInfo{} + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELB client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELB: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ReplaceAll(strings.ToLower(elb.ServiceID), " ", "")), "ELB: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "ELB: blank signing region should fall back to request region") + + // Test ELBV2 client + reqInfo = requestInfo{} + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating ELBV2 client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "ELBV2: custom endpoint was not used") + // ELB and ELBV2 use the same default signing name (https://docs.aws.amazon.com/general/latest/gr/elb.html) + assert.True(t, strings.Contains(reqInfo.credential, strings.ReplaceAll(strings.ToLower(elb.ServiceID), " ", "")), "ELBV2: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "ELBV2: blank signing region should fall back to request region") + + // Test KMS client + reqInfo = requestInfo{} + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating KMS client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, reqInfo.usedCustomEndpoint, "KMS: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ToLower(kms.ServiceID)), "KMS: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "KMS: blank signing region should fall back to request region") + + val := os.Getenv("AWS_EC2_METADATA_DISABLED") + // Test Metadata client. This client only supports overriding the URL, not the signing name and region. + reqInfo = requestInfo{} + // This client can only successfully make requests when AWS_EC2_METADATA_DISABLED = false. + // https://docs.aws.amazon.com/sdkref/latest/guide/feature-imds-credentials.html + os.Setenv("AWS_EC2_METADATA_DISABLED", "false") + // Make a IMDS client and make a request + metadataClient, err := mockProvider.Metadata(context.TODO()) + if err != nil { + t.Errorf("error creating Metadata client, %v", err) + } + _, err = metadataClient.GetRegion(context.TODO(), &imds.GetRegionInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "IMDS: custom endpoint was not used") + + // reset AWS_EC2_METADATA_DISABLED + os.Setenv("AWS_EC2_METADATA_DISABLED", val) + + // Test Autoscaling client + reqInfo = requestInfo{} + autoscalingClient, err := mockProvider.Autoscaling(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating Autoscaling client, %v", err) + } + _, err = autoscalingClient.DescribeAutoScalingGroups(context.TODO(), &autoscaling.DescribeAutoScalingGroupsInput{}) + assert.True(t, reqInfo.usedCustomEndpoint, "Autoscaling: custom endpoint was not used") + assert.True(t, strings.Contains(reqInfo.credential, strings.ReplaceAll(strings.ToLower(autoscaling.ServiceID), " ", "")), "Autoscaling: blank signing name should fall back to request service") + assert.True(t, strings.Contains(reqInfo.credential, "us-west-2"), "Autoscaling: blank signing region should fall back to request region") + + }) +} + +// Test whether SDK clients refrain from retrying an API request when given a nonRetryableError. +func TestClientsNoRetry(t *testing.T) { + attemptCount := 0 + // Dummy server that counts attempts and returns a nonRetryableError + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + w.Header().Set("Content-Type", "text/xml") + w.WriteHeader(http.StatusBadRequest) + + // Insert the nonRetryableError error message + errorXML := fmt.Sprintf(` + + + + %d + %s + + + 12345678-1234-1234-1234-123456789012 + `, http.StatusBadRequest, nonRetryableError) + + w.Write([]byte(errorXML)) + })) + defer testServer.Close() + + // Override service endpoints with dummy server URL + cfgWithServiceOverride := CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "5": { + Service: autoscaling.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // EC2 Client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = ec2Client.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{}) + // Ensure that only 1 attempt was made, signifying no retries + assert.True(t, attemptCount == 1, fmt.Sprintf("expected an attempt count of 1 for EC2 client, got %d", attemptCount)) + + // ELB Client + attemptCount = 0 // reset attempt count for next request + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, attemptCount == 1, fmt.Sprintf("expected an attempt count of 1 for ELB client, got %d", attemptCount)) + + // ELBV2 Client + attemptCount = 0 + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, attemptCount == 1, fmt.Sprintf("expected an attempt count of 1 for ELBV2 client, got %d", attemptCount)) + + // KMS Client + attemptCount = 0 + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, attemptCount == 1, fmt.Sprintf("expected an attempt count of 1 for KMS client, got %d", attemptCount)) + + // Autoscaling Client + attemptCount = 0 + autoscalingClient, err := mockProvider.Autoscaling(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = autoscalingClient.DescribeAutoScalingGroups(context.TODO(), &autoscaling.DescribeAutoScalingGroupsInput{}) + assert.True(t, attemptCount == 1, fmt.Sprintf("expected an attempt count of 1 for Autoscalig client, got %d", attemptCount)) + +} + +// Test whether SDK clients retry an API request when given a retryable error code. +func TestClientsWithRetry(t *testing.T) { + attemptCount := 0 + // Dummy server that counts attempts and returns a retryable error + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + // 500 status codes are retried by SDK (see https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws/retry) + http.Error(w, "RequestTimeout", 500) + })) + + // Override service endpoints with dummy server URL + cfgWithServiceOverride := CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: ec2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "2": { + Service: elb.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "3": { + Service: elbv2.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "4": { + Service: kms.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + "5": { + Service: autoscaling.ServiceID, + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + }, + } + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: make(map[string]*CrossRequestRetryDelay), + } + + // EC2 Client + ec2Client, err := mockProvider.Compute(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = ec2Client.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{}) + // Ensure that more than 1 attempt was made, signifying retries + assert.True(t, attemptCount > 1, fmt.Sprintf("expected an attempt count of >1 for EC2 client, got %d", attemptCount)) + + // ELB Client + attemptCount = 0 // Reset the attempt count before the next request + elbClient, err := mockProvider.LoadBalancing(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = elbClient.DescribeLoadBalancers(context.TODO(), &elb.DescribeLoadBalancersInput{}) + assert.True(t, attemptCount > 1, fmt.Sprintf("expected an attempt count of >1 for ELB client, got %d", attemptCount)) + + // ELBV2 Client + attemptCount = 0 + elbv2Client, err := mockProvider.LoadBalancingV2(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = elbv2Client.DescribeLoadBalancers(context.TODO(), &elbv2.DescribeLoadBalancersInput{}) + assert.True(t, attemptCount > 1, fmt.Sprintf("expected an attempt count of >1 for ELB client, got %d", attemptCount)) + + // KMS Client + attemptCount = 0 + kmsClient, err := mockProvider.KeyManagement(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = kmsClient.DescribeKey(context.TODO(), &kms.DescribeKeyInput{KeyId: aws.String("dummy")}) + assert.True(t, attemptCount > 1, fmt.Sprintf("expected an attempt count of >1 for KMS client, got %d", attemptCount)) + + // Autoscaling Client + attemptCount = 0 + autoscalingClient, err := mockProvider.Autoscaling(context.TODO(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + _, err = autoscalingClient.DescribeAutoScalingGroups(context.TODO(), &autoscaling.DescribeAutoScalingGroupsInput{}) + assert.True(t, attemptCount > 1, fmt.Sprintf("expected an attempt count of >1 for Autoscaling client, got %d", attemptCount)) +} diff --git a/pkg/providers/v1/aws_test.go b/pkg/providers/v1/aws_test.go index d04cbca790..ff7175c781 100644 --- a/pkg/providers/v1/aws_test.go +++ b/pkg/providers/v1/aws_test.go @@ -27,12 +27,16 @@ import ( "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elbv2" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing" + elbtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + + "github.com/aws/aws-sdk-go-v2/aws" + + "github.com/aws/smithy-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -55,54 +59,59 @@ type MockedFakeEC2 struct { } func (m *MockedFakeEC2) expectDescribeSecurityGroups(clusterID, groupName string) { - tags := []*ec2.Tag{ + tags := []ec2types.Tag{ {Key: aws.String(TagNameKubernetesClusterLegacy), Value: aws.String(clusterID)}, {Key: aws.String(fmt.Sprintf("%s%s", TagNameKubernetesClusterPrefix, clusterID)), Value: aws.String(ResourceLifecycleOwned)}, } - m.On("DescribeSecurityGroups", &ec2.DescribeSecurityGroupsInput{Filters: []*ec2.Filter{ + m.On("DescribeSecurityGroups", context.TODO(), &ec2.DescribeSecurityGroupsInput{Filters: []ec2types.Filter{ newEc2Filter("group-name", groupName), newEc2Filter("vpc-id", ""), - }}).Return([]*ec2.SecurityGroup{{Tags: tags}}) + }}).Return([]ec2types.SecurityGroup{{Tags: tags}}) } -func (m *MockedFakeEC2) DescribeVolumes(request *ec2.DescribeVolumesInput) ([]*ec2.Volume, error) { - args := m.Called(request) - return args.Get(0).([]*ec2.Volume), nil +func (m *MockedFakeEC2) DescribeVolumes(ctx context.Context, request *ec2.DescribeVolumesInput, optFns ...func(*ec2.Options)) ([]ec2types.Volume, error) { + args := m.Called(ctx, request) + return args.Get(0).([]ec2types.Volume), nil } -func (m *MockedFakeEC2) DeleteVolume(request *ec2.DeleteVolumeInput) (*ec2.DeleteVolumeOutput, error) { - args := m.Called(request) +func (m *MockedFakeEC2) DeleteVolume(ctx context.Context, request *ec2.DeleteVolumeInput, optFns ...func(*ec2.Options)) (*ec2.DeleteVolumeOutput, error) { + args := m.Called(ctx, request) return args.Get(0).(*ec2.DeleteVolumeOutput), nil } -func (m *MockedFakeEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) { - args := m.Called(request) - return args.Get(0).([]*ec2.SecurityGroup), nil +func (m *MockedFakeEC2) DescribeSecurityGroups(ctx context.Context, request *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) ([]ec2types.SecurityGroup, error) { + args := m.Called(ctx, request) + return args.Get(0).([]ec2types.SecurityGroup), nil } -func (m *MockedFakeEC2) CreateVolume(request *ec2.CreateVolumeInput) (*ec2.Volume, error) { +func (m *MockedFakeEC2) CreateVolume(ctx context.Context, request *ec2.CreateVolumeInput, optFns ...func(*ec2.Options)) (*ec2.CreateVolumeOutput, error) { // mock requires stable input, and in CreateDisk we invoke buildTags which uses // a map to create tags, which then get converted into an array. This leads to // unstable sorting order which confuses mock. Sorted tags are not needed in // regular code, but are a must in tests here: for i := 0; i < len(request.TagSpecifications); i++ { - if request.TagSpecifications[i] == nil { - continue - } tags := request.TagSpecifications[i].Tags sort.Slice(tags, func(i, j int) bool { - if tags[i] == nil && tags[j] != nil { + if tags[i].Key == nil && tags[j].Key != nil { return false } - if tags[i] != nil && tags[j] == nil { + if tags[i].Key != nil && tags[j].Key == nil { return true } return *tags[i].Key < *tags[j].Key }) } - args := m.Called(request) - return args.Get(0).(*ec2.Volume), nil + args := m.Called(ctx, request) + return args.Get(0).(*ec2.CreateVolumeOutput), nil +} + +func (m *MockedFakeEC2) DescribeInstanceTopology(ctx context.Context, request *ec2.DescribeInstanceTopologyInput, optFns ...func(*ec2.Options)) ([]ec2types.InstanceTopology, error) { + args := m.Called(ctx, request) + if args.Get(1) != nil { + return nil, args.Get(1).(error) + } + return args.Get(0).([]ec2types.InstanceTopology), nil } type MockedFakeELB struct { @@ -110,33 +119,33 @@ type MockedFakeELB struct { mock.Mock } -func (m *MockedFakeELB) DescribeLoadBalancers(input *elb.DescribeLoadBalancersInput) (*elb.DescribeLoadBalancersOutput, error) { - args := m.Called(input) +func (m *MockedFakeELB) DescribeLoadBalancers(ctx context.Context, input *elb.DescribeLoadBalancersInput, optFns ...func(*elb.Options)) (*elb.DescribeLoadBalancersOutput, error) { + args := m.Called(ctx, input) return args.Get(0).(*elb.DescribeLoadBalancersOutput), nil } func (m *MockedFakeELB) expectDescribeLoadBalancers(loadBalancerName string) { - m.On("DescribeLoadBalancers", &elb.DescribeLoadBalancersInput{LoadBalancerNames: []*string{aws.String(loadBalancerName)}}).Return(&elb.DescribeLoadBalancersOutput{ - LoadBalancerDescriptions: []*elb.LoadBalancerDescription{{}}, + m.On("DescribeLoadBalancers", context.TODO(), &elb.DescribeLoadBalancersInput{LoadBalancerNames: []string{loadBalancerName}}).Return(&elb.DescribeLoadBalancersOutput{ + LoadBalancerDescriptions: []elbtypes.LoadBalancerDescription{{}}, }) } -func (m *MockedFakeELB) AddTags(input *elb.AddTagsInput) (*elb.AddTagsOutput, error) { - args := m.Called(input) +func (m *MockedFakeELB) AddTags(ctx context.Context, input *elb.AddTagsInput, optFns ...func(*elb.Options)) (*elb.AddTagsOutput, error) { + args := m.Called(ctx, input) return args.Get(0).(*elb.AddTagsOutput), nil } -func (m *MockedFakeELB) ConfigureHealthCheck(input *elb.ConfigureHealthCheckInput) (*elb.ConfigureHealthCheckOutput, error) { - args := m.Called(input) +func (m *MockedFakeELB) ConfigureHealthCheck(ctx context.Context, input *elb.ConfigureHealthCheckInput, optFns ...func(*elb.Options)) (*elb.ConfigureHealthCheckOutput, error) { + args := m.Called(ctx, input) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*elb.ConfigureHealthCheckOutput), args.Error(1) } -func (m *MockedFakeELB) expectConfigureHealthCheck(loadBalancerName *string, expectedHC *elb.HealthCheck, returnErr error) { +func (m *MockedFakeELB) expectConfigureHealthCheck(loadBalancerName *string, expectedHC *elbtypes.HealthCheck, returnErr error) { expected := &elb.ConfigureHealthCheckInput{HealthCheck: expectedHC, LoadBalancerName: loadBalancerName} - call := m.On("ConfigureHealthCheck", expected) + call := m.On("ConfigureHealthCheck", context.TODO(), expected) if returnErr != nil { call.Return(nil, returnErr) } else { @@ -193,16 +202,15 @@ type ServiceDescriptor struct { signingName string } -func TestOverridesActiveConfig(t *testing.T) { +func TestValidateOverridesActiveConfig(t *testing.T) { tests := []struct { name string reader io.Reader aws Services - expectError bool - active bool - servicesOverridden []ServiceDescriptor + expectError bool + active bool }{ { "No overrides", @@ -211,7 +219,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, false, false, - []ServiceDescriptor{}, }, { "Missing Service Name", @@ -226,7 +233,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, true, false, - []ServiceDescriptor{}, }, { "Missing Service Region", @@ -241,7 +247,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, true, false, - []ServiceDescriptor{}, }, { "Missing URL", @@ -256,7 +261,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, true, false, - []ServiceDescriptor{}, }, { "Missing Signing Region", @@ -271,7 +275,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, true, false, - []ServiceDescriptor{}, }, { "Active Overrides", @@ -287,7 +290,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, false, true, - []ServiceDescriptor{{name: "s3", region: "sregion", signingRegion: "sregion", signingMethod: "v4"}}, }, { "Multiple Overridden Services", @@ -310,8 +312,6 @@ func TestOverridesActiveConfig(t *testing.T) { SigningMethod = v4`), nil, false, true, - []ServiceDescriptor{{name: "s3", region: "sregion1", signingRegion: "sregion1", signingMethod: "v4"}, - {name: "ec2", region: "sregion2", signingRegion: "sregion2", signingMethod: "v4"}}, }, { "Duplicate Services", @@ -334,7 +334,6 @@ func TestOverridesActiveConfig(t *testing.T) { SigningMethod = sign`), nil, true, false, - []ServiceDescriptor{}, }, { "Multiple Overridden Services in Multiple regions", @@ -356,8 +355,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, false, true, - []ServiceDescriptor{{name: "s3", region: "region1", signingRegion: "sregion1", signingMethod: ""}, - {name: "ec2", region: "region2", signingRegion: "sregion", signingMethod: "v4"}}, }, { "Multiple regions, Same Service", @@ -381,8 +378,6 @@ func TestOverridesActiveConfig(t *testing.T) { `), nil, false, true, - []ServiceDescriptor{{name: "s3", region: "region1", signingRegion: "sregion1", signingMethod: "v3"}, - {name: "s3", region: "region2", signingRegion: "sregion1", signingMethod: "v4", signingName: "name"}}, }, } @@ -400,71 +395,6 @@ func TestOverridesActiveConfig(t *testing.T) { if err != nil { t.Errorf("Should succeed for case: %s, got %v", test.name, err) } - - if len(cfg.ServiceOverride) != len(test.servicesOverridden) { - t.Errorf("Expected %d overridden services, received %d for case %s", - len(test.servicesOverridden), len(cfg.ServiceOverride), test.name) - } else { - for _, sd := range test.servicesOverridden { - var found *struct { - Service string - Region string - URL string - SigningRegion string - SigningMethod string - SigningName string - } - for _, v := range cfg.ServiceOverride { - if v.Service == sd.name && v.Region == sd.region { - found = v - break - } - } - if found == nil { - t.Errorf("Missing override for service %s in case %s", - sd.name, test.name) - } else { - if found.SigningRegion != sd.signingRegion { - t.Errorf("Expected signing region '%s', received '%s' for case %s", - sd.signingRegion, found.SigningRegion, test.name) - } - if found.SigningMethod != sd.signingMethod { - t.Errorf("Expected signing method '%s', received '%s' for case %s", - sd.signingMethod, found.SigningRegion, test.name) - } - targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) - if found.URL != targetName { - t.Errorf("Expected Endpoint '%s', received '%s' for case %s", - targetName, found.URL, test.name) - } - if found.SigningName != sd.signingName { - t.Errorf("Expected signing name '%s', received '%s' for case %s", - sd.signingName, found.SigningName, test.name) - } - - fn := cfg.getResolver() - ep1, e := fn(sd.name, sd.region, nil) - if e != nil { - t.Errorf("Expected a valid endpoint for %s in case %s", - sd.name, test.name) - } else { - targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) - if ep1.URL != targetName { - t.Errorf("Expected endpoint url: %s, received %s in case %s", - targetName, ep1.URL, test.name) - } - if ep1.SigningRegion != sd.signingRegion { - t.Errorf("Expected signing region '%s', received '%s' in case %s", - sd.signingRegion, ep1.SigningRegion, test.name) - } - if ep1.SigningMethod != sd.signingMethod { - t.Errorf("Expected signing method '%s', received '%s' in case %s", - sd.signingMethod, ep1.SigningRegion, test.name) - } - } - } - } - } } } } @@ -497,7 +427,7 @@ func TestNewAWSCloud(t *testing.T) { cfg, err := readAWSCloudConfig(test.reader) var c *Cloud if err == nil { - c, err = newAWSCloud(*cfg, test.awsServices) + c, err = newAWSCloud(*cfg, test.awsServices, nil) } if test.expectError { if err == nil { @@ -514,11 +444,11 @@ func TestNewAWSCloud(t *testing.T) { } } -func mockInstancesResp(selfInstance *ec2.Instance, instances []*ec2.Instance) (*Cloud, *FakeAWSServices) { +func mockInstancesResp(selfInstance *ec2types.Instance, instances []*ec2types.Instance) (*Cloud, *FakeAWSServices) { awsServices := newMockedFakeAWSServices(TestClusterID) awsServices.instances = instances awsServices.selfInstance = selfInstance - awsCloud, err := newAWSCloud(CloudConfig{}, awsServices) + awsCloud, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { panic(err) } @@ -542,7 +472,7 @@ func mockInstancesResp(selfInstance *ec2.Instance, instances []*ec2.Instance) (* func mockZone(region, availabilityZone string) *Cloud { awsServices := newMockedFakeAWSServices(TestClusterID).WithAz(availabilityZone).WithRegion(region) - awsCloud, err := newAWSCloud(CloudConfig{}, awsServices) + awsCloud, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { panic(err) } @@ -559,30 +489,30 @@ func testHasNodeAddress(t *testing.T, addrs []v1.NodeAddress, addressType v1.Nod t.Errorf("Did not find expected address: %s:%s in %v", addressType, address, addrs) } -func makeInstance(instanceID string, privateIP, publicIP, privateDNSName, publicDNSName string, ipv6s []string, setNetInterface bool) ec2.Instance { - var tag ec2.Tag +func makeInstance(instanceID string, privateIP, publicIP, privateDNSName, publicDNSName string, ipv6s []string, setNetInterface bool) ec2types.Instance { + var tag ec2types.Tag tag.Key = aws.String(TagNameKubernetesClusterLegacy) tag.Value = aws.String(TestClusterID) - tags := []*ec2.Tag{&tag} + tags := []ec2types.Tag{tag} - instance := ec2.Instance{ + instance := ec2types.Instance{ InstanceId: &instanceID, PrivateDnsName: aws.String(privateDNSName), PrivateIpAddress: aws.String(privateIP), PublicDnsName: aws.String(publicDNSName), PublicIpAddress: aws.String(publicIP), - InstanceType: aws.String("c3.large"), + InstanceType: ec2types.InstanceTypeC3Large, Tags: tags, - Placement: &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")}, - State: &ec2.InstanceState{ - Name: aws.String("running"), + Placement: &ec2types.Placement{AvailabilityZone: aws.String("us-east-1a")}, + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, } if setNetInterface == true { - instance.NetworkInterfaces = []*ec2.InstanceNetworkInterface{ + instance.NetworkInterfaces = []ec2types.InstanceNetworkInterface{ { - Status: aws.String(ec2.NetworkInterfaceStatusInUse), - PrivateIpAddresses: []*ec2.InstancePrivateIpAddress{ + Status: ec2types.NetworkInterfaceStatusInUse, + PrivateIpAddresses: []ec2types.InstancePrivateIpAddress{ { PrivateIpAddress: aws.String(privateIP), }, @@ -590,7 +520,7 @@ func makeInstance(instanceID string, privateIP, publicIP, privateDNSName, public }, } if len(ipv6s) > 0 { - instance.NetworkInterfaces[0].Ipv6Addresses = []*ec2.InstanceIpv6Address{ + instance.NetworkInterfaces[0].Ipv6Addresses = []ec2types.InstanceIpv6Address{ { Ipv6Address: aws.String(ipv6s[0]), }, @@ -654,7 +584,7 @@ func TestNodeAddressesByProviderID(t *testing.T) { } { t.Run(tc.Name, func(t *testing.T) { instance := makeInstance(tc.InstanceID, tc.PrivateIP, tc.PublicIP, tc.PrivateDNSName, tc.PublicDNSName, tc.Ipv6s, tc.SetNetInterface) - aws1, _ := mockInstancesResp(&instance, []*ec2.Instance{&instance}) + aws1, _ := mockInstancesResp(&instance, []*ec2types.Instance{&instance}) _, err := aws1.NodeAddressesByProviderID(context.TODO(), "i-xxx") if err == nil { t.Errorf("Should error when no instance found") @@ -760,7 +690,7 @@ func TestNodeAddresses(t *testing.T) { } { t.Run(tc.Name, func(t *testing.T) { instance := makeInstance(tc.InstanceID, tc.PrivateIP, tc.PublicIP, tc.PrivateDNSName, tc.PublicDNSName, tc.Ipv6s, tc.SetNetInterface) - aws1, _ := mockInstancesResp(&instance, []*ec2.Instance{&instance}) + aws1, _ := mockInstancesResp(&instance, []*ec2types.Instance{&instance}) _, err := aws1.NodeAddresses(context.TODO(), "instance-mismatch.ec2.internal") if err == nil { t.Errorf("Should error when no instance found") @@ -819,12 +749,12 @@ func TestGetRegion(t *testing.T) { func TestFindVPCID(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Errorf("Error building aws cloud: %v", err) return } - vpcID, err := c.findVPCID() + vpcID, err := c.findVPCID(context.TODO()) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -833,7 +763,7 @@ func TestFindVPCID(t *testing.T) { } } -func constructSubnets(subnetsIn map[int]map[string]string) (subnetsOut []*ec2.Subnet) { +func constructSubnets(subnetsIn map[int]map[string]string) (subnetsOut []*ec2types.Subnet) { for i := range subnetsIn { subnetsOut = append( subnetsOut, @@ -846,18 +776,18 @@ func constructSubnets(subnetsIn map[int]map[string]string) (subnetsOut []*ec2.Su return } -func constructSubnet(id string, az string) *ec2.Subnet { - return &ec2.Subnet{ +func constructSubnet(id string, az string) *ec2types.Subnet { + return &ec2types.Subnet{ SubnetId: &id, AvailabilityZone: &az, } } -func constructRouteTables(routeTablesIn map[string]bool) (routeTablesOut []*ec2.RouteTable) { +func constructRouteTables(routeTablesIn map[string]bool) (routeTablesOut []*ec2types.RouteTable) { routeTablesOut = append(routeTablesOut, - &ec2.RouteTable{ - Associations: []*ec2.RouteTableAssociation{{Main: aws.Bool(true)}}, - Routes: []*ec2.Route{{ + &ec2types.RouteTable{ + Associations: []ec2types.RouteTableAssociation{{Main: aws.Bool(true)}}, + Routes: []ec2types.Route{{ DestinationCidrBlock: aws.String("0.0.0.0/0"), GatewayId: aws.String("igw-main"), }}, @@ -875,16 +805,16 @@ func constructRouteTables(routeTablesIn map[string]bool) (routeTablesOut []*ec2. return } -func constructRouteTable(subnetID string, public bool) *ec2.RouteTable { +func constructRouteTable(subnetID string, public bool) *ec2types.RouteTable { var gatewayID string if public { gatewayID = "igw-" + subnetID[len(subnetID)-8:8] } else { gatewayID = "vgw-" + subnetID[len(subnetID)-8:8] } - return &ec2.RouteTable{ - Associations: []*ec2.RouteTableAssociation{{SubnetId: aws.String(subnetID)}}, - Routes: []*ec2.Route{{ + return &ec2types.RouteTable{ + Associations: []ec2types.RouteTableAssociation{{SubnetId: aws.String(subnetID)}}, + Routes: []ec2types.Route{{ DestinationCidrBlock: aws.String("0.0.0.0/0"), GatewayId: aws.String(gatewayID), }}, @@ -893,35 +823,35 @@ func constructRouteTable(subnetID string, public bool) *ec2.RouteTable { func Test_findELBSubnets(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Errorf("Error building aws cloud: %v", err) return } - subnetA0000001 := &ec2.Subnet{ + subnetA0000001 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2a"), SubnetId: aws.String("subnet-a0000001"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(TagNameSubnetPublicELB), Value: aws.String("1"), }, }, } - subnetA0000002 := &ec2.Subnet{ + subnetA0000002 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2a"), SubnetId: aws.String("subnet-a0000002"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(TagNameSubnetPublicELB), Value: aws.String("1"), }, }, } - subnetA0000003 := &ec2.Subnet{ + subnetA0000003 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2a"), SubnetId: aws.String("subnet-a0000003"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -932,10 +862,10 @@ func Test_findELBSubnets(t *testing.T) { }, }, } - subnetB0000001 := &ec2.Subnet{ + subnetB0000001 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2b"), SubnetId: aws.String("subnet-b0000001"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -946,10 +876,10 @@ func Test_findELBSubnets(t *testing.T) { }, }, } - subnetB0000002 := &ec2.Subnet{ + subnetB0000002 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2b"), SubnetId: aws.String("subnet-b0000002"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -960,10 +890,10 @@ func Test_findELBSubnets(t *testing.T) { }, }, } - subnetC0000001 := &ec2.Subnet{ + subnetC0000001 := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-c0000001"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -974,10 +904,10 @@ func Test_findELBSubnets(t *testing.T) { }, }, } - subnetOther := &ec2.Subnet{ + subnetOther := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-other"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(TagNameKubernetesClusterPrefix + "clusterid.other"), Value: aws.String("owned"), @@ -988,24 +918,24 @@ func Test_findELBSubnets(t *testing.T) { }, }, } - subnetNoTag := &ec2.Subnet{ + subnetNoTag := &ec2types.Subnet{ AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-notag"), } - subnetLocalZone := &ec2.Subnet{ + subnetLocalZone := &ec2types.Subnet{ AvailabilityZone: aws.String("az-local"), SubnetId: aws.String("subnet-in-local-zone"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), }, }, } - subnetWavelengthZone := &ec2.Subnet{ + subnetWavelengthZone := &ec2types.Subnet{ AvailabilityZone: aws.String("az-wavelength"), SubnetId: aws.String("subnet-in-wavelength-zone"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -1015,7 +945,7 @@ func Test_findELBSubnets(t *testing.T) { tests := []struct { name string - subnets []*ec2.Subnet + subnets []*ec2types.Subnet routeTables map[string]bool internal bool want []string @@ -1025,7 +955,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "single tagged subnet", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, }, routeTables: map[string]bool{ @@ -1036,7 +966,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "no matching public subnet", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000002, }, routeTables: map[string]bool{ @@ -1046,7 +976,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "prefer role over cluster tag", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, subnetA0000003, }, @@ -1058,7 +988,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "prefer cluster tag", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetC0000001, subnetNoTag, }, @@ -1066,7 +996,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "include untagged", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, subnetNoTag, }, @@ -1078,7 +1008,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "ignore some other cluster owned subnet", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetB0000001, subnetOther, }, @@ -1090,7 +1020,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "prefer matching role", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetB0000001, subnetB0000002, }, @@ -1103,7 +1033,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "choose lexicographic order", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, subnetA0000002, }, @@ -1115,7 +1045,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "everything", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, subnetA0000002, subnetB0000001, @@ -1137,7 +1067,7 @@ func Test_findELBSubnets(t *testing.T) { }, { name: "exclude subnets from local and wavelenght zones", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ subnetA0000001, subnetB0000001, subnetC0000001, @@ -1158,7 +1088,7 @@ func Test_findELBSubnets(t *testing.T) { for _, rt := range routeTables { awsServices.ec2.CreateRouteTable(rt) } - got, _ := c.findELBSubnets(tt.internal) + got, _ := c.findELBSubnets(context.TODO(), tt.internal) sort.Strings(tt.want) sort.Strings(got) assert.Equal(t, tt.want, got) @@ -1168,7 +1098,7 @@ func Test_findELBSubnets(t *testing.T) { func Test_getLoadBalancerSubnets(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -1176,7 +1106,7 @@ func Test_getLoadBalancerSubnets(t *testing.T) { tests := []struct { name string service *v1.Service - subnets []*ec2.Subnet + subnets []*ec2types.Subnet internalELB bool want []string wantErr error @@ -1198,7 +1128,7 @@ func Test_getLoadBalancerSubnets(t *testing.T) { }, { name: "subnet ids", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ { AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-a000001"), @@ -1219,7 +1149,7 @@ func Test_getLoadBalancerSubnets(t *testing.T) { }, { name: "subnet names", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ { AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-a000001"), @@ -1240,7 +1170,7 @@ func Test_getLoadBalancerSubnets(t *testing.T) { }, { name: "unable to find all subnets", - subnets: []*ec2.Subnet{ + subnets: []*ec2types.Subnet{ { AvailabilityZone: aws.String("us-west-2c"), SubnetId: aws.String("subnet-a000001"), @@ -1262,7 +1192,7 @@ func Test_getLoadBalancerSubnets(t *testing.T) { for _, subnet := range tt.subnets { awsServices.ec2.CreateSubnet(subnet) } - got, err := c.getLoadBalancerSubnets(tt.service, tt.internalELB) + got, err := c.getLoadBalancerSubnets(context.TODO(), tt.service, tt.internalELB) if tt.wantErr != nil { assert.EqualError(t, err, tt.wantErr.Error()) } else { @@ -1274,7 +1204,7 @@ func Test_getLoadBalancerSubnets(t *testing.T) { func TestSubnetIDsinVPC(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -1308,7 +1238,7 @@ func TestSubnetIDsinVPC(t *testing.T) { awsServices.ec2.CreateRouteTable(rt) } - result, err := c.findELBSubnets(false) + result, err := c.findELBSubnets(context.TODO(), false) if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -1338,7 +1268,7 @@ func TestSubnetIDsinVPC(t *testing.T) { awsServices.ec2.CreateRouteTable(rt) } - result, err = c.findELBSubnets(false) + result, err = c.findELBSubnets(context.TODO(), false) if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -1384,7 +1314,7 @@ func TestSubnetIDsinVPC(t *testing.T) { awsServices.ec2.CreateRouteTable(rt) } - result, err = c.findELBSubnets(false) + result, err = c.findELBSubnets(context.TODO(), false) if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -1395,7 +1325,7 @@ func TestSubnetIDsinVPC(t *testing.T) { return } - expected := []*string{aws.String("subnet-a0000001"), aws.String("subnet-b0000001"), aws.String("subnet-c0000000")} + expected := []string{"subnet-a0000001", "subnet-b0000001", "subnet-c0000000"} for _, s := range result { if !contains(expected, s) { t.Errorf("Unexpected subnet '%s' found", s) @@ -1431,7 +1361,7 @@ func TestSubnetIDsinVPC(t *testing.T) { for _, rt := range constructedRouteTables { awsServices.ec2.CreateRouteTable(rt) } - result, err = c.findELBSubnets(false) + result, err = c.findELBSubnets(context.TODO(), false) if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -1442,7 +1372,7 @@ func TestSubnetIDsinVPC(t *testing.T) { return } - expected = []*string{aws.String("subnet-c0000000"), aws.String("subnet-d0000001"), aws.String("subnet-d0000002")} + expected = []string{"subnet-c0000000", "subnet-d0000001", "subnet-d0000002"} for _, s := range result { if !contains(expected, s) { t.Errorf("Unexpected subnet '%s' found", s) @@ -1452,22 +1382,22 @@ func TestSubnetIDsinVPC(t *testing.T) { } func TestIpPermissionExistsHandlesMultipleGroupIds(t *testing.T) { - oldIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + oldIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("firstGroupId")}, {GroupId: aws.String("secondGroupId")}, {GroupId: aws.String("thirdGroupId")}, }, } - existingIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + existingIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("secondGroupId")}, }, } - newIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + newIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("fourthGroupId")}, }, } @@ -1483,8 +1413,8 @@ func TestIpPermissionExistsHandlesMultipleGroupIds(t *testing.T) { } // The first pair matches, but the second does not - newIPPermission2 := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + newIPPermission2 := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("firstGroupId")}, {GroupId: aws.String("fourthGroupId")}, }, @@ -1497,29 +1427,29 @@ func TestIpPermissionExistsHandlesMultipleGroupIds(t *testing.T) { func TestIpPermissionExistsHandlesRangeSubsets(t *testing.T) { // Two existing scenarios we'll test against - emptyIPPermission := ec2.IpPermission{} + emptyIPPermission := ec2types.IpPermission{} - oldIPPermission := ec2.IpPermission{ - IpRanges: []*ec2.IpRange{ + oldIPPermission := ec2types.IpPermission{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("10.0.0.0/8")}, {CidrIp: aws.String("192.168.1.0/24")}, }, } // Two already existing ranges and a new one - existingIPPermission := ec2.IpPermission{ - IpRanges: []*ec2.IpRange{ + existingIPPermission := ec2types.IpPermission{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("10.0.0.0/8")}, }, } - existingIPPermission2 := ec2.IpPermission{ - IpRanges: []*ec2.IpRange{ + existingIPPermission2 := ec2types.IpPermission{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("192.168.1.0/24")}, }, } - newIPPermission := ec2.IpPermission{ - IpRanges: []*ec2.IpRange{ + newIPPermission := ec2types.IpPermission{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("172.16.0.0/16")}, }, } @@ -1553,22 +1483,22 @@ func TestIpPermissionExistsHandlesRangeSubsets(t *testing.T) { } func TestIpPermissionExistsHandlesMultipleGroupIdsWithUserIds(t *testing.T) { - oldIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + oldIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("firstGroupId"), UserId: aws.String("firstUserId")}, {GroupId: aws.String("secondGroupId"), UserId: aws.String("secondUserId")}, {GroupId: aws.String("thirdGroupId"), UserId: aws.String("thirdUserId")}, }, } - existingIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + existingIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("secondGroupId"), UserId: aws.String("secondUserId")}, }, } - newIPPermission := ec2.IpPermission{ - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + newIPPermission := ec2types.IpPermission{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ {GroupId: aws.String("secondGroupId"), UserId: aws.String("anotherUserId")}, }, } @@ -1586,45 +1516,45 @@ func TestIpPermissionExistsHandlesMultipleGroupIdsWithUserIds(t *testing.T) { func TestFindInstanceByNodeNameExcludesTerminatedInstances(t *testing.T) { awsStates := []struct { - id int64 - state string + id int32 + state ec2types.InstanceStateName expected bool }{ - {0, ec2.InstanceStateNamePending, true}, - {16, ec2.InstanceStateNameRunning, true}, - {32, ec2.InstanceStateNameShuttingDown, true}, - {48, ec2.InstanceStateNameTerminated, false}, - {64, ec2.InstanceStateNameStopping, true}, - {80, ec2.InstanceStateNameStopped, true}, + {0, ec2types.InstanceStateNamePending, true}, + {16, ec2types.InstanceStateNameRunning, true}, + {32, ec2types.InstanceStateNameShuttingDown, true}, + {48, ec2types.InstanceStateNameTerminated, false}, + {64, ec2types.InstanceStateNameStopping, true}, + {80, ec2types.InstanceStateNameStopped, true}, } awsServices := newMockedFakeAWSServices(TestClusterID) nodeName := types.NodeName("my-dns.internal") - var tag ec2.Tag + var tag ec2types.Tag tag.Key = aws.String(TagNameKubernetesClusterLegacy) tag.Value = aws.String(TestClusterID) - tags := []*ec2.Tag{&tag} + tags := []ec2types.Tag{tag} - var testInstance ec2.Instance + var testInstance ec2types.Instance testInstance.PrivateDnsName = aws.String(string(nodeName)) testInstance.Tags = tags awsDefaultInstances := awsServices.instances for _, awsState := range awsStates { - id := "i-" + awsState.state + id := string("i-" + awsState.state) testInstance.InstanceId = aws.String(id) - testInstance.State = &ec2.InstanceState{Code: aws.Int64(awsState.id), Name: aws.String(awsState.state)} + testInstance.State = &ec2types.InstanceState{Code: aws.Int32(awsState.id), Name: awsState.state} awsServices.instances = append(awsDefaultInstances, &testInstance) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Errorf("Error building aws cloud: %v", err) return } - resultInstance, err := c.findInstanceByNodeName(nodeName) + resultInstance, err := c.findInstanceByNodeName(context.TODO(), nodeName) if awsState.expected { if err != nil || resultInstance == nil { @@ -1646,27 +1576,26 @@ func TestFindInstanceByNodeNameExcludesTerminatedInstances(t *testing.T) { func TestGetInstanceByNodeNameBatching(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) assert.Nil(t, err, "Error building aws cloud: %v", err) - var tag ec2.Tag + var tag ec2types.Tag tag.Key = aws.String(TagNameKubernetesClusterPrefix + TestClusterID) tag.Value = aws.String("") - tags := []*ec2.Tag{&tag} + tags := []ec2types.Tag{tag} nodeNames := []string{} for i := 0; i < 200; i++ { nodeName := fmt.Sprintf("ip-171-20-42-%d.ec2.internal", i) nodeNames = append(nodeNames, nodeName) - ec2Instance := &ec2.Instance{} + ec2Instance := &ec2types.Instance{} instanceID := fmt.Sprintf("i-abcedf%d", i) ec2Instance.InstanceId = aws.String(instanceID) ec2Instance.PrivateDnsName = aws.String(nodeName) - ec2Instance.State = &ec2.InstanceState{Code: aws.Int64(48), Name: aws.String("running")} + ec2Instance.State = &ec2types.InstanceState{Code: aws.Int32(48), Name: ec2types.InstanceStateNameRunning} ec2Instance.Tags = tags awsServices.instances = append(awsServices.instances, ec2Instance) - } - instances, err := c.getInstancesByNodeNames(nodeNames) + instances, err := c.getInstancesByNodeNames(context.TODO(), nodeNames) assert.Nil(t, err, "Error getting instances by nodeNames %v: %v", nodeNames, err) assert.NotEmpty(t, instances) assert.Equal(t, 200, len(instances), "Expected 200 but got less") @@ -1674,18 +1603,18 @@ func TestGetInstanceByNodeNameBatching(t *testing.T) { func TestGetVolumeLabels(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) assert.Nil(t, err, "Error building aws cloud: %v", err) volumeID := EBSVolumeID("vol-VolumeId") - expectedVolumeRequest := &ec2.DescribeVolumesInput{VolumeIds: []*string{volumeID.awsString()}} - awsServices.ec2.(*MockedFakeEC2).On("DescribeVolumes", expectedVolumeRequest).Return([]*ec2.Volume{ + expectedVolumeRequest := &ec2.DescribeVolumesInput{VolumeIds: []string{string(volumeID)}} + awsServices.ec2.(*MockedFakeEC2).On("DescribeVolumes", context.TODO(), expectedVolumeRequest).Return([]ec2types.Volume{ { VolumeId: volumeID.awsString(), AvailabilityZone: aws.String("us-east-1a"), }, }) - labels, err := c.GetVolumeLabels(KubernetesVolumeID("aws:///" + string(volumeID))) + labels, err := c.GetVolumeLabels(context.TODO(), KubernetesVolumeID("aws:///"+string(volumeID))) assert.Nil(t, err, "Error creating Volume %v", err) assert.Equal(t, map[string]string{ @@ -1699,8 +1628,8 @@ func TestGetLabelsForVolume(t *testing.T) { tests := []struct { name string pv *v1.PersistentVolume - expectedVolumeID *string - expectedEC2Volumes []*ec2.Volume + expectedVolumeID string + expectedEC2Volumes []ec2types.Volume expectedLabels map[string]string expectedError error }{ @@ -1709,8 +1638,8 @@ func TestGetLabelsForVolume(t *testing.T) { &v1.PersistentVolume{ Spec: v1.PersistentVolumeSpec{}, }, - nil, - nil, + "", + []ec2types.Volume{}, nil, nil, }, @@ -1725,8 +1654,8 @@ func TestGetLabelsForVolume(t *testing.T) { }, }, }, - nil, - nil, + "", + []ec2types.Volume{}, nil, nil, }, @@ -1741,8 +1670,8 @@ func TestGetLabelsForVolume(t *testing.T) { }, }, }, - defaultVolume, - nil, + aws.ToString(defaultVolume), + []ec2types.Volume{}, nil, fmt.Errorf("no volumes found"), }, @@ -1757,8 +1686,8 @@ func TestGetLabelsForVolume(t *testing.T) { }, }, }, - defaultVolume, - []*ec2.Volume{{ + aws.ToString(defaultVolume), + []ec2types.Volume{{ VolumeId: defaultVolume, AvailabilityZone: aws.String("us-east-1a"), }}, @@ -1772,10 +1701,10 @@ func TestGetLabelsForVolume(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - expectedVolumeRequest := &ec2.DescribeVolumesInput{VolumeIds: []*string{test.expectedVolumeID}} - awsServices.ec2.(*MockedFakeEC2).On("DescribeVolumes", expectedVolumeRequest).Return(test.expectedEC2Volumes) + expectedVolumeRequest := &ec2.DescribeVolumesInput{VolumeIds: []string{test.expectedVolumeID}} + awsServices.ec2.(*MockedFakeEC2).On("DescribeVolumes", context.TODO(), expectedVolumeRequest).Return(test.expectedEC2Volumes) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) assert.Nil(t, err, "Error building aws cloud: %v", err) l, err := c.GetLabelsForVolume(context.TODO(), test.pv) @@ -1788,7 +1717,7 @@ func TestGetLabelsForVolume(t *testing.T) { func TestDescribeLoadBalancerOnDelete(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.EnsureLoadBalancerDeleted(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}) @@ -1796,7 +1725,7 @@ func TestDescribeLoadBalancerOnDelete(t *testing.T) { func TestDescribeLoadBalancerOnUpdate(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.UpdateLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}, []*v1.Node{}) @@ -1804,7 +1733,7 @@ func TestDescribeLoadBalancerOnUpdate(t *testing.T) { func TestDescribeLoadBalancerOnGet(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.GetLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}) @@ -1812,7 +1741,7 @@ func TestDescribeLoadBalancerOnGet(t *testing.T) { func TestDescribeLoadBalancerOnEnsure(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.EnsureLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}, []*v1.Node{}) @@ -1928,9 +1857,9 @@ func TestBuildListener(t *testing.T) { tests := []struct { name string - lbPort int64 + lbPort int32 portName string - instancePort int64 + instancePort int32 backendProtocolAnnotation string certAnnotation string sslPortAnnotation string @@ -2040,10 +1969,10 @@ func TestBuildListener(t *testing.T) { if test.certID != "" { cert = &test.certID } - expected := &elb.Listener{ + expected := elbtypes.Listener{ InstancePort: &test.instancePort, InstanceProtocol: &test.instanceProtocol, - LoadBalancerPort: &test.lbPort, + LoadBalancerPort: test.lbPort, Protocol: &test.lbProtocol, SSLCertificateId: cert, } @@ -2057,27 +1986,25 @@ func TestBuildListener(t *testing.T) { } func TestProxyProtocolEnabled(t *testing.T) { - policies := sets.NewString(ProxyProtocolPolicyName, "FooBarFoo") - fakeBackend := &elb.BackendServerDescription{ - InstancePort: aws.Int64(80), - PolicyNames: stringSetToPointers(policies), + policies := []string{ProxyProtocolPolicyName, "FooBarFoo"} + fakeBackend := elbtypes.BackendServerDescription{ + InstancePort: aws.Int32(80), + PolicyNames: policies, } result := proxyProtocolEnabled(fakeBackend) assert.True(t, result, "expected to find %s in %s", ProxyProtocolPolicyName, policies) - policies = sets.NewString("FooBarFoo") - fakeBackend = &elb.BackendServerDescription{ - InstancePort: aws.Int64(80), - PolicyNames: []*string{ - aws.String("FooBarFoo"), - }, + policies = []string{"FooBarFoo"} + fakeBackend = elbtypes.BackendServerDescription{ + InstancePort: aws.Int32(80), + PolicyNames: []string{"FooBarFoo"}, } result = proxyProtocolEnabled(fakeBackend) assert.False(t, result, "did not expect to find %s in %s", ProxyProtocolPolicyName, policies) - policies = sets.NewString() - fakeBackend = &elb.BackendServerDescription{ - InstancePort: aws.Int64(80), + policies = []string{} + fakeBackend = elbtypes.BackendServerDescription{ + InstancePort: aws.Int32(80), } result = proxyProtocolEnabled(fakeBackend) assert.False(t, result, "did not expect to find %s in %s", ProxyProtocolPolicyName, policies) @@ -2150,7 +2077,7 @@ func TestGetKeyValuePropertiesFromAnnotation(t *testing.T) { func TestLBExtraSecurityGroupsAnnotation(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) sg1 := map[string]string{ServiceAnnotationLoadBalancerExtraSecurityGroups: "sg-000001"} sg2 := map[string]string{ServiceAnnotationLoadBalancerExtraSecurityGroups: "sg-000002"} @@ -2174,7 +2101,7 @@ func TestLBExtraSecurityGroupsAnnotation(t *testing.T) { t.Run(test.name, func(t *testing.T) { serviceName := types.NamespacedName{Namespace: "default", Name: "myservice"} - sgList, setupSg, err := c.buildELBSecurityGroupList(serviceName, "aid", test.annotations) + sgList, setupSg, err := c.buildELBSecurityGroupList(context.TODO(), serviceName, "aid", test.annotations) assert.NoError(t, err, "buildELBSecurityGroupList failed") extraSGs := sgList[1:] assert.True(t, sets.NewString(test.expectedSGs...).Equal(sets.NewString(extraSGs...)), @@ -2186,7 +2113,7 @@ func TestLBExtraSecurityGroupsAnnotation(t *testing.T) { func TestLBSecurityGroupsAnnotation(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) sg1 := map[string]string{ServiceAnnotationLoadBalancerSecurityGroups: "sg-000001"} sg2 := map[string]string{ServiceAnnotationLoadBalancerSecurityGroups: "sg-000002"} @@ -2208,7 +2135,7 @@ func TestLBSecurityGroupsAnnotation(t *testing.T) { t.Run(test.name, func(t *testing.T) { serviceName := types.NamespacedName{Namespace: "default", Name: "myservice"} - sgList, setupSg, err := c.buildELBSecurityGroupList(serviceName, "aid", test.annotations) + sgList, setupSg, err := c.buildELBSecurityGroupList(context.TODO(), serviceName, "aid", test.annotations) assert.NoError(t, err, "buildELBSecurityGroupList failed") assert.True(t, sets.NewString(test.expectedSGs...).Equal(sets.NewString(sgList...)), "Security Groups expected=%q , returned=%q", test.expectedSGs, sgList) @@ -2221,23 +2148,23 @@ func TestLBSecurityGroupsAnnotation(t *testing.T) { func TestAddLoadBalancerTags(t *testing.T) { loadBalancerName := "test-elb" awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) want := make(map[string]string) want["tag1"] = "val1" expectedAddTagsRequest := &elb.AddTagsInput{ - LoadBalancerNames: []*string{&loadBalancerName}, - Tags: []*elb.Tag{ + LoadBalancerNames: []string{loadBalancerName}, + Tags: []elbtypes.Tag{ { Key: aws.String("tag1"), Value: aws.String("val1"), }, }, } - awsServices.elb.(*MockedFakeELB).On("AddTags", expectedAddTagsRequest).Return(&elb.AddTagsOutput{}) + awsServices.elb.(*MockedFakeELB).On("AddTags", context.TODO(), expectedAddTagsRequest).Return(&elb.AddTagsOutput{}) - err := c.addLoadBalancerTags(loadBalancerName, want) + err := c.addLoadBalancerTags(context.TODO(), loadBalancerName, want) assert.Nil(t, err, "Error adding load balancer tags: %v", err) awsServices.elb.(*MockedFakeELB).AssertExpectations(t) } @@ -2246,60 +2173,60 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { tests := []struct { name string annotations map[string]string - want elb.HealthCheck + want elbtypes.HealthCheck }{ { name: "falls back to HC defaults", annotations: map[string]string{}, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("TCP:8080"), }, }, { name: "healthy threshold override", annotations: map[string]string{ServiceAnnotationLoadBalancerHCHealthyThreshold: "7"}, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(7), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(7), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("TCP:8080"), }, }, { name: "unhealthy threshold override", annotations: map[string]string{ServiceAnnotationLoadBalancerHCUnhealthyThreshold: "7"}, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(7), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(7), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("TCP:8080"), }, }, { name: "timeout override", annotations: map[string]string{ServiceAnnotationLoadBalancerHCTimeout: "7"}, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(7), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(7), + Interval: aws.Int32(10), Target: aws.String("TCP:8080"), }, }, { name: "interval override", annotations: map[string]string{ServiceAnnotationLoadBalancerHCInterval: "7"}, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(7), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(7), Target: aws.String("TCP:8080"), }, }, @@ -2308,11 +2235,11 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { annotations: map[string]string{ ServiceAnnotationLoadBalancerHealthCheckPort: "2122", }, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("TCP:2122"), }, }, @@ -2321,11 +2248,11 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { annotations: map[string]string{ ServiceAnnotationLoadBalancerHealthCheckProtocol: "HTTP", }, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("HTTP:8080/"), }, }, @@ -2336,11 +2263,11 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { ServiceAnnotationLoadBalancerHealthCheckPath: "/healthz", ServiceAnnotationLoadBalancerHealthCheckPort: "31224", }, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("HTTPS:31224/healthz"), }, }, @@ -2351,11 +2278,11 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { ServiceAnnotationLoadBalancerHealthCheckPath: "/healthz", ServiceAnnotationLoadBalancerHealthCheckPort: "3124", }, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("SSL:3124"), }, }, @@ -2365,11 +2292,11 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { ServiceAnnotationLoadBalancerHealthCheckProtocol: "TCP", ServiceAnnotationLoadBalancerHealthCheckPort: "traffic-port", }, - want: elb.HealthCheck{ - HealthyThreshold: aws.Int64(2), - UnhealthyThreshold: aws.Int64(6), - Timeout: aws.Int64(5), - Interval: aws.Int64(10), + want: elbtypes.HealthCheck{ + HealthyThreshold: aws.Int32(2), + UnhealthyThreshold: aws.Int32(6), + Timeout: aws.Int32(5), + Interval: aws.Int32(10), Target: aws.String("TCP:8080"), }, }, @@ -2377,15 +2304,15 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { lbName := "myLB" // this HC will always differ from the expected HC and thus it is expected an // API call will be made to update it - currentHC := &elb.HealthCheck{} - elbDesc := &elb.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: currentHC} - defaultHealthyThreshold := int64(2) - defaultUnhealthyThreshold := int64(6) - defaultTimeout := int64(5) - defaultInterval := int64(10) + currentHC := &elbtypes.HealthCheck{} + elbDesc := &elbtypes.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: currentHC} + defaultHealthyThreshold := int32(2) + defaultUnhealthyThreshold := int32(6) + defaultTimeout := int32(5) + defaultInterval := int32(10) protocol, path, port := "TCP", "", int32(8080) target := "TCP:8080" - defaultHC := &elb.HealthCheck{ + defaultHC := &elbtypes.HealthCheck{ HealthyThreshold: &defaultHealthyThreshold, UnhealthyThreshold: &defaultUnhealthyThreshold, Timeout: &defaultTimeout, @@ -2395,12 +2322,12 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) assert.Nil(t, err, "Error building aws cloud: %v", err) expectedHC := test.want awsServices.elb.(*MockedFakeELB).expectConfigureHealthCheck(&lbName, &expectedHC, nil) - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, test.annotations) + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, test.annotations) require.NoError(t, err) awsServices.elb.(*MockedFakeELB).AssertExpectations(t) @@ -2409,62 +2336,62 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { t.Run("does not make an API call if the current health check is the same", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) assert.Nil(t, err, "Error building aws cloud: %v", err) expectedHC := *defaultHC - timeout := int64(3) - expectedHC.Timeout = &timeout + timeout := int32(3) + expectedHC.Timeout = aws.Int32(timeout) annotations := map[string]string{ServiceAnnotationLoadBalancerHCTimeout: "3"} - var currentHC elb.HealthCheck + var currentHC elbtypes.HealthCheck currentHC = expectedHC // NOTE no call expectations are set on the ELB mock // test default HC - elbDesc := &elb.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: defaultHC} - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, map[string]string{}) + elbDesc := &elbtypes.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: defaultHC} + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, map[string]string{}) assert.NoError(t, err) // test HC with override - elbDesc = &elb.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: ¤tHC} - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, annotations) + elbDesc = &elbtypes.LoadBalancerDescription{LoadBalancerName: &lbName, HealthCheck: ¤tHC} + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, annotations) assert.NoError(t, err) }) t.Run("validates resulting expected health check before making an API call", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) assert.Nil(t, err, "Error building aws cloud: %v", err) expectedHC := *defaultHC - invalidThreshold := int64(1) - expectedHC.HealthyThreshold = &invalidThreshold - require.Error(t, expectedHC.Validate()) // confirm test precondition + invalidThreshold := int32(1) + expectedHC.HealthyThreshold = aws.Int32(invalidThreshold) + require.Error(t, ValidateHealthCheck(&expectedHC)) // confirm test precondition annotations := map[string]string{ServiceAnnotationLoadBalancerHCTimeout: "1"} // NOTE no call expectations are set on the ELB mock - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, annotations) + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, annotations) require.Error(t, err) }) t.Run("handles invalid override values", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) assert.Nil(t, err, "Error building aws cloud: %v", err) annotations := map[string]string{ServiceAnnotationLoadBalancerHCTimeout: "3.3"} // NOTE no call expectations are set on the ELB mock - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, annotations) + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, annotations) require.Error(t, err) }) t.Run("returns error when updating the health check fails", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) assert.Nil(t, err, "Error building aws cloud: %v", err) returnErr := fmt.Errorf("throttling error") awsServices.elb.(*MockedFakeELB).expectConfigureHealthCheck(&lbName, defaultHC, returnErr) - err = c.ensureLoadBalancerHealthCheck(elbDesc, protocol, port, path, map[string]string{}) + err = c.ensureLoadBalancerHealthCheck(context.TODO(), elbDesc, protocol, port, path, map[string]string{}) require.Error(t, err) awsServices.elb.(*MockedFakeELB).AssertExpectations(t) @@ -2472,8 +2399,8 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { } func TestFindSecurityGroupForInstance(t *testing.T) { - groups := map[string]*ec2.SecurityGroup{"sg123": {GroupId: aws.String("sg123")}} - id, err := findSecurityGroupForInstance(&ec2.Instance{SecurityGroups: []*ec2.GroupIdentifier{{GroupId: aws.String("sg123"), GroupName: aws.String("my_group")}}}, groups) + groups := map[string]*ec2types.SecurityGroup{"sg123": {GroupId: aws.String("sg123")}} + id, err := findSecurityGroupForInstance(&ec2types.Instance{SecurityGroups: []ec2types.GroupIdentifier{{GroupId: aws.String("sg123"), GroupName: aws.String("my_group")}}}, groups) if err != nil { t.Error() } @@ -2482,9 +2409,9 @@ func TestFindSecurityGroupForInstance(t *testing.T) { } func TestFindSecurityGroupForInstanceMultipleTagged(t *testing.T) { - groups := map[string]*ec2.SecurityGroup{"sg123": {GroupId: aws.String("sg123")}} - _, err := findSecurityGroupForInstance(&ec2.Instance{ - SecurityGroups: []*ec2.GroupIdentifier{ + groups := map[string]*ec2types.SecurityGroup{"sg123": {GroupId: aws.String("sg123")}} + _, err := findSecurityGroupForInstance(&ec2types.Instance{ + SecurityGroups: []ec2types.GroupIdentifier{ {GroupId: aws.String("sg123"), GroupName: aws.String("my_group")}, {GroupId: aws.String("sg123"), GroupName: aws.String("another_group")}, }, @@ -2496,7 +2423,7 @@ func TestFindSecurityGroupForInstanceMultipleTagged(t *testing.T) { func TestCreateDisk(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) volumeOptions := &VolumeOptions{ AvailabilityZone: "us-east-1a", @@ -2505,10 +2432,10 @@ func TestCreateDisk(t *testing.T) { request := &ec2.CreateVolumeInput{ AvailabilityZone: aws.String("us-east-1a"), Encrypted: aws.Bool(false), - VolumeType: aws.String(DefaultVolumeType), - Size: aws.Int64(10), - TagSpecifications: []*ec2.TagSpecification{ - {ResourceType: aws.String(ec2.ResourceTypeVolume), Tags: []*ec2.Tag{ + VolumeType: DefaultVolumeType, + Size: aws.Int32(10), + TagSpecifications: []ec2types.TagSpecification{ + {ResourceType: ec2types.ResourceTypeVolume, Tags: []ec2types.Tag{ // CreateVolume from MockedFakeEC2 expects sorted tags, so we need to // always have these tags sorted: {Key: aws.String(TagNameKubernetesClusterLegacy), Value: aws.String(TestClusterID)}, @@ -2517,19 +2444,22 @@ func TestCreateDisk(t *testing.T) { }, } - volume := &ec2.Volume{ + awsServices.ec2.(*MockedFakeEC2).On("CreateVolume", context.TODO(), request).Return(&ec2.CreateVolumeOutput{ AvailabilityZone: aws.String("us-east-1a"), VolumeId: aws.String("vol-volumeId0"), - State: aws.String("available"), - } - awsServices.ec2.(*MockedFakeEC2).On("CreateVolume", request).Return(volume, nil) + State: ec2types.VolumeStateAvailable, + }, nil) describeVolumesRequest := &ec2.DescribeVolumesInput{ - VolumeIds: []*string{aws.String("vol-volumeId0")}, + VolumeIds: []string{"vol-volumeId0"}, } - awsServices.ec2.(*MockedFakeEC2).On("DescribeVolumes", describeVolumesRequest).Return([]*ec2.Volume{volume}, nil) + awsServices.ec2.(*MockedFakeEC2).On("DescribeVolumes", context.TODO(), describeVolumesRequest).Return([]ec2types.Volume{{ + AvailabilityZone: aws.String("us-east-1a"), + VolumeId: aws.String("vol-volumeId0"), + State: ec2types.VolumeStateAvailable, + }}, nil) - volumeID, err := c.CreateDisk(volumeOptions) + volumeID, err := c.CreateDisk(context.TODO(), volumeOptions) assert.Nil(t, err, "Error creating disk: %v", err) assert.Equal(t, volumeID, KubernetesVolumeID("aws://us-east-1a/vol-volumeId0")) awsServices.ec2.(*MockedFakeEC2).AssertExpectations(t) @@ -2537,7 +2467,7 @@ func TestCreateDisk(t *testing.T) { func TestCreateDiskFailDescribeVolume(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) volumeOptions := &VolumeOptions{ AvailabilityZone: "us-east-1a", @@ -2546,10 +2476,10 @@ func TestCreateDiskFailDescribeVolume(t *testing.T) { request := &ec2.CreateVolumeInput{ AvailabilityZone: aws.String("us-east-1a"), Encrypted: aws.Bool(false), - VolumeType: aws.String(DefaultVolumeType), - Size: aws.Int64(10), - TagSpecifications: []*ec2.TagSpecification{ - {ResourceType: aws.String(ec2.ResourceTypeVolume), Tags: []*ec2.Tag{ + VolumeType: DefaultVolumeType, + Size: aws.Int32(10), + TagSpecifications: []ec2types.TagSpecification{ + {ResourceType: ec2types.ResourceTypeVolume, Tags: []ec2types.Tag{ // CreateVolume from MockedFakeEC2 expects sorted tags, so we need to // always have these tags sorted: {Key: aws.String(TagNameKubernetesClusterLegacy), Value: aws.String(TestClusterID)}, @@ -2558,23 +2488,26 @@ func TestCreateDiskFailDescribeVolume(t *testing.T) { }, } - volume := &ec2.Volume{ + awsServices.ec2.(*MockedFakeEC2).On("CreateVolume", context.TODO(), request).Return(&ec2.CreateVolumeOutput{ AvailabilityZone: aws.String("us-east-1a"), VolumeId: aws.String("vol-volumeId0"), - State: aws.String("creating"), - } - awsServices.ec2.(*MockedFakeEC2).On("CreateVolume", request).Return(volume, nil) + State: ec2types.VolumeStateCreating, + }, nil) describeVolumesRequest := &ec2.DescribeVolumesInput{ - VolumeIds: []*string{aws.String("vol-volumeId0")}, + VolumeIds: []string{"vol-volumeId0"}, } deleteVolumeRequest := &ec2.DeleteVolumeInput{ VolumeId: aws.String("vol-volumeId0"), } - awsServices.ec2.(*MockedFakeEC2).On("DescribeVolumes", describeVolumesRequest).Return([]*ec2.Volume{volume}, nil) - awsServices.ec2.(*MockedFakeEC2).On("DeleteVolume", deleteVolumeRequest).Return(&ec2.DeleteVolumeOutput{}, nil) + awsServices.ec2.(*MockedFakeEC2).On("DescribeVolumes", context.TODO(), describeVolumesRequest).Return([]ec2types.Volume{{ + AvailabilityZone: aws.String("us-east-1a"), + VolumeId: aws.String("vol-volumeId0"), + State: ec2types.VolumeStateCreating, + }}, nil) + awsServices.ec2.(*MockedFakeEC2).On("DeleteVolume", context.TODO(), deleteVolumeRequest).Return(&ec2.DeleteVolumeOutput{}, nil) - volumeID, err := c.CreateDisk(volumeOptions) + volumeID, err := c.CreateDisk(context.TODO(), volumeOptions) assert.Error(t, err) assert.Equal(t, volumeID, KubernetesVolumeID("")) awsServices.ec2.(*MockedFakeEC2).AssertExpectations(t) @@ -2590,7 +2523,7 @@ const ( func TestNodeNameToInstanceID(t *testing.T) { fakeAWS := newMockedFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, fakeAWS) + c, err := newAWSCloud(CloudConfig{}, fakeAWS, nil) assert.NoError(t, err) fakeClient := &fake.Clientset{} @@ -2700,7 +2633,7 @@ func TestInstanceIDToNodeName(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - awsCloud, err := newAWSCloud(CloudConfig{}, awsServices) + awsCloud, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Fatalf("error creating mock cloud: %v", err) } @@ -2727,39 +2660,39 @@ func informerNotSynced() bool { } type MockedFakeELBV2 struct { - LoadBalancers []*elbv2.LoadBalancer - TargetGroups []*elbv2.TargetGroup - Listeners []*elbv2.Listener + LoadBalancers []elbv2types.LoadBalancer + TargetGroups []elbv2types.TargetGroup + Listeners []elbv2types.Listener // keys on all of these maps are ARNs LoadBalancerAttributes map[string]map[string]string - Tags map[string][]elbv2.Tag + Tags map[string][]elbv2types.Tag RegisteredInstances map[string][]string // value is list of instance IDs } -func (m *MockedFakeELBV2) AddTags(request *elbv2.AddTagsInput) (*elbv2.AddTagsOutput, error) { - for _, arn := range request.ResourceArns { - for _, tag := range request.Tags { - m.Tags[aws.StringValue(arn)] = append(m.Tags[aws.StringValue(arn)], *tag) +func (m *MockedFakeELBV2) AddTags(ctx context.Context, input *elbv2.AddTagsInput, optFns ...func(*elbv2.Options)) (*elbv2.AddTagsOutput, error) { + for _, arn := range input.ResourceArns { + for _, tag := range input.Tags { + m.Tags[arn] = append(m.Tags[arn], tag) } } return &elbv2.AddTagsOutput{}, nil } -func (m *MockedFakeELBV2) CreateLoadBalancer(request *elbv2.CreateLoadBalancerInput) (*elbv2.CreateLoadBalancerOutput, error) { +func (m *MockedFakeELBV2) CreateLoadBalancer(ctx context.Context, input *elbv2.CreateLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateLoadBalancerOutput, error) { accountID := 123456789 arn := fmt.Sprintf("arn:aws:elasticloadbalancing:us-east-1:%d:loadbalancer/net/%x/%x", accountID, rand.Uint64(), rand.Uint32()) - newLB := &elbv2.LoadBalancer{ + newLB := elbv2types.LoadBalancer{ LoadBalancerArn: aws.String(arn), - LoadBalancerName: request.Name, - Type: aws.String(elbv2.LoadBalancerTypeEnumNetwork), + LoadBalancerName: input.Name, + Type: elbv2types.LoadBalancerTypeEnumNetwork, VpcId: aws.String("vpc-abc123def456abc78"), - AvailabilityZones: []*elbv2.AvailabilityZone{ + AvailabilityZones: []elbv2types.AvailabilityZone{ { ZoneName: aws.String("us-west-2a"), SubnetId: aws.String("subnet-abc123de"), @@ -2769,35 +2702,35 @@ func (m *MockedFakeELBV2) CreateLoadBalancer(request *elbv2.CreateLoadBalancerIn m.LoadBalancers = append(m.LoadBalancers, newLB) return &elbv2.CreateLoadBalancerOutput{ - LoadBalancers: []*elbv2.LoadBalancer{newLB}, + LoadBalancers: []elbv2types.LoadBalancer{newLB}, }, nil } -func (m *MockedFakeELBV2) DescribeLoadBalancers(request *elbv2.DescribeLoadBalancersInput) (*elbv2.DescribeLoadBalancersOutput, error) { +func (m *MockedFakeELBV2) DescribeLoadBalancers(ctx context.Context, input *elbv2.DescribeLoadBalancersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancersOutput, error) { findMeNames := make(map[string]bool) - for _, name := range request.Names { - findMeNames[aws.StringValue(name)] = true + for _, name := range input.Names { + findMeNames[name] = true } findMeARNs := make(map[string]bool) - for _, arn := range request.LoadBalancerArns { - findMeARNs[aws.StringValue(arn)] = true + for _, arn := range input.LoadBalancerArns { + findMeARNs[arn] = true } - result := []*elbv2.LoadBalancer{} + result := []elbv2types.LoadBalancer{} for _, lb := range m.LoadBalancers { - if _, present := findMeNames[aws.StringValue(lb.LoadBalancerName)]; present { + if _, present := findMeNames[aws.ToString(lb.LoadBalancerName)]; present { result = append(result, lb) - delete(findMeNames, aws.StringValue(lb.LoadBalancerName)) - } else if _, present := findMeARNs[aws.StringValue(lb.LoadBalancerArn)]; present { + delete(findMeNames, aws.ToString(lb.LoadBalancerName)) + } else if _, present := findMeARNs[aws.ToString(lb.LoadBalancerArn)]; present { result = append(result, lb) - delete(findMeARNs, aws.StringValue(lb.LoadBalancerArn)) + delete(findMeARNs, aws.ToString(lb.LoadBalancerArn)) } } if len(findMeNames) > 0 || len(findMeARNs) > 0 { - return nil, awserr.New(elbv2.ErrCodeLoadBalancerNotFoundException, "not found", nil) + return nil, &elbv2types.LoadBalancerNotFoundException{Message: aws.String("not found")} } return &elbv2.DescribeLoadBalancersOutput{ @@ -2805,33 +2738,33 @@ func (m *MockedFakeELBV2) DescribeLoadBalancers(request *elbv2.DescribeLoadBalan }, nil } -func (m *MockedFakeELBV2) DeleteLoadBalancer(*elbv2.DeleteLoadBalancerInput) (*elbv2.DeleteLoadBalancerOutput, error) { +func (m *MockedFakeELBV2) DeleteLoadBalancer(ctx context.Context, input *elbv2.DeleteLoadBalancerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteLoadBalancerOutput, error) { panic("Not implemented") } -func (m *MockedFakeELBV2) ModifyLoadBalancerAttributes(request *elbv2.ModifyLoadBalancerAttributesInput) (*elbv2.ModifyLoadBalancerAttributesOutput, error) { - attrMap, present := m.LoadBalancerAttributes[aws.StringValue(request.LoadBalancerArn)] +func (m *MockedFakeELBV2) ModifyLoadBalancerAttributes(ctx context.Context, input *elbv2.ModifyLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyLoadBalancerAttributesOutput, error) { + attrMap, present := m.LoadBalancerAttributes[aws.ToString(input.LoadBalancerArn)] if !present { attrMap = make(map[string]string) - m.LoadBalancerAttributes[aws.StringValue(request.LoadBalancerArn)] = attrMap + m.LoadBalancerAttributes[aws.ToString(input.LoadBalancerArn)] = attrMap } - for _, attr := range request.Attributes { - attrMap[aws.StringValue(attr.Key)] = aws.StringValue(attr.Value) + for _, attr := range input.Attributes { + attrMap[aws.ToString(attr.Key)] = aws.ToString(attr.Value) } return &elbv2.ModifyLoadBalancerAttributesOutput{ - Attributes: request.Attributes, + Attributes: input.Attributes, }, nil } -func (m *MockedFakeELBV2) DescribeLoadBalancerAttributes(request *elbv2.DescribeLoadBalancerAttributesInput) (*elbv2.DescribeLoadBalancerAttributesOutput, error) { - attrs := []*elbv2.LoadBalancerAttribute{} +func (m *MockedFakeELBV2) DescribeLoadBalancerAttributes(ctx context.Context, input *elbv2.DescribeLoadBalancerAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeLoadBalancerAttributesOutput, error) { + attrs := []elbv2types.LoadBalancerAttribute{} - if lbAttrs, present := m.LoadBalancerAttributes[aws.StringValue(request.LoadBalancerArn)]; present { + if lbAttrs, present := m.LoadBalancerAttributes[aws.ToString(input.LoadBalancerArn)]; present { for key, value := range lbAttrs { - attrs = append(attrs, &elbv2.LoadBalancerAttribute{ + attrs = append(attrs, elbv2types.LoadBalancerAttribute{ Key: aws.String(key), Value: aws.String(value), }) @@ -2843,65 +2776,65 @@ func (m *MockedFakeELBV2) DescribeLoadBalancerAttributes(request *elbv2.Describe }, nil } -func (m *MockedFakeELBV2) CreateTargetGroup(request *elbv2.CreateTargetGroupInput) (*elbv2.CreateTargetGroupOutput, error) { +func (m *MockedFakeELBV2) CreateTargetGroup(ctx context.Context, input *elbv2.CreateTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateTargetGroupOutput, error) { accountID := 123456789 arn := fmt.Sprintf("arn:aws:elasticloadbalancing:us-east-1:%d:targetgroup/%x/%x", accountID, rand.Uint64(), rand.Uint32()) - newTG := &elbv2.TargetGroup{ + newTG := elbv2types.TargetGroup{ TargetGroupArn: aws.String(arn), - TargetGroupName: request.Name, - Port: request.Port, - Protocol: request.Protocol, - HealthCheckProtocol: request.HealthCheckProtocol, - HealthCheckPath: request.HealthCheckPath, - HealthCheckPort: request.HealthCheckPort, - HealthCheckTimeoutSeconds: request.HealthCheckTimeoutSeconds, - HealthCheckIntervalSeconds: request.HealthCheckIntervalSeconds, - HealthyThresholdCount: request.HealthyThresholdCount, - UnhealthyThresholdCount: request.UnhealthyThresholdCount, + TargetGroupName: input.Name, + Port: input.Port, + Protocol: input.Protocol, + HealthCheckProtocol: input.HealthCheckProtocol, + HealthCheckPath: input.HealthCheckPath, + HealthCheckPort: input.HealthCheckPort, + HealthCheckTimeoutSeconds: input.HealthCheckTimeoutSeconds, + HealthCheckIntervalSeconds: input.HealthCheckIntervalSeconds, + HealthyThresholdCount: input.HealthyThresholdCount, + UnhealthyThresholdCount: input.UnhealthyThresholdCount, } m.TargetGroups = append(m.TargetGroups, newTG) return &elbv2.CreateTargetGroupOutput{ - TargetGroups: []*elbv2.TargetGroup{newTG}, + TargetGroups: []elbv2types.TargetGroup{newTG}, }, nil } -func (m *MockedFakeELBV2) DescribeTargetGroups(request *elbv2.DescribeTargetGroupsInput) (*elbv2.DescribeTargetGroupsOutput, error) { - var targetGroups []*elbv2.TargetGroup +func (m *MockedFakeELBV2) DescribeTargetGroups(ctx context.Context, input *elbv2.DescribeTargetGroupsInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupsOutput, error) { + var targetGroups []elbv2types.TargetGroup - if request.LoadBalancerArn != nil { - targetGroups = []*elbv2.TargetGroup{} + if input.LoadBalancerArn != nil { + targetGroups = []elbv2types.TargetGroup{} for _, tg := range m.TargetGroups { for _, lbArn := range tg.LoadBalancerArns { - if aws.StringValue(lbArn) == aws.StringValue(request.LoadBalancerArn) { + if lbArn == aws.ToString(input.LoadBalancerArn) { targetGroups = append(targetGroups, tg) break } } } - } else if len(request.Names) != 0 { - targetGroups = []*elbv2.TargetGroup{} + } else if len(input.Names) != 0 { + targetGroups = []elbv2types.TargetGroup{} for _, tg := range m.TargetGroups { - for _, name := range request.Names { - if aws.StringValue(tg.TargetGroupName) == aws.StringValue(name) { + for _, name := range input.Names { + if aws.ToString(tg.TargetGroupName) == name { targetGroups = append(targetGroups, tg) break } } } - } else if len(request.TargetGroupArns) != 0 { - targetGroups = []*elbv2.TargetGroup{} + } else if len(input.TargetGroupArns) != 0 { + targetGroups = []elbv2types.TargetGroup{} for _, tg := range m.TargetGroups { - for _, arn := range request.TargetGroupArns { - if aws.StringValue(tg.TargetGroupArn) == aws.StringValue(arn) { + for _, arn := range input.TargetGroupArns { + if aws.ToString(tg.TargetGroupArn) == arn { targetGroups = append(targetGroups, tg) break } @@ -2916,46 +2849,46 @@ func (m *MockedFakeELBV2) DescribeTargetGroups(request *elbv2.DescribeTargetGrou }, nil } -func (m *MockedFakeELBV2) ModifyTargetGroup(request *elbv2.ModifyTargetGroupInput) (*elbv2.ModifyTargetGroupOutput, error) { - var matchingTargetGroup *elbv2.TargetGroup - dirtyGroups := []*elbv2.TargetGroup{} +func (m *MockedFakeELBV2) ModifyTargetGroup(ctx context.Context, input *elbv2.ModifyTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupOutput, error) { + var matchingTargetGroup *elbv2types.TargetGroup + dirtyGroups := []elbv2types.TargetGroup{} for _, tg := range m.TargetGroups { - if aws.StringValue(tg.TargetGroupArn) == aws.StringValue(request.TargetGroupArn) { - matchingTargetGroup = tg + if aws.ToString(tg.TargetGroupArn) == aws.ToString(input.TargetGroupArn) { + matchingTargetGroup = &tg break } } if matchingTargetGroup != nil { - dirtyGroups = append(dirtyGroups, matchingTargetGroup) + dirtyGroups = append(dirtyGroups, *matchingTargetGroup) - if request.HealthCheckEnabled != nil { - matchingTargetGroup.HealthCheckEnabled = request.HealthCheckEnabled + if input.HealthCheckEnabled != nil { + matchingTargetGroup.HealthCheckEnabled = input.HealthCheckEnabled } - if request.HealthCheckIntervalSeconds != nil { - matchingTargetGroup.HealthCheckIntervalSeconds = request.HealthCheckIntervalSeconds + if input.HealthCheckIntervalSeconds != nil { + matchingTargetGroup.HealthCheckIntervalSeconds = input.HealthCheckIntervalSeconds } - if request.HealthCheckPath != nil { - matchingTargetGroup.HealthCheckPath = request.HealthCheckPath + if input.HealthCheckPath != nil { + matchingTargetGroup.HealthCheckPath = input.HealthCheckPath } - if request.HealthCheckPort != nil { - matchingTargetGroup.HealthCheckPort = request.HealthCheckPort + if input.HealthCheckPort != nil { + matchingTargetGroup.HealthCheckPort = input.HealthCheckPort } - if request.HealthCheckProtocol != nil { - matchingTargetGroup.HealthCheckProtocol = request.HealthCheckProtocol + if string(input.HealthCheckProtocol) != "" { + matchingTargetGroup.HealthCheckProtocol = input.HealthCheckProtocol } - if request.HealthCheckTimeoutSeconds != nil { - matchingTargetGroup.HealthCheckTimeoutSeconds = request.HealthCheckTimeoutSeconds + if input.HealthCheckTimeoutSeconds != nil { + matchingTargetGroup.HealthCheckTimeoutSeconds = input.HealthCheckTimeoutSeconds } - if request.HealthyThresholdCount != nil { - matchingTargetGroup.HealthyThresholdCount = request.HealthyThresholdCount + if input.HealthyThresholdCount != nil { + matchingTargetGroup.HealthyThresholdCount = input.HealthyThresholdCount } - if request.Matcher != nil { - matchingTargetGroup.Matcher = request.Matcher + if input.Matcher != nil { + matchingTargetGroup.Matcher = input.Matcher } - if request.UnhealthyThresholdCount != nil { - matchingTargetGroup.UnhealthyThresholdCount = request.UnhealthyThresholdCount + if input.UnhealthyThresholdCount != nil { + matchingTargetGroup.UnhealthyThresholdCount = input.UnhealthyThresholdCount } } @@ -2964,44 +2897,44 @@ func (m *MockedFakeELBV2) ModifyTargetGroup(request *elbv2.ModifyTargetGroupInpu }, nil } -func (m *MockedFakeELBV2) DeleteTargetGroup(request *elbv2.DeleteTargetGroupInput) (*elbv2.DeleteTargetGroupOutput, error) { - newTargetGroups := []*elbv2.TargetGroup{} +func (m *MockedFakeELBV2) DeleteTargetGroup(ctx context.Context, input *elbv2.DeleteTargetGroupInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteTargetGroupOutput, error) { + newTargetGroups := []elbv2types.TargetGroup{} for _, tg := range m.TargetGroups { - if aws.StringValue(tg.TargetGroupArn) != aws.StringValue(request.TargetGroupArn) { + if aws.ToString(tg.TargetGroupArn) != aws.ToString(input.TargetGroupArn) { newTargetGroups = append(newTargetGroups, tg) } } m.TargetGroups = newTargetGroups - delete(m.RegisteredInstances, aws.StringValue(request.TargetGroupArn)) + delete(m.RegisteredInstances, aws.ToString(input.TargetGroupArn)) return &elbv2.DeleteTargetGroupOutput{}, nil } -func (m *MockedFakeELBV2) DescribeTargetHealth(request *elbv2.DescribeTargetHealthInput) (*elbv2.DescribeTargetHealthOutput, error) { - healthDescriptions := []*elbv2.TargetHealthDescription{} +func (m *MockedFakeELBV2) DescribeTargetHealth(ctx context.Context, input *elbv2.DescribeTargetHealthInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetHealthOutput, error) { + healthDescriptions := []elbv2types.TargetHealthDescription{} - var matchingTargetGroup *elbv2.TargetGroup + var matchingTargetGroup elbv2types.TargetGroup for _, tg := range m.TargetGroups { - if aws.StringValue(tg.TargetGroupArn) == aws.StringValue(request.TargetGroupArn) { + if aws.ToString(tg.TargetGroupArn) == aws.ToString(input.TargetGroupArn) { matchingTargetGroup = tg break } } - if registeredTargets, present := m.RegisteredInstances[aws.StringValue(request.TargetGroupArn)]; present { + if registeredTargets, present := m.RegisteredInstances[aws.ToString(input.TargetGroupArn)]; present { for _, target := range registeredTargets { - healthDescriptions = append(healthDescriptions, &elbv2.TargetHealthDescription{ + healthDescriptions = append(healthDescriptions, elbv2types.TargetHealthDescription{ HealthCheckPort: matchingTargetGroup.HealthCheckPort, - Target: &elbv2.TargetDescription{ + Target: &elbv2types.TargetDescription{ Id: aws.String(target), Port: matchingTargetGroup.Port, }, - TargetHealth: &elbv2.TargetHealth{ - State: aws.String("healthy"), + TargetHealth: &elbv2types.TargetHealth{ + State: elbv2types.TargetHealthStateEnumHealthy, }, }) } @@ -3012,46 +2945,46 @@ func (m *MockedFakeELBV2) DescribeTargetHealth(request *elbv2.DescribeTargetHeal }, nil } -func (m *MockedFakeELBV2) DescribeTargetGroupAttributes(*elbv2.DescribeTargetGroupAttributesInput) (*elbv2.DescribeTargetGroupAttributesOutput, error) { +func (m *MockedFakeELBV2) DescribeTargetGroupAttributes(ctx context.Context, input *elbv2.DescribeTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeTargetGroupAttributesOutput, error) { panic("Not implemented") } -func (m *MockedFakeELBV2) ModifyTargetGroupAttributes(*elbv2.ModifyTargetGroupAttributesInput) (*elbv2.ModifyTargetGroupAttributesOutput, error) { +func (m *MockedFakeELBV2) ModifyTargetGroupAttributes(ctx context.Context, input *elbv2.ModifyTargetGroupAttributesInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyTargetGroupAttributesOutput, error) { panic("Not implemented") } -func (m *MockedFakeELBV2) RegisterTargets(request *elbv2.RegisterTargetsInput) (*elbv2.RegisterTargetsOutput, error) { - arn := aws.StringValue(request.TargetGroupArn) +func (m *MockedFakeELBV2) RegisterTargets(ctx context.Context, input *elbv2.RegisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.RegisterTargetsOutput, error) { + arn := aws.ToString(input.TargetGroupArn) alreadyExists := make(map[string]bool) for _, targetID := range m.RegisteredInstances[arn] { alreadyExists[targetID] = true } - for _, target := range request.Targets { - if !alreadyExists[aws.StringValue(target.Id)] { - m.RegisteredInstances[arn] = append(m.RegisteredInstances[arn], aws.StringValue(target.Id)) + for _, target := range input.Targets { + if !alreadyExists[aws.ToString(target.Id)] { + m.RegisteredInstances[arn] = append(m.RegisteredInstances[arn], aws.ToString(target.Id)) } } return &elbv2.RegisterTargetsOutput{}, nil } -func (m *MockedFakeELBV2) DeregisterTargets(request *elbv2.DeregisterTargetsInput) (*elbv2.DeregisterTargetsOutput, error) { +func (m *MockedFakeELBV2) DeregisterTargets(ctx context.Context, input *elbv2.DeregisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.DeregisterTargetsOutput, error) { removeMe := make(map[string]bool) - for _, target := range request.Targets { - removeMe[aws.StringValue(target.Id)] = true + for _, target := range input.Targets { + removeMe[aws.ToString(target.Id)] = true } newRegisteredInstancesForArn := []string{} - for _, targetID := range m.RegisteredInstances[aws.StringValue(request.TargetGroupArn)] { + for _, targetID := range m.RegisteredInstances[aws.ToString(input.TargetGroupArn)] { if !removeMe[targetID] { newRegisteredInstancesForArn = append(newRegisteredInstancesForArn, targetID) } } - m.RegisteredInstances[aws.StringValue(request.TargetGroupArn)] = newRegisteredInstancesForArn + m.RegisteredInstances[aws.ToString(input.TargetGroupArn)] = newRegisteredInstancesForArn return &elbv2.DeregisterTargetsOutput{}, nil } -func (m *MockedFakeELBV2) CreateListener(request *elbv2.CreateListenerInput) (*elbv2.CreateListenerOutput, error) { +func (m *MockedFakeELBV2) CreateListener(ctx context.Context, input *elbv2.CreateListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.CreateListenerOutput, error) { accountID := 123456789 arn := fmt.Sprintf("arn:aws:elasticloadbalancing:us-east-1:%d:listener/net/%x/%x/%x", accountID, @@ -3059,40 +2992,40 @@ func (m *MockedFakeELBV2) CreateListener(request *elbv2.CreateListenerInput) (*e rand.Uint32(), rand.Uint32()) - newListener := &elbv2.Listener{ + newListener := elbv2types.Listener{ ListenerArn: aws.String(arn), - Port: request.Port, - Protocol: request.Protocol, - DefaultActions: request.DefaultActions, - LoadBalancerArn: request.LoadBalancerArn, + Port: input.Port, + Protocol: input.Protocol, + DefaultActions: input.DefaultActions, + LoadBalancerArn: input.LoadBalancerArn, } m.Listeners = append(m.Listeners, newListener) for _, tg := range m.TargetGroups { - for _, action := range request.DefaultActions { - if aws.StringValue(action.TargetGroupArn) == aws.StringValue(tg.TargetGroupArn) { - tg.LoadBalancerArns = append(tg.LoadBalancerArns, request.LoadBalancerArn) + for _, action := range input.DefaultActions { + if aws.ToString(action.TargetGroupArn) == aws.ToString(tg.TargetGroupArn) { + tg.LoadBalancerArns = append(tg.LoadBalancerArns, aws.ToString(input.LoadBalancerArn)) break } } } return &elbv2.CreateListenerOutput{ - Listeners: []*elbv2.Listener{newListener}, + Listeners: []elbv2types.Listener{newListener}, }, nil } -func (m *MockedFakeELBV2) DescribeListeners(request *elbv2.DescribeListenersInput) (*elbv2.DescribeListenersOutput, error) { - if len(request.ListenerArns) == 0 && request.LoadBalancerArn == nil { +func (m *MockedFakeELBV2) DescribeListeners(ctx context.Context, input *elbv2.DescribeListenersInput, optFns ...func(*elbv2.Options)) (*elbv2.DescribeListenersOutput, error) { + if len(input.ListenerArns) == 0 && input.LoadBalancerArn == nil { return &elbv2.DescribeListenersOutput{ Listeners: m.Listeners, }, nil - } else if len(request.ListenerArns) == 0 { - listeners := []*elbv2.Listener{} + } else if len(input.ListenerArns) == 0 { + listeners := []elbv2types.Listener{} for _, lb := range m.Listeners { - if aws.StringValue(lb.LoadBalancerArn) == aws.StringValue(request.LoadBalancerArn) { + if aws.ToString(lb.LoadBalancerArn) == aws.ToString(input.LoadBalancerArn) { listeners = append(listeners, lb) } } @@ -3104,31 +3037,32 @@ func (m *MockedFakeELBV2) DescribeListeners(request *elbv2.DescribeListenersInpu panic("Not implemented") } -func (m *MockedFakeELBV2) DeleteListener(*elbv2.DeleteListenerInput) (*elbv2.DeleteListenerOutput, error) { +func (m *MockedFakeELBV2) DeleteListener(ctx context.Context, input *elbv2.DeleteListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.DeleteListenerOutput, error) { panic("Not implemented") } -func (m *MockedFakeELBV2) ModifyListener(request *elbv2.ModifyListenerInput) (*elbv2.ModifyListenerOutput, error) { - modifiedListeners := []*elbv2.Listener{} +func (m *MockedFakeELBV2) ModifyListener(ctx context.Context, input *elbv2.ModifyListenerInput, optFns ...func(*elbv2.Options)) (*elbv2.ModifyListenerOutput, error) { - for _, listener := range m.Listeners { - if aws.StringValue(listener.ListenerArn) == aws.StringValue(request.ListenerArn) { - if request.DefaultActions != nil { + modifiedListeners := []elbv2types.Listener{} + for i := range m.Listeners { + listener := &m.Listeners[i] + if aws.ToString(listener.ListenerArn) == aws.ToString(input.ListenerArn) { + if input.DefaultActions != nil { // for each old action, find the corresponding target group, and remove the listener's LB ARN from its list for _, action := range listener.DefaultActions { - var targetGroupForAction *elbv2.TargetGroup + var targetGroupForAction *elbv2types.TargetGroup for _, tg := range m.TargetGroups { - if aws.StringValue(action.TargetGroupArn) == aws.StringValue(tg.TargetGroupArn) { - targetGroupForAction = tg + if aws.ToString(action.TargetGroupArn) == aws.ToString(tg.TargetGroupArn) { + targetGroupForAction = &tg break } } if targetGroupForAction != nil { - newLoadBalancerARNs := []*string{} + newLoadBalancerARNs := []string{} for _, lbArn := range targetGroupForAction.LoadBalancerArns { - if aws.StringValue(lbArn) != aws.StringValue(listener.LoadBalancerArn) { + if lbArn != aws.ToString(listener.LoadBalancerArn) { newLoadBalancerARNs = append(newLoadBalancerARNs, lbArn) } } @@ -3137,33 +3071,34 @@ func (m *MockedFakeELBV2) ModifyListener(request *elbv2.ModifyListenerInput) (*e } } - listener.DefaultActions = request.DefaultActions + listener.DefaultActions = input.DefaultActions // for each new action, add the listener's LB ARN to that action's target groups' lists - for _, action := range request.DefaultActions { - var targetGroupForAction *elbv2.TargetGroup + for _, action := range input.DefaultActions { + var targetGroupForAction *elbv2types.TargetGroup for _, tg := range m.TargetGroups { - if aws.StringValue(action.TargetGroupArn) == aws.StringValue(tg.TargetGroupArn) { - targetGroupForAction = tg + if aws.ToString(action.TargetGroupArn) == aws.ToString(tg.TargetGroupArn) { + targetGroupForAction = &tg break } } if targetGroupForAction != nil { - targetGroupForAction.LoadBalancerArns = append(targetGroupForAction.LoadBalancerArns, listener.LoadBalancerArn) + targetGroupForAction.LoadBalancerArns = append(targetGroupForAction.LoadBalancerArns, aws.ToString(listener.LoadBalancerArn)) } } } - if request.Port != nil { - listener.Port = request.Port + if input.Port != nil { + listener.Port = input.Port } - if request.Protocol != nil { - listener.Protocol = request.Protocol + if string(input.Protocol) != "" { + listener.Protocol = input.Protocol } - modifiedListeners = append(modifiedListeners, listener) + modifiedListeners = append(modifiedListeners, *listener) } + } return &elbv2.ModifyListenerOutput{ @@ -3171,34 +3106,30 @@ func (m *MockedFakeELBV2) ModifyListener(request *elbv2.ModifyListenerInput) (*e }, nil } -func (m *MockedFakeELBV2) WaitUntilLoadBalancersDeleted(*elbv2.DescribeLoadBalancersInput) error { - panic("Not implemented") -} - func (m *MockedFakeEC2) maybeExpectDescribeSecurityGroups(clusterID, groupName string) { - tags := []*ec2.Tag{ + tags := []ec2types.Tag{ {Key: aws.String(TagNameKubernetesClusterLegacy), Value: aws.String(clusterID)}, {Key: aws.String(fmt.Sprintf("%s%s", TagNameKubernetesClusterPrefix, clusterID)), Value: aws.String(ResourceLifecycleOwned)}, } - m.On("DescribeSecurityGroups", &ec2.DescribeSecurityGroupsInput{Filters: []*ec2.Filter{ + m.On("DescribeSecurityGroups", context.TODO(), &ec2.DescribeSecurityGroupsInput{Filters: []ec2types.Filter{ newEc2Filter("group-name", groupName), newEc2Filter("vpc-id", ""), - }}).Maybe().Return([]*ec2.SecurityGroup{{Tags: tags}}) + }}).Maybe().Return([]ec2types.SecurityGroup{{Tags: tags}}) - m.On("DescribeSecurityGroups", &ec2.DescribeSecurityGroupsInput{}).Maybe().Return([]*ec2.SecurityGroup{{Tags: tags}}) + m.On("DescribeSecurityGroups", context.TODO(), &ec2.DescribeSecurityGroupsInput{}).Maybe().Return([]ec2types.SecurityGroup{{Tags: tags}}) } func TestNLBNodeRegistration(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - awsServices.elbv2 = &MockedFakeELBV2{Tags: make(map[string][]elbv2.Tag), RegisteredInstances: make(map[string][]string), LoadBalancerAttributes: make(map[string]map[string]string)} - c, _ := newAWSCloud(CloudConfig{}, awsServices) + awsServices.elbv2 = &MockedFakeELBV2{Tags: make(map[string][]elbv2types.Tag), RegisteredInstances: make(map[string][]string), LoadBalancerAttributes: make(map[string]map[string]string)} + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) - awsServices.ec2.(*MockedFakeEC2).Subnets = []*ec2.Subnet{ + awsServices.ec2.(*MockedFakeEC2).Subnets = []ec2types.Subnet{ { AvailabilityZone: aws.String("us-west-2a"), SubnetId: aws.String("subnet-abc123de"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(c.tagging.clusterTagKey()), Value: aws.String("owned"), @@ -3207,9 +3138,9 @@ func TestNLBNodeRegistration(t *testing.T) { }, } - awsServices.ec2.(*MockedFakeEC2).RouteTables = []*ec2.RouteTable{ + awsServices.ec2.(*MockedFakeEC2).RouteTables = []ec2types.RouteTable{ { - Associations: []*ec2.RouteTableAssociation{ + Associations: []ec2types.RouteTableAssociation{ { Main: aws.Bool(true), RouteTableAssociationId: aws.String("rtbassoc-abc123def456abc78"), @@ -3218,11 +3149,11 @@ func TestNLBNodeRegistration(t *testing.T) { }, }, RouteTableId: aws.String("rtb-abc123def456abc78"), - Routes: []*ec2.Route{ + Routes: []ec2types.Route{ { DestinationCidrBlock: aws.String("0.0.0.0/0"), GatewayId: aws.String("igw-abc123def456abc78"), - State: aws.String("active"), + State: ec2types.RouteStateActive, }, }, }, @@ -3284,29 +3215,29 @@ func TestNLBNodeRegistration(t *testing.T) { } fauxService.Annotations[ServiceAnnotationLoadBalancerHealthCheckProtocol] = "http" - tgARN := aws.StringValue(awsServices.elbv2.(*MockedFakeELBV2).Listeners[0].DefaultActions[0].TargetGroupArn) + tgARN := aws.ToString(awsServices.elbv2.(*MockedFakeELBV2).Listeners[0].DefaultActions[0].TargetGroupArn) _, err = c.EnsureLoadBalancer(context.TODO(), TestClusterName, fauxService, nodes) if err != nil { t.Errorf("EnsureLoadBalancer returned an error: %v", err) } assert.Equal(t, 1, len(awsServices.elbv2.(*MockedFakeELBV2).Listeners)) - assert.NotEqual(t, tgARN, aws.StringValue(awsServices.elbv2.(*MockedFakeELBV2).Listeners[0].DefaultActions[0].TargetGroupArn)) + assert.NotEqual(t, tgARN, aws.ToString(awsServices.elbv2.(*MockedFakeELBV2).Listeners[0].DefaultActions[0].TargetGroupArn)) } func makeNamedNode(s *FakeAWSServices, offset int, name string) *v1.Node { instanceID := fmt.Sprintf("i-%x", int64(0x02bce90670bb0c7cd)+int64(offset)) - instance := &ec2.Instance{} + instance := &ec2types.Instance{} instance.InstanceId = aws.String(instanceID) - instance.Placement = &ec2.Placement{ + instance.Placement = &ec2types.Placement{ AvailabilityZone: aws.String("us-east-1c"), } instance.PrivateDnsName = aws.String(fmt.Sprintf("ip-172-20-0-%d.ec2.internal", 101+offset)) instance.PrivateIpAddress = aws.String(fmt.Sprintf("192.168.0.%d", 1+offset)) - var tag ec2.Tag + var tag ec2types.Tag tag.Key = aws.String(TagNameKubernetesClusterLegacy) tag.Value = aws.String(TestClusterID) - instance.Tags = []*ec2.Tag{&tag} + instance.Tags = []ec2types.Tag{tag} s.instances = append(s.instances, instance) @@ -3437,7 +3368,7 @@ func TestCloud_buildNLBHealthCheckConfiguration(t *testing.T) { }, want: healthCheckConfig{ Port: "traffic-port", - Protocol: elbv2.ProtocolEnumTcp, + Protocol: elbv2types.ProtocolEnumTcp, Interval: 30, Timeout: 10, HealthyThreshold: 3, @@ -3471,7 +3402,7 @@ func TestCloud_buildNLBHealthCheckConfiguration(t *testing.T) { want: healthCheckConfig{ Port: "32213", Path: "/healthz", - Protocol: elbv2.ProtocolEnumHttp, + Protocol: elbv2types.ProtocolEnumHttp, Interval: 10, Timeout: 10, HealthyThreshold: 2, @@ -3616,7 +3547,7 @@ func TestCloud_buildNLBHealthCheckConfiguration(t *testing.T) { }, want: healthCheckConfig{ Port: "traffic-port", - Protocol: elbv2.ProtocolEnumTcp, + Protocol: elbv2types.ProtocolEnumTcp, Interval: 23, Timeout: 10, HealthyThreshold: 3, @@ -3674,7 +3605,7 @@ func TestCloud_buildNLBHealthCheckConfiguration(t *testing.T) { }, want: healthCheckConfig{ Port: "traffic-port", - Protocol: elbv2.ProtocolEnumTcp, + Protocol: elbv2types.ProtocolEnumTcp, Interval: 30, Timeout: 10, HealthyThreshold: 7, @@ -3750,7 +3681,7 @@ func Test_parseStringSliceAnnotation(t *testing.T) { func TestNodeAddressesForFargate(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) nodeAddresses, _ := c.NodeAddressesByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/fargate-ip-return-private-dns-name.us-west-2.compute.internal") verifyNodeAddressesForFargate(t, "IPV4", true, nodeAddresses) @@ -3758,7 +3689,7 @@ func TestNodeAddressesForFargate(t *testing.T) { func TestNodeAddressesForFargateIPV6Family(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) c.cfg.Global.NodeIPFamilies = []string{"ipv6"} nodeAddresses, _ := c.NodeAddressesByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/fargate-ip-return-private-dns-name-ipv6.us-west-2.compute.internal") @@ -3767,7 +3698,7 @@ func TestNodeAddressesForFargateIPV6Family(t *testing.T) { func TestNodeAddressesForFargatePrivateIP(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) nodeAddresses, _ := c.NodeAddressesByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/fargate-192.168.164.88") verifyNodeAddressesForFargate(t, "IPV4", false, nodeAddresses) @@ -3792,7 +3723,7 @@ func verifyNodeAddressesForFargate(t *testing.T, ipFamily string, verifyPublicIP func TestNodeAddressesOrderedByDeviceIndex(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) nodeAddresses, _ := c.NodeAddressesByProviderID(context.TODO(), "aws:///us-west-2a/i-self") expectedAddresses := []v1.NodeAddress{ @@ -3809,7 +3740,7 @@ func TestNodeAddressesOrderedByDeviceIndex(t *testing.T) { func TestInstanceExistsByProviderIDForFargate(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) instanceExist, err := c.InstanceExistsByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/fargate-192.168.164.88") assert.Nil(t, err) @@ -3818,7 +3749,7 @@ func TestInstanceExistsByProviderIDForFargate(t *testing.T) { func TestInstanceExistsByProviderIDWithNodeNameForFargate(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) instanceExist, err := c.InstanceExistsByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/fargate-ip-192-168-164-88.us-west-2.compute.internal") assert.Nil(t, err) @@ -3829,7 +3760,7 @@ func TestInstanceExistsByProviderIDForInstanceNotFound(t *testing.T) { mockedEC2API := newMockedEC2API() c := &Cloud{ec2: &awsSdkEC2{ec2: mockedEC2API}} - mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{}, awserr.New("InvalidInstanceID.NotFound", "Instance not found", nil)) + mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{}, errors.New("InvalidInstanceID.NotFound: Instance not found")) instanceExists, err := c.InstanceExistsByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/i-not-found") assert.Nil(t, err) @@ -3838,7 +3769,7 @@ func TestInstanceExistsByProviderIDForInstanceNotFound(t *testing.T) { func TestInstanceNotExistsByProviderIDForFargate(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) instanceExist, err := c.InstanceExistsByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/fargate-not-found") assert.Nil(t, err) @@ -3847,7 +3778,7 @@ func TestInstanceNotExistsByProviderIDForFargate(t *testing.T) { func TestInstanceShutdownByProviderIDForFargate(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) instanceExist, err := c.InstanceShutdownByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/fargate-192.168.164.88") assert.Nil(t, err) @@ -3856,7 +3787,7 @@ func TestInstanceShutdownByProviderIDForFargate(t *testing.T) { func TestInstanceShutdownNotExistsByProviderIDForFargate(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) instanceExist, err := c.InstanceShutdownByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/fargate-not-found") assert.Nil(t, err) @@ -3865,7 +3796,7 @@ func TestInstanceShutdownNotExistsByProviderIDForFargate(t *testing.T) { func TestInstanceTypeByProviderIDForFargate(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) instanceType, err := c.InstanceTypeByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/fargate-not-found") assert.Nil(t, err) @@ -3874,7 +3805,7 @@ func TestInstanceTypeByProviderIDForFargate(t *testing.T) { func TestGetZoneByProviderIDForFargate(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) - c, _ := newAWSCloud(CloudConfig{}, awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices, nil) zoneDetails, err := c.GetZoneByProviderID(context.TODO(), "aws:///us-west-2c/1abc-2def/fargate-192.168.164.88") assert.Nil(t, err) @@ -3886,27 +3817,27 @@ func TestGetRegionFromMetadata(t *testing.T) { // Returns region from zone if set cfg := CloudConfig{} cfg.Global.Zone = "us-west-2a" - region, err := getRegionFromMetadata(cfg, awsServices.metadata) + region, err := getRegionFromMetadata(context.TODO(), cfg, awsServices.metadata) assert.NoError(t, err) assert.Equal(t, "us-west-2", region) // Returns error if can map to region cfg = CloudConfig{} cfg.Global.Zone = "some-fake-zone" - _, err = getRegionFromMetadata(cfg, awsServices.metadata) + _, err = getRegionFromMetadata(context.TODO(), cfg, awsServices.metadata) assert.Error(t, err) // Returns region from metadata if zone unset cfg = CloudConfig{} - region, err = getRegionFromMetadata(cfg, awsServices.metadata) + region, err = getRegionFromMetadata(context.TODO(), cfg, awsServices.metadata) assert.NoError(t, err) assert.Equal(t, "us-east-1", region) } type MockedEC2API struct { - ec2iface.EC2API + EC2API mock.Mock } -func (m *MockedEC2API) DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { +func (m *MockedEC2API) DescribeInstances(ctx context.Context, input *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { args := m.Called(input) return args.Get(0).(*ec2.DescribeInstancesOutput), args.Error(1) } @@ -3919,17 +3850,17 @@ func TestDescribeInstances(t *testing.T) { tests := []struct { name string input *ec2.DescribeInstancesInput - expect func(ec2iface.EC2API) + expect func(EC2API) isError bool }{ { "MaxResults set on empty DescribeInstancesInput and NextToken respected.", &ec2.DescribeInstancesInput{}, - func(mockedEc2 ec2iface.EC2API) { + func(mockedEc2 EC2API) { m := mockedEc2.(*MockedEC2API) m.On("DescribeInstances", &ec2.DescribeInstancesInput{ - MaxResults: aws.Int64(1000), + MaxResults: aws.Int32(1000), }, ).Return( &ec2.DescribeInstancesOutput{ @@ -3939,7 +3870,7 @@ func TestDescribeInstances(t *testing.T) { ) m.On("DescribeInstances", &ec2.DescribeInstancesInput{ - MaxResults: aws.Int64(1000), + MaxResults: aws.Int32(1000), NextToken: aws.String("asdf"), }, ).Return( @@ -3952,13 +3883,13 @@ func TestDescribeInstances(t *testing.T) { { "MaxResults only set if empty DescribeInstancesInput", &ec2.DescribeInstancesInput{ - MaxResults: aws.Int64(3), + MaxResults: aws.Int32(3), }, - func(mockedEc2 ec2iface.EC2API) { + func(mockedEc2 EC2API) { m := mockedEc2.(*MockedEC2API) m.On("DescribeInstances", &ec2.DescribeInstancesInput{ - MaxResults: aws.Int64(3), + MaxResults: aws.Int32(3), }, ).Return( &ec2.DescribeInstancesOutput{}, @@ -3970,13 +3901,13 @@ func TestDescribeInstances(t *testing.T) { { "MaxResults not set if instance IDs are provided", &ec2.DescribeInstancesInput{ - InstanceIds: []*string{aws.String("i-1234")}, + InstanceIds: []string{"i-1234"}, }, - func(mockedEc2 ec2iface.EC2API) { + func(mockedEc2 EC2API) { m := mockedEc2.(*MockedEC2API) m.On("DescribeInstances", &ec2.DescribeInstancesInput{ - InstanceIds: []*string{aws.String("i-1234")}, + InstanceIds: []string{"i-1234"}, }, ).Return( &ec2.DescribeInstancesOutput{}, @@ -3994,7 +3925,7 @@ func TestDescribeInstances(t *testing.T) { fakeEC2 := awsSdkEC2{ ec2: mockedEC2API, } - _, err := fakeEC2.DescribeInstances(test.input) + _, err := fakeEC2.DescribeInstances(context.TODO(), test.input) if !test.isError { assert.NoError(t, err) } @@ -4057,3 +3988,32 @@ func TestInstanceIDIndexFunc(t *testing.T) { }) } } + +func TestIsAWSErrorInstanceNotFound(t *testing.T) { + mockedEC2API := newMockedEC2API() + ec2Client := &awsSdkEC2{ + ec2: mockedEC2API, + } + + // API error + mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{}, error(&smithy.GenericAPIError{ + Code: string(ec2types.UnsuccessfulInstanceCreditSpecificationErrorCodeInstanceNotFound), + Message: "test", + })) + _, err := ec2Client.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + assert.True(t, IsAWSErrorInstanceNotFound(err)) + + // Wrapped error + _, err = ec2Client.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + err = fmt.Errorf("error listing AWS instances: %q", err) + assert.True(t, IsAWSErrorInstanceNotFound(err)) + + // Expect false for nil and any other errors + assert.False(t, IsAWSErrorInstanceNotFound(nil)) + + mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesInput{}, &smithy.GenericAPIError{ + Code: string(ec2types.UnsuccessfulInstanceCreditSpecificationErrorCodeIncorrectInstanceState), + }) + _, err = ec2Client.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + assert.False(t, IsAWSErrorInstanceNotFound(nil)) +} diff --git a/pkg/providers/v1/aws_utils.go b/pkg/providers/v1/aws_utils.go index 621731ed1c..44cea6378e 100644 --- a/pkg/providers/v1/aws_utils.go +++ b/pkg/providers/v1/aws_utils.go @@ -19,30 +19,29 @@ package aws import ( "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go-v2/aws/arn" "k8s.io/apimachinery/pkg/util/sets" ) -func stringSetToPointers(in sets.String) []*string { +func stringSetToList(in sets.Set[string]) []string { if in == nil { return nil } - out := make([]*string, 0, len(in)) + out := make([]string, 0, len(in)) for k := range in { - out = append(out, aws.String(k)) + out = append(out, k) } return out } -func stringSetFromPointers(in []*string) sets.String { +func stringSetFromList(in []string) sets.Set[string] { if in == nil { return nil } - out := sets.NewString() + out := sets.New[string]() for i := range in { - out.Insert(aws.StringValue(in[i])) + out.Insert(in[i]) } return out } diff --git a/pkg/providers/v1/instances.go b/pkg/providers/v1/instances.go index ed3aad13f9..001438ac95 100644 --- a/pkg/providers/v1/instances.go +++ b/pkg/providers/v1/instances.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "context" "fmt" "net/url" "regexp" @@ -24,8 +25,9 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" ) @@ -122,12 +124,12 @@ func mapToAWSInstanceIDsTolerant(nodes []*v1.Node) []InstanceID { } // Gets the full information about this instance from the EC2 API -func describeInstance(ec2Client EC2, instanceID InstanceID) (*ec2.Instance, error) { +func describeInstance(ctx context.Context, ec2Client EC2, instanceID InstanceID) (*ec2types.Instance, error) { request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []string{string(instanceID)}, } - instances, err := ec2Client.DescribeInstances(request) + instances, err := ec2Client.DescribeInstances(ctx, request) if err != nil { return nil, err } @@ -137,7 +139,7 @@ func describeInstance(ec2Client EC2, instanceID InstanceID) (*ec2.Instance, erro if len(instances) > 1 { return nil, fmt.Errorf("multiple instances found for instance: %s", instanceID) } - return instances[0], nil + return &instances[0], nil } // instanceCache manages the cache of DescribeInstances @@ -151,20 +153,20 @@ type instanceCache struct { // Gets the full information about these instance from the EC2 API. Caller must have acquired c.mutex before // calling describeAllInstancesUncached. -func (c *instanceCache) describeAllInstancesUncached() (*allInstancesSnapshot, error) { +func (c *instanceCache) describeAllInstancesUncached(ctx context.Context) (*allInstancesSnapshot, error) { now := time.Now() klog.V(4).Infof("EC2 DescribeInstances - fetching all instances") - var filters []*ec2.Filter - instances, err := c.cloud.describeInstances(filters) + var filters []ec2types.Filter + instances, err := c.cloud.describeInstances(ctx, filters) if err != nil { return nil, err } - m := make(map[InstanceID]*ec2.Instance) + m := make(map[InstanceID]*ec2types.Instance) for _, i := range instances { - id := InstanceID(aws.StringValue(i.InstanceId)) + id := InstanceID(aws.ToString(i.InstanceId)) m[id] = i } @@ -191,7 +193,7 @@ type cacheCriteria struct { } // describeAllInstancesCached returns all instances, using cached results if applicable -func (c *instanceCache) describeAllInstancesCached(criteria cacheCriteria) (*allInstancesSnapshot, error) { +func (c *instanceCache) describeAllInstancesCached(ctx context.Context, criteria cacheCriteria) (*allInstancesSnapshot, error) { c.mutex.Lock() defer c.mutex.Unlock() if c.snapshot != nil && c.snapshot.MeetsCriteria(criteria) { @@ -199,7 +201,7 @@ func (c *instanceCache) describeAllInstancesCached(criteria cacheCriteria) (*all return c.snapshot, nil } - return c.describeAllInstancesUncached() + return c.describeAllInstancesUncached(ctx) } // olderThan is a simple helper to encapsulate timestamp comparison @@ -235,12 +237,12 @@ func (s *allInstancesSnapshot) MeetsCriteria(criteria cacheCriteria) bool { // along with the timestamp for cache-invalidation purposes type allInstancesSnapshot struct { timestamp time.Time - instances map[InstanceID]*ec2.Instance + instances map[InstanceID]*ec2types.Instance } // FindInstances returns the instances corresponding to the specified ids. If an id is not found, it is ignored. -func (s *allInstancesSnapshot) FindInstances(ids []InstanceID) map[InstanceID]*ec2.Instance { - m := make(map[InstanceID]*ec2.Instance) +func (s *allInstancesSnapshot) FindInstances(ids []InstanceID) map[InstanceID]*ec2types.Instance { + m := make(map[InstanceID]*ec2types.Instance) for _, id := range ids { instance := s.instances[id] if instance != nil { diff --git a/pkg/providers/v1/instances_test.go b/pkg/providers/v1/instances_test.go index ac431c6cf6..9a9866946f 100644 --- a/pkg/providers/v1/instances_test.go +++ b/pkg/providers/v1/instances_test.go @@ -20,8 +20,8 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/stretchr/testify/assert" v1 "k8s.io/api/core/v1" ) @@ -150,8 +150,8 @@ func TestSnapshotMeetsCriteria(t *testing.T) { t.Errorf("Snapshot did not honor HasInstances with missing instances") } - snapshot.instances = make(map[InstanceID]*ec2.Instance) - snapshot.instances[InstanceID("i-12345678")] = &ec2.Instance{} + snapshot.instances = make(map[InstanceID]*ec2types.Instance) + snapshot.instances[InstanceID("i-12345678")] = &ec2types.Instance{} if !snapshot.MeetsCriteria(cacheCriteria{HasInstances: []InstanceID{InstanceID("i-12345678")}}) { t.Errorf("Snapshot did not honor HasInstances with matching instances") @@ -177,14 +177,14 @@ func TestOlderThan(t *testing.T) { func TestSnapshotFindInstances(t *testing.T) { snapshot := &allInstancesSnapshot{} - snapshot.instances = make(map[InstanceID]*ec2.Instance) + snapshot.instances = make(map[InstanceID]*ec2types.Instance) { id := InstanceID("i-12345678") - snapshot.instances[id] = &ec2.Instance{InstanceId: id.awsString()} + snapshot.instances[id] = &ec2types.Instance{InstanceId: id.awsString()} } { id := InstanceID("i-23456789") - snapshot.instances[id] = &ec2.Instance{InstanceId: id.awsString()} + snapshot.instances[id] = &ec2types.Instance{InstanceId: id.awsString()} } instances := snapshot.FindInstances([]InstanceID{InstanceID("i-12345678"), InstanceID("i-23456789"), InstanceID("i-00000000")}) @@ -198,7 +198,7 @@ func TestSnapshotFindInstances(t *testing.T) { t.Errorf("findInstances did not return %s", id) continue } - if aws.StringValue(i.InstanceId) != string(id) { + if aws.ToString(i.InstanceId) != string(id) { t.Errorf("findInstances did not return expected instanceId for %s", id) } if i != snapshot.instances[id] { diff --git a/pkg/providers/v1/log_handler.go b/pkg/providers/v1/log_handler.go index bf0e45664a..ff15339ebe 100644 --- a/pkg/providers/v1/log_handler.go +++ b/pkg/providers/v1/log_handler.go @@ -17,32 +17,70 @@ limitations under the License. package aws import ( - "github.com/aws/aws-sdk-go/aws/request" + "context" + "fmt" + + "github.com/aws/smithy-go" + "github.com/aws/smithy-go/middleware" + "github.com/aws/smithy-go/transport/http" "k8s.io/klog/v2" ) -// Handler for aws-sdk-go that logs all requests -func awsHandlerLogger(req *request.Request) { - service, name := awsServiceAndName(req) - klog.V(4).Infof("AWS request: %s %s", service, name) +// Middleware for AWS SDK Go V2 clients. Logs requests at the Finalize stage. +func awsHandlerLoggerMiddleware() middleware.FinalizeMiddleware { + return middleware.FinalizeMiddlewareFunc( + "k8s/logger", + func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + service, name := awsServiceAndName(ctx) + + klog.V(4).Infof("AWS request: %s %s", service, name) + return next.HandleFinalize(ctx, in) + }, + ) } -func awsSendHandlerLogger(req *request.Request) { - service, name := awsServiceAndName(req) - klog.V(4).Infof("AWS API Send: %s %s %v %v", service, name, req.Operation, req.Params) +// Logs details about the response at the Deserialization stage +func awsValidateResponseHandlerLoggerMiddleware() middleware.DeserializeMiddleware { + return middleware.DeserializeMiddlewareFunc( + "k8s/api-validate-response", + func(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, + ) { + out, metadata, err = next.HandleDeserialize(ctx, in) + response, ok := out.RawResponse.(*http.Response) + if !ok { + return out, metadata, &smithy.DeserializationError{Err: fmt.Errorf("unknown transport type %T", out.RawResponse)} + } + service, name := awsServiceAndName(ctx) + klog.V(4).Infof("AWS API ValidateResponse: %s %s %d", service, name, response.StatusCode) + return out, metadata, err + }, + ) } -func awsValidateResponseHandlerLogger(req *request.Request) { - service, name := awsServiceAndName(req) - klog.V(4).Infof("AWS API ValidateResponse: %s %s %v %v %s", service, name, req.Operation, req.Params, req.HTTPResponse.Status) +// Logs details about the request at the Serialize stage +func awsSendHandlerLoggerMiddleware() middleware.SerializeMiddleware { + return middleware.SerializeMiddlewareFunc( + "k8s/api-request", + func(ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler) ( + out middleware.SerializeOutput, metadata middleware.Metadata, err error, + ) { + service, name := awsServiceAndName(ctx) + klog.V(4).Infof("AWS API Send: %s %s %v", service, name, in.Parameters) + return next.HandleSerialize(ctx, in) + }, + ) } -func awsServiceAndName(req *request.Request) (string, string) { - service := req.ClientInfo.ServiceName +// Gets the service and operation name from AWS SDK Go V2 client requests. +func awsServiceAndName(ctx context.Context) (string, string) { + service := middleware.GetServiceID(ctx) name := "?" - if req.Operation != nil { - name = req.Operation.Name + if opName := middleware.GetOperationName(ctx); opName != "" { + name = opName } return service, name } diff --git a/pkg/providers/v1/retry_handler.go b/pkg/providers/v1/retry_handler.go index 8023596dad..18643c2dd9 100644 --- a/pkg/providers/v1/retry_handler.go +++ b/pkg/providers/v1/retry_handler.go @@ -17,16 +17,27 @@ limitations under the License. package aws import ( + "context" + "errors" "math" + "strings" "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/smithy-go" + "github.com/aws/smithy-go/middleware" + "github.com/aws/smithy-go/transport/http" "k8s.io/klog/v2" ) +// nonRetryableError is the code for errors coming from API requests that should not be retried. This +// exists to replicate behavior from AWS SDK Go V1, where requests were marked as non-retryable +// in certain cases. +// In AWS SDK Go V2, an error with this error code is thrown in those same cases, and then +// caught during the IsErrorRetryable check by customRetryer. +var nonRetryableError = "non-retryable error" + const ( decayIntervalSeconds = 20 decayFraction = 0.8 @@ -47,60 +58,6 @@ func NewCrossRequestRetryDelay() *CrossRequestRetryDelay { return c } -// BeforeSign is added to the Sign chain; called before each request -func (c *CrossRequestRetryDelay) BeforeSign(r *request.Request) { - now := time.Now() - delay := c.backoff.ComputeDelayForRequest(now) - if delay > 0 { - klog.Warningf("Inserting delay before AWS request (%s) to avoid RequestLimitExceeded: %s", - describeRequest(r), delay.String()) - - if sleepFn := r.Config.SleepDelay; sleepFn != nil { - // Support SleepDelay for backwards compatibility - sleepFn(delay) - } else if err := aws.SleepWithContext(r.Context(), delay); err != nil { - r.Error = awserr.New(request.CanceledErrorCode, "request context canceled", err) - r.Retryable = aws.Bool(false) - return - } - - // Avoid clock skew problems - r.Time = now - } -} - -// Return the operation name, for use in log messages and metrics -func operationName(r *request.Request) string { - name := "?" - if r.Operation != nil { - name = r.Operation.Name - } - return name -} - -// Return a user-friendly string describing the request, for use in log messages -func describeRequest(r *request.Request) string { - service := r.ClientInfo.ServiceName - return service + "::" + operationName(r) -} - -// AfterRetry is added to the AfterRetry chain; called after any error -func (c *CrossRequestRetryDelay) AfterRetry(r *request.Request) { - if r.Error == nil { - return - } - awsError, ok := r.Error.(awserr.Error) - if !ok { - return - } - if awsError.Code() == "RequestLimitExceeded" { - c.backoff.ReportError() - recordAWSThrottlesMetric(operationName(r)) - klog.Warningf("Got RequestLimitExceeded error on AWS request (%s)", - describeRequest(r)) - } -} - // Backoff manages a backoff that varies based on the recently observed failures type Backoff struct { decayIntervalSeconds int64 @@ -170,6 +127,104 @@ func (b *Backoff) ComputeDelayForRequest(now time.Time) time.Duration { func (b *Backoff) ReportError() { b.mutex.Lock() defer b.mutex.Unlock() - b.countErrorsRequestLimit += 1.0 } + +// Standard retry implementation, except that it doesn't retry NON_RETRYABLE_ERROR errors. +// This works in tandem with (l *delayPrerequest) HandleFinalize, which will throw the error +// in certain cases as part of the middleware. +type customRetryer struct { + aws.Retryer +} + +func (r customRetryer) IsErrorRetryable(err error) bool { + if strings.Contains(err.Error(), nonRetryableError) { + return false + } + return r.Retryer.IsErrorRetryable(err) +} + +// Middleware for AWS SDK Go V2 clients +// Throws nonRetryableError if the request context was canceled, to preserve behavior from AWS +// SDK Go V1, where requests were marked as non-retryable under the same conditions. +// This works in tandem with customRetryer, which will not retry nonRetryableErrors. +func delayPreSign(delayer *CrossRequestRetryDelay) middleware.FinalizeMiddleware { + return middleware.FinalizeMiddlewareFunc( + "k8s/delay-presign", + func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + now := time.Now() + delay := delayer.backoff.ComputeDelayForRequest(now) + + if delay > 0 { + klog.Warningf("Inserting delay before AWS request (%s) to avoid RequestLimitExceeded: %s", + describeRequest(ctx), delay.String()) + + if err := sleepWithContext(ctx, delay); err != nil { + return middleware.FinalizeOutput{}, middleware.Metadata{}, errors.New(nonRetryableError) + } + } + + service, name := awsServiceAndName(ctx) + request, ok := in.Request.(*http.Request) + if ok { + klog.V(4).Infof("AWS API Send: %s %s %s %s", service, name, request.Request.Method, request.Request.URL.Path) + } + return next.HandleFinalize(ctx, in) + }, + ) +} + +func delayAfterRetry(delayer *CrossRequestRetryDelay) middleware.FinalizeMiddleware { + return middleware.FinalizeMiddlewareFunc( + "k8s/delay-afterretry", + func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + finOutput, finMetadata, finErr := next.HandleFinalize(ctx, in) + if finErr == nil { + return finOutput, finMetadata, finErr + } + + var ae smithy.APIError + if errors.As(finErr, &ae) && strings.Contains(ae.Error(), "RequestLimitExceeded") { + delayer.backoff.ReportError() + recordAWSThrottlesMetric(operationName(ctx)) + klog.Warningf("Got RequestLimitExceeded error on AWS request (%s)", + describeRequest(ctx)) + } + return finOutput, finMetadata, finErr + }, + ) +} + +// Return the operation name, for use in log messages and metrics +func operationName(ctx context.Context) string { + name := "?" + if opName := middleware.GetOperationName(ctx); opName != "" { + name = opName + } + return name +} + +// Return a user-friendly string describing the request, for use in log messages. +func describeRequest(ctx context.Context) string { + service := middleware.GetServiceID(ctx) + + return service + "::" + operationName(ctx) +} + +func sleepWithContext(ctx context.Context, dur time.Duration) error { + t := time.NewTimer(dur) + defer t.Stop() + + select { + case <-t.C: + break + case <-ctx.Done(): + return ctx.Err() + } + + return nil +} diff --git a/pkg/providers/v1/retry_handler_test.go b/pkg/providers/v1/retry_handler_test.go index 27b18c6005..2c3b0a3578 100644 --- a/pkg/providers/v1/retry_handler_test.go +++ b/pkg/providers/v1/retry_handler_test.go @@ -17,8 +17,20 @@ limitations under the License. package aws import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" "time" + + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + // "k8s.io/cloud-provider-aws/pkg/providers/v1/config" ) // There follows a group of tests for the backoff logic. There's nothing @@ -133,3 +145,205 @@ func TestBackoffRecovers(t *testing.T) { now = now.Add(time.Second) } } + +// Make sure that nonRetryableErrors, which are thrown by AWS SDK Go V2 clients +// when the request context is canceled, are not retried with customRetryer is used. +func TestNonRetryableError(t *testing.T) { + mockedEC2API := newMockedEC2API() + mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{}, errors.New(nonRetryableError)) + + ec2Client := &awsSdkEC2{ + ec2: mockedEC2API, + } + _, err := ec2Client.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + + // Verify that the custom retryer can recognize when a nonRetryableError is thrown + retryer := &customRetryer{ + retry.NewStandard(), + } + if retryer.IsErrorRetryable(err) { + t.Errorf("Expected nonRetryableError error to be non-retryable") + } +} + +// Tests delayPresign to ensure that it delays the request +func TestDelayPresign(t *testing.T) { + // This test forces certain results from ComputeDelayForRequest() and sleepWithContext() + // to trigger a delay from delayPresign(). + // Dummy server to make sure the client request doesn't actually hit the API. + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + cfgWithServiceOverride := CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: "EC2", + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + }, + } + // Create a dummy delayer that sets a delay of 1 second for ComputeDelayForRequest() + delayer := NewCrossRequestRetryDelay() + delayer.backoff.countRequests = 1 + delayer.backoff.countErrorsRequestLimit = 20000 + delayer.backoff.maxDelay = 100000 + regionDelayersMap := make(map[string]*CrossRequestRetryDelay) + regionDelayersMap["us-west-2"] = delayer + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: regionDelayersMap, + } + + ec2Client, err := mockProvider.Compute(context.Background(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + startTime := time.Now() + _, _ = ec2Client.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + endTime := time.Now() + diff := endTime.Sub(startTime).Seconds() + assert.True(t, diff > 1, fmt.Sprintf("expected a delay of at least 1 second, got %f", diff)) +} + +// Tests that delayAfterRetry() recognizes RequestLimitExceeded errors and counts them towards the backoff +func TestDelayAfterRetry(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/xml") + w.WriteHeader(http.StatusBadRequest) + + // Insert the RequestLimitExceeded error message + errorXML := fmt.Sprintf(` + + + + %d + %s + + + 12345678-1234-1234-1234-123456789012 + `, http.StatusBadRequest, "RequestLimitExceeded") + + w.Write([]byte(errorXML)) + })) + defer testServer.Close() + + cfgWithServiceOverride := CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: "EC2", + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + }, + } + delayer := NewCrossRequestRetryDelay() + regionDelayersMap := make(map[string]*CrossRequestRetryDelay) + regionDelayersMap["us-west-2"] = delayer + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: regionDelayersMap, + } + + ec2Client, err := mockProvider.Compute(context.Background(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + preDelayErrorCount := delayer.backoff.countErrorsRequestLimit + _, err = ec2Client.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + postDelayErrorCount := delayer.backoff.countErrorsRequestLimit + + // Verify that a RequestLimitExceeded error was thrown + assert.Error(t, err) + assert.Contains(t, err.Error(), "RequestLimitExceeded") + + // In the event that delayAfterRetry() catches a RequestLimitExceeded error, it will + // update the error count in the delayer. This count is used to verify that this case + // was entered. + diff := (int)(postDelayErrorCount - preDelayErrorCount) + assert.True(t, diff == 1, fmt.Sprintf("expected an update to the backoff count of %d, got %d", 1, diff)) +} + +// Tests that delayAfterRetry() does not update the backoff in case of an error other than RequestLimitExceeded +func TestDelayAfterRetryNoDelay(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/xml") + w.WriteHeader(http.StatusBadRequest) + + // Insert a dummy error message that's not RequestLimitExceeded + errorXML := fmt.Sprintf(` + + + + %d + %s + + + 12345678-1234-1234-1234-123456789012 + `, http.StatusBadRequest, "DummyError") + + w.Write([]byte(errorXML)) + })) + defer testServer.Close() + + cfgWithServiceOverride := CloudConfig{ + ServiceOverride: map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + }{ + "1": { + Service: "EC2", + Region: "us-west-2", + URL: testServer.URL, + SigningRegion: "signingRegion", + SigningName: "signingName", + }, + }, + } + delayer := NewCrossRequestRetryDelay() + regionDelayersMap := make(map[string]*CrossRequestRetryDelay) + regionDelayersMap["us-west-2"] = delayer + mockProvider := &awsSDKProvider{ + cfg: &cfgWithServiceOverride, + regionDelayers: regionDelayersMap, + } + + ec2Client, err := mockProvider.Compute(context.Background(), "us-west-2", nil) + if err != nil { + t.Errorf("error creating client, %v", err) + } + preDelayErrorCount := delayer.backoff.countErrorsRequestLimit + _, err = ec2Client.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{}) + postDelayErrorCount := delayer.backoff.countErrorsRequestLimit + + // Verify that a RequestLimitExceeded error wasn't thrown + assert.Error(t, err) + assert.NotContains(t, err.Error(), "RequestLimitExceeded") + + // In the event that delayAfterRetry() catches a RequestLimitExceeded error, it will + // update the error count in the delayer. This count is used to verify that this case + // was not entered. + diff := (int)(postDelayErrorCount - preDelayErrorCount) + assert.True(t, diff == 0, fmt.Sprintf("expected an update to the backoff count of %d, got %d", 0, diff)) +} diff --git a/pkg/providers/v1/sets_ippermissions.go b/pkg/providers/v1/sets_ippermissions.go index a304deedd5..72e99ec2ea 100644 --- a/pkg/providers/v1/sets_ippermissions.go +++ b/pkg/providers/v1/sets_ippermissions.go @@ -20,21 +20,21 @@ import ( "encoding/json" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" ) // IPPermissionSet maps IP strings of strings to EC2 IpPermissions -type IPPermissionSet map[string]*ec2.IpPermission +type IPPermissionSet map[string]ec2types.IpPermission // IPPermissionPredicate is an predicate to test whether IPPermission matches some condition. type IPPermissionPredicate interface { // Test checks whether specified IPPermission matches condition. - Test(perm *ec2.IpPermission) bool + Test(perm ec2types.IpPermission) bool } // NewIPPermissionSet creates a new IPPermissionSet -func NewIPPermissionSet(items ...*ec2.IpPermission) IPPermissionSet { +func NewIPPermissionSet(items ...ec2types.IpPermission) IPPermissionSet { s := make(IPPermissionSet) s.Insert(items...) return s @@ -44,44 +44,44 @@ func NewIPPermissionSet(items ...*ec2.IpPermission) IPPermissionSet { // EC2 will combine permissions with the same port but different SourceRanges together, for example // We ungroup them so we can process them func (s IPPermissionSet) Ungroup() IPPermissionSet { - l := []*ec2.IpPermission{} + l := []ec2types.IpPermission{} for _, p := range s.List() { if len(p.IpRanges) <= 1 { l = append(l, p) continue } for _, ipRange := range p.IpRanges { - c := &ec2.IpPermission{} - *c = *p - c.IpRanges = []*ec2.IpRange{ipRange} + c := ec2types.IpPermission{} + c = p + c.IpRanges = []ec2types.IpRange{ipRange} l = append(l, c) } } - l2 := []*ec2.IpPermission{} + l2 := []ec2types.IpPermission{} for _, p := range l { if len(p.UserIdGroupPairs) <= 1 { l2 = append(l2, p) continue } for _, u := range p.UserIdGroupPairs { - c := &ec2.IpPermission{} - *c = *p - c.UserIdGroupPairs = []*ec2.UserIdGroupPair{u} + c := ec2types.IpPermission{} + c = p + c.UserIdGroupPairs = []ec2types.UserIdGroupPair{u} l2 = append(l2, c) } } - l3 := []*ec2.IpPermission{} + l3 := []ec2types.IpPermission{} for _, p := range l2 { if len(p.PrefixListIds) <= 1 { l3 = append(l3, p) continue } for _, v := range p.PrefixListIds { - c := &ec2.IpPermission{} - *c = *p - c.PrefixListIds = []*ec2.PrefixListId{v} + c := ec2types.IpPermission{} + c = p + c.PrefixListIds = []ec2types.PrefixListId{v} l3 = append(l3, c) } } @@ -90,7 +90,7 @@ func (s IPPermissionSet) Ungroup() IPPermissionSet { } // Insert adds items to the set. -func (s IPPermissionSet) Insert(items ...*ec2.IpPermission) { +func (s IPPermissionSet) Insert(items ...ec2types.IpPermission) { for _, p := range items { k := keyForIPPermission(p) s[k] = p @@ -98,7 +98,7 @@ func (s IPPermissionSet) Insert(items ...*ec2.IpPermission) { } // Delete delete permission from the set. -func (s IPPermissionSet) Delete(items ...*ec2.IpPermission) { +func (s IPPermissionSet) Delete(items ...ec2types.IpPermission) { for _, p := range items { k := keyForIPPermission(p) delete(s, k) @@ -115,8 +115,8 @@ func (s IPPermissionSet) DeleteIf(predicate IPPermissionPredicate) { } // List returns the contents as a slice. Order is not defined. -func (s IPPermissionSet) List() []*ec2.IpPermission { - res := make([]*ec2.IpPermission, 0, len(s)) +func (s IPPermissionSet) List() []ec2types.IpPermission { + res := make([]ec2types.IpPermission, 0, len(s)) for _, v := range s { res = append(res, v) } @@ -163,7 +163,7 @@ func (s IPPermissionSet) Len() int { return len(s) } -func keyForIPPermission(p *ec2.IpPermission) string { +func keyForIPPermission(p ec2types.IpPermission) string { v, err := json.Marshal(p) if err != nil { panic(fmt.Sprintf("error building JSON representation of ec2.IpPermission: %v", err)) @@ -179,24 +179,24 @@ type IPPermissionMatchDesc struct { } // Test whether specific IPPermission contains description. -func (p IPPermissionMatchDesc) Test(perm *ec2.IpPermission) bool { +func (p IPPermissionMatchDesc) Test(perm ec2types.IpPermission) bool { for _, v4Range := range perm.IpRanges { - if aws.StringValue(v4Range.Description) == p.Description { + if aws.ToString(v4Range.Description) == p.Description { return true } } for _, v6Range := range perm.Ipv6Ranges { - if aws.StringValue(v6Range.Description) == p.Description { + if aws.ToString(v6Range.Description) == p.Description { return true } } for _, prefixListID := range perm.PrefixListIds { - if aws.StringValue(prefixListID.Description) == p.Description { + if aws.ToString(prefixListID.Description) == p.Description { return true } } for _, group := range perm.UserIdGroupPairs { - if aws.StringValue(group.Description) == p.Description { + if aws.ToString(group.Description) == p.Description { return true } } @@ -211,6 +211,6 @@ type IPPermissionNotMatch struct { } // Test whether specific IPPermission not match the embed predicate. -func (p IPPermissionNotMatch) Test(perm *ec2.IpPermission) bool { +func (p IPPermissionNotMatch) Test(perm ec2types.IpPermission) bool { return !p.Predicate.Test(perm) } diff --git a/pkg/providers/v1/sets_ippermissions_test.go b/pkg/providers/v1/sets_ippermissions_test.go index 0680b29b1e..4e4f3a54fd 100644 --- a/pkg/providers/v1/sets_ippermissions_test.go +++ b/pkg/providers/v1/sets_ippermissions_test.go @@ -3,8 +3,8 @@ package aws import ( "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" ) func TestUngroup(t *testing.T) { @@ -17,67 +17,67 @@ func TestUngroup(t *testing.T) { { "Single IP range in input set", NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - ToPort: aws.Int64(2), + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + ToPort: aws.Int32(2), }, ), NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - ToPort: aws.Int64(2), + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + ToPort: aws.Int32(2), }, ), }, { "Three ip ranges in input set", NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("10.0.0.0/16")}, {CidrIp: aws.String("10.1.0.0/16")}, {CidrIp: aws.String("10.2.0.0/16")}, }, - ToPort: aws.Int64(2), + ToPort: aws.Int32(2), }, ), NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - ToPort: aws.Int64(2), + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + ToPort: aws.Int32(2), }, - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.1.0.0/16")}}, - ToPort: aws.Int64(2), + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.1.0.0/16")}}, + ToPort: aws.Int32(2), }, - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.2.0.0/16")}}, - ToPort: aws.Int64(2), + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.2.0.0/16")}}, + ToPort: aws.Int32(2), }, ), }, { "Three UserIdGroupPairs in input set", NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{ + IpRanges: []ec2types.IpRange{ {CidrIp: aws.String("10.0.0.0/16")}, }, - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + UserIdGroupPairs: []ec2types.UserIdGroupPair{ { GroupId: aws.String("1"), GroupName: aws.String("group-1"), @@ -97,15 +97,15 @@ func TestUngroup(t *testing.T) { VpcId: aws.String("123"), }, }, - ToPort: aws.Int64(2), + ToPort: aws.Int32(2), }, ), NewIPPermissionSet( - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + UserIdGroupPairs: []ec2types.UserIdGroupPair{ { GroupId: aws.String("1"), GroupName: aws.String("group-1"), @@ -113,13 +113,13 @@ func TestUngroup(t *testing.T) { VpcId: aws.String("123"), }, }, - ToPort: aws.Int64(2), + ToPort: aws.Int32(2), }, - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + UserIdGroupPairs: []ec2types.UserIdGroupPair{ { GroupId: aws.String("2"), GroupName: aws.String("group-2"), @@ -127,13 +127,13 @@ func TestUngroup(t *testing.T) { VpcId: aws.String("123"), }, }, - ToPort: aws.Int64(2), + ToPort: aws.Int32(2), }, - &ec2.IpPermission{ - FromPort: aws.Int64(1), + ec2types.IpPermission{ + FromPort: aws.Int32(1), IpProtocol: aws.String("tcp"), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, - UserIdGroupPairs: []*ec2.UserIdGroupPair{ + IpRanges: []ec2types.IpRange{{CidrIp: aws.String("10.0.0.0/16")}}, + UserIdGroupPairs: []ec2types.UserIdGroupPair{ { GroupId: aws.String("3"), GroupName: aws.String("group-3"), @@ -141,7 +141,7 @@ func TestUngroup(t *testing.T) { VpcId: aws.String("123"), }, }, - ToPort: aws.Int64(2), + ToPort: aws.Int32(2), }, ), }, diff --git a/pkg/providers/v1/tags.go b/pkg/providers/v1/tags.go index 9b82436097..5b4a761c16 100644 --- a/pkg/providers/v1/tags.go +++ b/pkg/providers/v1/tags.go @@ -17,11 +17,13 @@ limitations under the License. package aws import ( + "context" "fmt" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "k8s.io/klog/v2" "k8s.io/apimachinery/pkg/util/wait" @@ -85,7 +87,7 @@ func (t *awsTagging) init(legacyClusterID string, clusterID string) error { // Extracts a clusterID from the given tags, if one is present // If no clusterID is found, returns "", nil // If multiple (different) clusterIDs are found, returns an error -func (t *awsTagging) initFromTags(tags []*ec2.Tag) error { +func (t *awsTagging) initFromTags(tags []ec2types.Tag) error { legacyClusterID, newClusterID, err := findClusterIDs(tags) if err != nil { return err @@ -100,12 +102,12 @@ func (t *awsTagging) initFromTags(tags []*ec2.Tag) error { // Extracts the legacy & new cluster ids from the given tags, if they are present // If duplicate tags are found, returns an error -func findClusterIDs(tags []*ec2.Tag) (string, string, error) { +func findClusterIDs(tags []ec2types.Tag) (string, string, error) { legacyClusterID := "" newClusterID := "" for _, tag := range tags { - tagKey := aws.StringValue(tag.Key) + tagKey := aws.ToString(tag.Key) if strings.HasPrefix(tagKey, TagNameKubernetesClusterPrefix) { id := strings.TrimPrefix(tagKey, TagNameKubernetesClusterPrefix) if newClusterID != "" { @@ -115,7 +117,7 @@ func findClusterIDs(tags []*ec2.Tag) (string, string, error) { } if tagKey == TagNameKubernetesClusterLegacy { - id := aws.StringValue(tag.Value) + id := aws.ToString(tag.Value) if legacyClusterID != "" { return "", "", fmt.Errorf("Found multiple %s tags (%q and %q)", TagNameKubernetesClusterLegacy, legacyClusterID, id) } @@ -130,17 +132,17 @@ func (t *awsTagging) clusterTagKey() string { return TagNameKubernetesClusterPrefix + t.ClusterID } -func (t *awsTagging) hasClusterTag(tags []*ec2.Tag) bool { +func (t *awsTagging) hasClusterTag(tags []ec2types.Tag) bool { // if the clusterID is not configured -- we consider all instances. if len(t.ClusterID) == 0 { return true } clusterTagKey := t.clusterTagKey() for _, tag := range tags { - tagKey := aws.StringValue(tag.Key) + tagKey := aws.ToString(tag.Key) // For 1.6, we continue to recognize the legacy tags, for the 1.5 -> 1.6 upgrade // Note that we want to continue traversing tag list if we see a legacy tag with value != ClusterID - if (tagKey == TagNameKubernetesClusterLegacy) && (aws.StringValue(tag.Value) == t.ClusterID) { + if (tagKey == TagNameKubernetesClusterLegacy) && (aws.ToString(tag.Value) == t.ClusterID) { return true } if tagKey == clusterTagKey { @@ -150,9 +152,9 @@ func (t *awsTagging) hasClusterTag(tags []*ec2.Tag) bool { return false } -func (t *awsTagging) hasNoClusterPrefixTag(tags []*ec2.Tag) bool { +func (t *awsTagging) hasNoClusterPrefixTag(tags []ec2types.Tag) bool { for _, tag := range tags { - if strings.HasPrefix(aws.StringValue(tag.Key), TagNameKubernetesClusterPrefix) { + if strings.HasPrefix(aws.ToString(tag.Key), TagNameKubernetesClusterPrefix) { return false } } @@ -162,10 +164,10 @@ func (t *awsTagging) hasNoClusterPrefixTag(tags []*ec2.Tag) bool { // Ensure that a resource has the correct tags // If it has no tags, we assume that this was a problem caused by an error in between creation and tagging, // and we add the tags. If it has a different cluster's tags, that is an error. -func (t *awsTagging) readRepairClusterTags(client EC2, resourceID string, lifecycle ResourceLifecycle, additionalTags map[string]string, observedTags []*ec2.Tag) error { +func (t *awsTagging) readRepairClusterTags(ctx context.Context, client EC2, resourceID string, lifecycle ResourceLifecycle, additionalTags map[string]string, observedTags []ec2types.Tag) error { actualTagMap := make(map[string]string) for _, tag := range observedTags { - actualTagMap[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + actualTagMap[aws.ToString(tag.Key)] = aws.ToString(tag.Value) } expectedTags := t.buildTags(lifecycle, additionalTags) @@ -188,7 +190,7 @@ func (t *awsTagging) readRepairClusterTags(client EC2, resourceID string, lifecy return nil } - if err := t.createTags(client, resourceID, lifecycle, addTags); err != nil { + if err := t.createTags(ctx, client, resourceID, lifecycle, addTags); err != nil { return fmt.Errorf("error adding missing tags to resource %q: %q", resourceID, err) } @@ -198,16 +200,16 @@ func (t *awsTagging) readRepairClusterTags(client EC2, resourceID string, lifecy // createTags calls EC2 CreateTags, but adds retry-on-failure logic // We retry mainly because if we create an object, we cannot tag it until it is "fully created" (eventual consistency) // The error code varies though (depending on what we are tagging), so we simply retry on all errors -func (t *awsTagging) createTags(client EC2, resourceID string, lifecycle ResourceLifecycle, additionalTags map[string]string) error { +func (t *awsTagging) createTags(ctx context.Context, client EC2, resourceID string, lifecycle ResourceLifecycle, additionalTags map[string]string) error { tags := t.buildTags(lifecycle, additionalTags) if tags == nil || len(tags) == 0 { return nil } - var awsTags []*ec2.Tag + var awsTags []ec2types.Tag for k, v := range tags { - tag := &ec2.Tag{ + tag := ec2types.Tag{ Key: aws.String(k), Value: aws.String(v), } @@ -220,12 +222,12 @@ func (t *awsTagging) createTags(client EC2, resourceID string, lifecycle Resourc Steps: createTagSteps, } request := &ec2.CreateTagsInput{} - request.Resources = []*string{&resourceID} + request.Resources = []string{resourceID} request.Tags = awsTags var lastErr error err := wait.ExponentialBackoff(backoff, func() (bool, error) { - _, err := client.CreateTags(request) + _, err := client.CreateTags(ctx, request) if err == nil { return true, nil } @@ -245,7 +247,7 @@ func (t *awsTagging) createTags(client EC2, resourceID string, lifecycle Resourc // Add additional filters, to match on our tags // This lets us run multiple k8s clusters in a single EC2 AZ -func (t *awsTagging) addFilters(filters []*ec2.Filter) []*ec2.Filter { +func (t *awsTagging) addFilters(filters []ec2types.Filter) []ec2types.Filter { // if there are no clusterID configured - no filtering by special tag names // should be applied to revert to legacy behaviour. if len(t.ClusterID) == 0 { @@ -266,7 +268,7 @@ func (t *awsTagging) addFilters(filters []*ec2.Filter) []*ec2.Filter { // 1.5 -> 1.6 clusters and exists for backwards compatibility // // This lets us run multiple k8s clusters in a single EC2 AZ -func (t *awsTagging) addLegacyFilters(filters []*ec2.Filter) []*ec2.Filter { +func (t *awsTagging) addLegacyFilters(filters []ec2types.Filter) []ec2types.Filter { // if there are no clusterID configured - no filtering by special tag names // should be applied to revert to legacy behaviour. if len(t.ClusterID) == 0 { @@ -313,13 +315,13 @@ func (t *awsTagging) clusterID() string { // TagResource calls EC2 and tag the resource associated to resourceID // with the supplied tags -func (c *Cloud) TagResource(resourceID string, tags map[string]string) error { +func (c *Cloud) TagResource(ctx context.Context, resourceID string, tags map[string]string) error { request := &ec2.CreateTagsInput{ - Resources: []*string{aws.String(resourceID)}, + Resources: []string{resourceID}, Tags: buildAwsTags(tags), } - output, err := c.ec2.CreateTags(request) + output, err := c.ec2.CreateTags(ctx, request) if err != nil { klog.Errorf("Error occurred trying to tag resources, %v", err) @@ -333,13 +335,13 @@ func (c *Cloud) TagResource(resourceID string, tags map[string]string) error { // UntagResource calls EC2 and tag the resource associated to resourceID // with the supplied tags -func (c *Cloud) UntagResource(resourceID string, tags map[string]string) error { +func (c *Cloud) UntagResource(ctx context.Context, resourceID string, tags map[string]string) error { request := &ec2.DeleteTagsInput{ - Resources: []*string{aws.String(resourceID)}, + Resources: []string{resourceID}, Tags: buildAwsTags(tags), } - output, err := c.ec2.DeleteTags(request) + output, err := c.ec2.DeleteTags(ctx, request) if err != nil { // An instance not found should not fail the untagging workflow as it @@ -357,10 +359,10 @@ func (c *Cloud) UntagResource(resourceID string, tags map[string]string) error { return nil } -func buildAwsTags(tags map[string]string) []*ec2.Tag { - var awsTags []*ec2.Tag +func buildAwsTags(tags map[string]string) []ec2types.Tag { + var awsTags []ec2types.Tag for k, v := range tags { - newTag := &ec2.Tag{ + newTag := ec2types.Tag{ Key: aws.String(k), Value: aws.String(v), } diff --git a/pkg/providers/v1/tags_test.go b/pkg/providers/v1/tags_test.go index af2e70927e..3e9d44d914 100644 --- a/pkg/providers/v1/tags_test.go +++ b/pkg/providers/v1/tags_test.go @@ -18,21 +18,22 @@ package aws import ( "bytes" + "context" "errors" "flag" "os" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/stretchr/testify/assert" "k8s.io/klog/v2" ) func TestFilterTags(t *testing.T) { awsServices := NewFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -94,9 +95,9 @@ func TestFindClusterID(t *testing.T) { }, } for _, g := range grid { - var ec2Tags []*ec2.Tag + var ec2Tags []ec2types.Tag for k, v := range g.Tags { - ec2Tags = append(ec2Tags, &ec2.Tag{Key: aws.String(k), Value: aws.String(v)}) + ec2Tags = append(ec2Tags, ec2types.Tag{Key: aws.String(k), Value: aws.String(v)}) } actualLegacy, actualNew, err := findClusterIDs(ec2Tags) if g.ExpectError { @@ -125,7 +126,7 @@ func TestFindClusterID(t *testing.T) { func TestHasClusterTag(t *testing.T) { awsServices := NewFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -177,9 +178,9 @@ func TestHasClusterTag(t *testing.T) { }, } for _, g := range grid { - var ec2Tags []*ec2.Tag + var ec2Tags []ec2types.Tag for k, v := range g.Tags { - ec2Tags = append(ec2Tags, &ec2.Tag{Key: aws.String(k), Value: aws.String(v)}) + ec2Tags = append(ec2Tags, ec2types.Tag{Key: aws.String(k), Value: aws.String(v)}) } result := c.tagging.hasClusterTag(ec2Tags) if result != g.Expected { @@ -190,14 +191,14 @@ func TestHasClusterTag(t *testing.T) { func TestHasNoClusterPrefixTag(t *testing.T) { awsServices := NewFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Errorf("Error building aws cloud: %v", err) return } tests := []struct { name string - tags []*ec2.Tag + tags []ec2types.Tag want bool }{ { @@ -206,7 +207,7 @@ func TestHasNoClusterPrefixTag(t *testing.T) { }, { name: "no cluster tags", - tags: []*ec2.Tag{ + tags: []ec2types.Tag{ { Key: aws.String("not a cluster tag"), Value: aws.String("true"), @@ -216,7 +217,7 @@ func TestHasNoClusterPrefixTag(t *testing.T) { }, { name: "contains cluster tags", - tags: []*ec2.Tag{ + tags: []ec2types.Tag{ { Key: aws.String("tag1"), Value: aws.String("value1"), @@ -242,7 +243,7 @@ func TestTagResource(t *testing.T) { klog.InitFlags(testFlags) testFlags.Parse([]string{"--logtostderr=false"}) awsServices := NewFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -269,7 +270,7 @@ func TestTagResource(t *testing.T) { { name: "tagging failed due to resource not found error", instanceID: "i-not-found", - err: awserr.New("InvalidInstanceID.NotFound", "Instance not found", nil), + err: errors.New("InvalidInstanceID.NotFound: Instance not found"), expectedMessage: "Error occurred trying to tag resources", }, } @@ -282,7 +283,7 @@ func TestTagResource(t *testing.T) { klog.SetOutput(os.Stderr) }() - err := c.TagResource(tt.instanceID, nil) + err := c.TagResource(context.TODO(), tt.instanceID, nil) assert.Equal(t, tt.err, err) assert.Contains(t, logBuf.String(), tt.expectedMessage) }) @@ -294,7 +295,7 @@ func TestUntagResource(t *testing.T) { klog.InitFlags(testFlags) testFlags.Parse([]string{"--logtostderr=false"}) awsServices := NewFakeAWSServices(TestClusterID) - c, err := newAWSCloud(CloudConfig{}, awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices, nil) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -334,7 +335,7 @@ func TestUntagResource(t *testing.T) { klog.SetOutput(os.Stderr) }() - err := c.UntagResource(tt.instanceID, nil) + err := c.UntagResource(context.TODO(), tt.instanceID, nil) assert.Equal(t, tt.err, err) assert.Contains(t, logBuf.String(), tt.expectedMessage) }) diff --git a/pkg/providers/v1/volumes.go b/pkg/providers/v1/volumes.go index 2e913fc7eb..3f17e7572d 100644 --- a/pkg/providers/v1/volumes.go +++ b/pkg/providers/v1/volumes.go @@ -17,10 +17,11 @@ limitations under the License. package aws import ( + "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" csimigration "k8s.io/csi-translation-lib/plugins" "k8s.io/klog/v2" @@ -47,7 +48,7 @@ type KubernetesVolumeID string // DiskInfo returns aws disk information in easy to use manner type diskInfo struct { - ec2Instance *ec2.Instance + ec2Instance *ec2types.Instance nodeName types.NodeName volumeState string attachmentState string @@ -71,7 +72,7 @@ func GetAWSVolumeID(kubeVolumeID string) (string, error) { return string(awsID), err } -func (c *Cloud) checkIfAttachedToNode(diskName KubernetesVolumeID, nodeName types.NodeName) (*diskInfo, bool, error) { +func (c *Cloud) checkIfAttachedToNode(ctx context.Context, diskName KubernetesVolumeID, nodeName types.NodeName) (*diskInfo, bool, error) { disk, err := newAWSDisk(c, diskName) if err != nil { @@ -82,7 +83,7 @@ func (c *Cloud) checkIfAttachedToNode(diskName KubernetesVolumeID, nodeName type disk: disk, } - info, err := disk.describeVolume() + info, err := disk.describeVolume(ctx) if err != nil { klog.Warningf("Error describing volume %s with %v", diskName, err) @@ -90,13 +91,13 @@ func (c *Cloud) checkIfAttachedToNode(diskName KubernetesVolumeID, nodeName type return awsDiskInfo, false, err } - awsDiskInfo.volumeState = aws.StringValue(info.State) + awsDiskInfo.volumeState = string(info.State) if len(info.Attachments) > 0 { attachment := info.Attachments[0] - awsDiskInfo.attachmentState = aws.StringValue(attachment.State) - instanceID := aws.StringValue(attachment.InstanceId) - instanceInfo, err := c.getInstanceByID(instanceID) + awsDiskInfo.attachmentState = string(attachment.State) + instanceID := *attachment.InstanceId + instanceInfo, err := c.getInstanceByID(ctx, instanceID) // This should never happen but if it does it could mean there was a race and instance // has been deleted diff --git a/pkg/services/aws_sts.go b/pkg/services/aws_sts.go new file mode 100644 index 0000000000..f62cf01e90 --- /dev/null +++ b/pkg/services/aws_sts.go @@ -0,0 +1,97 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package services + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "k8s.io/klog/v2" +) + +const headerSourceArn = "x-amz-source-arn" +const headerSourceAccount = "x-amz-source-account" + +type withStsHeadersMiddleware struct { + headers map[string]string +} + +func (*withStsHeadersMiddleware) ID() string { + return "withStsHeadersMiddleware" +} + +func (m *withStsHeadersMiddleware) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unrecognized transport type %T", in.Request) + } + + for k, v := range m.headers { + req.Header.Set(k, v) + } + return next.HandleBuild(ctx, in) +} + +// WithStsHeadersMiddleware provides middleware to set custom headers for STS calls +func WithStsHeadersMiddleware(headers map[string]string) func(*sts.Options) { + return func(o *sts.Options) { + o.APIOptions = append(o.APIOptions, func(s *middleware.Stack) error { + return s.Build.Add(&withStsHeadersMiddleware{ + headers: headers, + }, middleware.After) + }) + } +} + +// NewStsClient provides a new STS client. +func NewStsClient(ctx context.Context, region, roleARN, sourceARN string) (*sts.Client, error) { + klog.Infof("Using AWS assumed role %v", roleARN) + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return nil, err + } + + parsedSourceArn, err := arn.Parse(roleARN) + if err != nil { + return nil, err + } + + sourceAcct := parsedSourceArn.AccountID + + reqHeaders := map[string]string{ + headerSourceAccount: sourceAcct, + } + if sourceARN != "" { + reqHeaders[headerSourceArn] = sourceARN + } + + // Create the STS client with the custom middleware + // svc := s3.NewFromConfig(cfg, WithHeader("x-user-header", "...")) + stsClient := sts.NewFromConfig(cfg, func(o *sts.Options) { + o.Region = region + }, WithStsHeadersMiddleware(reqHeaders)) + + klog.V(4).Infof("configuring STS client with extra headers, %v", reqHeaders) + return stsClient, nil +} diff --git a/version.txt b/version.txt index 0ca398d6d6..28d8420a18 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -1.28.11 +1.28.12