diff --git a/pkg/cloud/services/network/subnets.go b/pkg/cloud/services/network/subnets.go index f6406bd833..7cb259669f 100644 --- a/pkg/cloud/services/network/subnets.go +++ b/pkg/cloud/services/network/subnets.go @@ -58,26 +58,16 @@ func (s *Service) reconcileSubnets() error { existing infrav1.Subnets ) - // Describing the VPC Subnets tags the resources. - if s.scope.TagUnmanagedNetworkResources() { - // Describe subnets in the vpc. - if existing, err = s.describeVpcSubnets(); err != nil { - return err - } - } - unmanagedVPC := s.scope.VPC().IsUnmanaged(s.scope.Name()) - if len(subnets) == 0 { - if unmanagedVPC { - // If we have a unmanaged VPC then subnets must be specified - errMsg := "no subnets specified, you must specify the subnets when using an umanaged vpc" - record.Warnf(s.scope.InfraCluster(), "FailedNoSubnets", errMsg) - return errors.New(errMsg) - } - - // If we a managed VPC and have no subnets then create subnets. There will be 1 public and 1 private subnet - // for each az in a region up to a maximum of 3 azs + // If we have a unmanaged VPC then subnets must be specified + if unmanagedVPC && len(subnets) == 0 { + errMsg := "no subnets specified, you must specify the subnets when using an umanaged vpc" + record.Warnf(s.scope.InfraCluster(), "FailedNoSubnets", errMsg) + return errors.New(errMsg) + } else if len(subnets) == 0 { + // If we have a managed VPC and have no subnets then create some default subnets. + // There will be 1 public and 1 private subnet for each az in a region up to a maximum of 3 azs. s.scope.Info("no subnets specified, setting defaults") subnets, err = s.getDefaultSubnets() @@ -93,12 +83,9 @@ func (s *Service) reconcileSubnets() error { } } - // Describing the VPC Subnets tags the resources. - if !s.scope.TagUnmanagedNetworkResources() { - // Describe subnets in the vpc. - if existing, err = s.describeVpcSubnets(); err != nil { - return err - } + // Describe subnets in the vpc. + if existing, err = s.describeVpcSubnets(); err != nil { + return err } if s.scope.SecondaryCidrBlock() != nil { @@ -132,13 +119,10 @@ func (s *Service) reconcileSubnets() error { sub := &subnets[i] existingSubnet := existing.FindEqual(sub) if existingSubnet != nil { - if len(sub.ID) > 0 { - // NOTE: Describing subnets assumes the subnet.ID is the same as the subnet's identifier (i.e. subnet-), - // if we have a subnet ID specified in the spec, we need to restore it. - existingSubnet.ID = sub.ID - } - - // Update subnet spec with the existing subnet details + // Update subnet spec with the existing subnet details, but we want to keep the subnet ID and tags defined in the spec. + // We don't want to mess with tags that exist only on AWS. + existingSubnet.ID = sub.ID + existingSubnet.Tags = sub.Tags existingSubnet.DeepCopyInto(sub) // Make sure tags are up-to-date. @@ -168,12 +152,6 @@ func (s *Service) reconcileSubnets() error { } } - // If we have an unmanaged VPC, require that the user has specified at least 1 subnet. - if unmanagedVPC && len(subnets) < 1 { - record.Warnf(s.scope.InfraCluster(), "FailedNoSubnet", "Expected at least 1 subnet but got 0") - return errors.New("expected at least 1 subnet but got 0") - } - // Reconciling the zone information for the subnets. Subnets are grouped // by regular zones (availability zones) or edge zones (local zones or wavelength zones) // based in the zone-type attribute for zone. diff --git a/pkg/cloud/services/network/subnets_test.go b/pkg/cloud/services/network/subnets_test.go index 6daa99c9ca..23d4ea091a 100644 --- a/pkg/cloud/services/network/subnets_test.go +++ b/pkg/cloud/services/network/subnets_test.go @@ -18,9 +18,9 @@ package network import ( "context" - "encoding/json" "fmt" "reflect" + "slices" "testing" "github.com/aws/aws-sdk-go/aws" @@ -63,10 +63,7 @@ func TestReconcileSubnets(t *testing.T) { {ID: "subnet-private-us-east-1-wl1-nyc-wlz-1", AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", CidrBlock: "10.0.7.0/24", IsPublic: false}, {ID: "subnet-public-us-east-1-wl1-nyc-wlz-1", AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", CidrBlock: "10.0.8.0/24", IsPublic: true}, } - // TODO(mtulio): replace by slices.Concat(...) on go 1.22+ - stubSubnetsAllZones := stubSubnetsAvailabilityZone - stubSubnetsAllZones = append(stubSubnetsAllZones, stubSubnetsLocalZone...) - stubSubnetsAllZones = append(stubSubnetsAllZones, stubSubnetsWavelengthZone...) + stubSubnetsAllZones := slices.Concat(stubSubnetsAvailabilityZone, stubSubnetsLocalZone, stubSubnetsWavelengthZone) // NetworkSpec with subnets in zone type availability-zone stubNetworkSpecWithSubnets := &infrav1.NetworkSpec{ @@ -678,7 +675,6 @@ func TestReconcileSubnets(t *testing.T) { AvailabilityZone: "us-east-1a", CidrBlock: "10.0.10.0/24", IsPublic: true, - Tags: infrav1.Tags{}, }, }, expect: func(m *mocks.MockEC2APIMockRecorder) { @@ -702,6 +698,12 @@ func TestReconcileSubnets(t *testing.T) { AvailabilityZone: aws.String("us-east-1a"), CidrBlock: aws.String("10.0.10.0/24"), MapPublicIpOnLaunch: aws.Bool(false), + Tags: []*ec2.Tag{ + { + Key: aws.String("company-policy"), + Value: aws.String("enabled"), + }, + }, }, }, }, nil) @@ -784,7 +786,6 @@ func TestReconcileSubnets(t *testing.T) { AvailabilityZone: "us-east-1a", CidrBlock: "10.0.10.0/24", IsPublic: true, - Tags: infrav1.Tags{}, }, { ID: "subnet-2", @@ -792,7 +793,6 @@ func TestReconcileSubnets(t *testing.T) { AvailabilityZone: "us-east-1b", CidrBlock: "10.0.11.0/24", IsPublic: true, - Tags: infrav1.Tags{}, }, }, expect: func(m *mocks.MockEC2APIMockRecorder) { @@ -1057,55 +1057,7 @@ func TestReconcileSubnets(t *testing.T) { }, Subnets: []infrav1.SubnetSpec{}, }).WithTagUnmanagedNetworkResources(true), - expect: func(m *mocks.MockEC2APIMockRecorder) { - m.DescribeSubnetsWithContext(context.TODO(), gomock.Eq(&ec2.DescribeSubnetsInput{ - Filters: []*ec2.Filter{ - { - Name: aws.String("state"), - Values: []*string{aws.String("pending"), aws.String("available")}, - }, - { - Name: aws.String("vpc-id"), - Values: []*string{aws.String(subnetsVPCID)}, - }, - }, - })). - Return(&ec2.DescribeSubnetsOutput{ - Subnets: []*ec2.Subnet{ - { - VpcId: aws.String(subnetsVPCID), - SubnetId: aws.String("subnet-1"), - AvailabilityZone: aws.String("us-east-1a"), - CidrBlock: aws.String("10.0.10.0/24"), - MapPublicIpOnLaunch: aws.Bool(false), - }, - { - VpcId: aws.String(subnetsVPCID), - SubnetId: aws.String("subnet-2"), - AvailabilityZone: aws.String("us-east-1a"), - CidrBlock: aws.String("10.0.20.0/24"), - MapPublicIpOnLaunch: aws.Bool(false), - }, - }, - }, nil) - m.DescribeRouteTablesWithContext(context.TODO(), gomock.AssignableToTypeOf(&ec2.DescribeRouteTablesInput{})). - Return(&ec2.DescribeRouteTablesOutput{}, nil) - - m.DescribeNatGatewaysPagesWithContext(context.TODO(), - gomock.Eq(&ec2.DescribeNatGatewaysInput{ - Filter: []*ec2.Filter{ - { - Name: aws.String("vpc-id"), - Values: []*string{aws.String(subnetsVPCID)}, - }, - { - Name: aws.String("state"), - Values: []*string{aws.String("pending"), aws.String("available")}, - }, - }, - }), - gomock.Any()).Return(nil) - }, + expect: func(m *mocks.MockEC2APIMockRecorder) {}, errorExpected: true, tagUnmanagedNetworkResources: true, }, @@ -3261,10 +3213,7 @@ func TestDiscoverSubnets(t *testing.T) { CidrBlock: "10.0.10.0/24", IsPublic: true, RouteTableID: aws.String("rtb-1"), - Tags: infrav1.Tags{ - "Name": "provided-subnet-public", - }, - ZoneType: ptr.To[infrav1.ZoneType]("availability-zone"), + ZoneType: ptr.To[infrav1.ZoneType]("availability-zone"), }, { ID: "subnet-2", @@ -3273,10 +3222,7 @@ func TestDiscoverSubnets(t *testing.T) { CidrBlock: "10.0.11.0/24", IsPublic: false, RouteTableID: aws.String("rtb-2"), - Tags: infrav1.Tags{ - "Name": "provided-subnet-private", - }, - ZoneType: ptr.To[infrav1.ZoneType]("availability-zone"), + ZoneType: ptr.To[infrav1.ZoneType]("availability-zone"), }, }, }, @@ -3328,15 +3274,7 @@ func TestDiscoverSubnets(t *testing.T) { } if !cmp.Equal(sn, exp) { - expected, err := json.MarshalIndent(exp, "", "\t") - if err != nil { - t.Fatalf("got an unexpected error: %v", err) - } - actual, err := json.MarshalIndent(sn, "", "\t") - if err != nil { - t.Fatalf("got an unexpected error: %v", err) - } - t.Errorf("Expected %s, got %s", string(expected), string(actual)) + t.Errorf("Expected subnets to be equal. Diff %s", cmp.Diff(sn, exp)) } delete(out, exp.ID) }