diff --git a/pkg/cloud/services/network/egress_only_gateways.go b/pkg/cloud/services/network/egress_only_gateways.go index e710adecf7..1827957ac2 100644 --- a/pkg/cloud/services/network/egress_only_gateways.go +++ b/pkg/cloud/services/network/egress_only_gateways.go @@ -136,9 +136,12 @@ func (s *Service) createEgressOnlyInternetGateway() (*types.EgressOnlyInternetGa } func (s *Service) describeEgressOnlyVpcInternetGateways() ([]types.EgressOnlyInternetGateway, error) { + // The API for DescribeEgressOnlyInternetGateways does not support filtering by VPC ID attachment. + // More details: https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeEgressOnlyInternetGateways.html + // Since the eigw is managed by CAPA, we can filter by the kubernetes cluster tag. out, err := s.EC2Client.DescribeEgressOnlyInternetGateways(context.TODO(), &ec2.DescribeEgressOnlyInternetGatewaysInput{ Filters: []types.Filter{ - filter.EC2.VPCAttachment(s.scope.VPC().ID), + filter.EC2.Cluster(s.scope.Name()), }, }) if err != nil { @@ -146,11 +149,28 @@ func (s *Service) describeEgressOnlyVpcInternetGateways() ([]types.EgressOnlyInt return nil, errors.Wrapf(err, "failed to describe egress only internet gateways in vpc %q", s.scope.VPC().ID) } - if len(out.EgressOnlyInternetGateways) == 0 { + // For safeguarding, we collect only egress-only internet gateways + // that are attached to the VPC. + eigws := make([]types.EgressOnlyInternetGateway, 0) + for _, eigw := range out.EgressOnlyInternetGateways { + for _, attachment := range eigw.Attachments { + if aws.ToString(attachment.VpcId) == s.scope.VPC().ID { + eigws = append(eigws, eigw) + } + } + } + + if len(eigws) == 0 { return nil, awserrors.NewNotFound(fmt.Sprintf("no egress only internet gateways found in vpc %q", s.scope.VPC().ID)) + } else if len(eigws) > 1 { + eigwIDs := make([]string, len(eigws)) + for i, eigw := range eigws { + eigwIDs[i] = aws.ToString(eigw.EgressOnlyInternetGatewayId) + } + return nil, awserrors.NewConflict(fmt.Sprintf("expected 1 egress only internet gateway in vpc %q, but found %v: %v", s.scope.VPC().ID, len(eigws), eigwIDs)) } - return out.EgressOnlyInternetGateways, nil + return eigws, nil } func (s *Service) getEgressOnlyGatewayTagParams(id string) infrav1.BuildParams { diff --git a/pkg/cloud/services/network/egress_only_gateways_test.go b/pkg/cloud/services/network/egress_only_gateways_test.go index fbd859ab80..ff12058f12 100644 --- a/pkg/cloud/services/network/egress_only_gateways_test.go +++ b/pkg/cloud/services/network/egress_only_gateways_test.go @@ -40,9 +40,10 @@ func TestReconcileEgressOnlyInternetGateways(t *testing.T) { defer mockCtrl.Finish() testCases := []struct { - name string - input *infrav1.NetworkSpec - expect func(m *mocks.MockEC2APIMockRecorder) + name string + input *infrav1.NetworkSpec + expect func(m *mocks.MockEC2APIMockRecorder) + wantErrContaining *string }{ { name: "has eigw", @@ -75,6 +76,44 @@ func TestReconcileEgressOnlyInternetGateways(t *testing.T) { Return(nil, nil) }, }, + { + name: "has more than 1 eigw, should return error", + input: &infrav1.NetworkSpec{ + VPC: infrav1.VPCSpec{ + ID: "vpc-egress-only-gateways", + IPv6: &infrav1.IPv6{}, + Tags: infrav1.Tags{ + infrav1.ClusterTagKey("test-cluster"): "owned", + }, + }, + }, + wantErrContaining: aws.String("expected 1 egress only internet gateway in vpc \"vpc-egress-only-gateways\", but found 2: [eigw-0 eigw-1]"), + expect: func(m *mocks.MockEC2APIMockRecorder) { + m.DescribeEgressOnlyInternetGateways(context.TODO(), gomock.AssignableToTypeOf(&ec2.DescribeEgressOnlyInternetGatewaysInput{})). + Return(&ec2.DescribeEgressOnlyInternetGatewaysOutput{ + EgressOnlyInternetGateways: []types.EgressOnlyInternetGateway{ + { + EgressOnlyInternetGatewayId: aws.String("eigw-0"), + Attachments: []types.InternetGatewayAttachment{ + { + State: types.AttachmentStatusAttached, + VpcId: aws.String("vpc-egress-only-gateways"), + }, + }, + }, + { + EgressOnlyInternetGatewayId: aws.String("eigw-1"), + Attachments: []types.InternetGatewayAttachment{ + { + State: types.AttachmentStatusAttached, + VpcId: aws.String("vpc-egress-only-gateways"), + }, + }, + }, + }, + }, nil) + }, + }, { name: "no eigw attached, creates one", input: &infrav1.NetworkSpec{ @@ -122,10 +161,13 @@ func TestReconcileEgressOnlyInternetGateways(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) ec2Mock := mocks.NewMockEC2API(mockCtrl) scheme := runtime.NewScheme() - _ = infrav1.AddToScheme(scheme) + err := infrav1.AddToScheme(scheme) + g.Expect(err).NotTo(HaveOccurred()) + client := fake.NewClientBuilder().WithScheme(scheme).Build() scope, err := scope.NewClusterScope(scope.ClusterScopeParams{ Client: client, @@ -139,18 +181,20 @@ func TestReconcileEgressOnlyInternetGateways(t *testing.T) { }, }, }) - if err != nil { - t.Fatalf("Failed to create test context: %v", err) - } + g.Expect(err).NotTo(HaveOccurred()) tc.expect(ec2Mock.EXPECT()) s := NewService(scope) s.EC2Client = ec2Mock - if err := s.reconcileEgressOnlyInternetGateways(); err != nil { - t.Fatalf("got an unexpected error: %v", err) + err = s.reconcileEgressOnlyInternetGateways() + if tc.wantErrContaining != nil { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring(*tc.wantErrContaining)) + return } + g.Expect(err).NotTo(HaveOccurred()) }) } } @@ -199,8 +243,8 @@ func TestDeleteEgressOnlyInternetGateways(t *testing.T) { m.DescribeEgressOnlyInternetGateways(context.TODO(), gomock.Eq(&ec2.DescribeEgressOnlyInternetGatewaysInput{ Filters: []types.Filter{ { - Name: aws.String("attachment.vpc-id"), - Values: []string{"vpc-gateways"}, + Name: aws.String("tag-key"), + Values: []string{infrav1.ClusterTagKey("test-cluster")}, }, }, })).Return(&ec2.DescribeEgressOnlyInternetGatewaysOutput{}, nil)