Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 15 additions & 37 deletions pkg/cloud/services/network/subnets.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,16 @@ func (s *Service) reconcileSubnets() error {
existing infrav1.Subnets
)

// Describing the VPC Subnets tags the resources.
if s.scope.TagUnmanagedNetworkResources() {
// Describe subnets in the vpc.
if existing, err = s.describeVpcSubnets(); err != nil {
return err
}
}

unmanagedVPC := s.scope.VPC().IsUnmanaged(s.scope.Name())

if len(subnets) == 0 {
if unmanagedVPC {
// If we have a unmanaged VPC then subnets must be specified
errMsg := "no subnets specified, you must specify the subnets when using an umanaged vpc"
record.Warnf(s.scope.InfraCluster(), "FailedNoSubnets", errMsg)
return errors.New(errMsg)
}

// If we a managed VPC and have no subnets then create subnets. There will be 1 public and 1 private subnet
// for each az in a region up to a maximum of 3 azs
// If we have a unmanaged VPC then subnets must be specified
if unmanagedVPC && len(subnets) == 0 {
errMsg := "no subnets specified, you must specify the subnets when using an umanaged vpc"
record.Warnf(s.scope.InfraCluster(), "FailedNoSubnets", errMsg)
return errors.New(errMsg)
} else if len(subnets) == 0 {
// If we have a managed VPC and have no subnets then create some default subnets.
// There will be 1 public and 1 private subnet for each az in a region up to a maximum of 3 azs.
s.scope.Info("no subnets specified, setting defaults")

subnets, err = s.getDefaultSubnets()
Expand All @@ -93,12 +83,9 @@ func (s *Service) reconcileSubnets() error {
}
}

// Describing the VPC Subnets tags the resources.
if !s.scope.TagUnmanagedNetworkResources() {
// Describe subnets in the vpc.
if existing, err = s.describeVpcSubnets(); err != nil {
return err
}
// Describe subnets in the vpc.
if existing, err = s.describeVpcSubnets(); err != nil {
return err
}

if s.scope.SecondaryCidrBlock() != nil {
Expand Down Expand Up @@ -132,13 +119,10 @@ func (s *Service) reconcileSubnets() error {
sub := &subnets[i]
existingSubnet := existing.FindEqual(sub)
if existingSubnet != nil {
if len(sub.ID) > 0 {
// NOTE: Describing subnets assumes the subnet.ID is the same as the subnet's identifier (i.e. subnet-<xyz>),
// if we have a subnet ID specified in the spec, we need to restore it.
existingSubnet.ID = sub.ID
}

// Update subnet spec with the existing subnet details
// Update subnet spec with the existing subnet details, but we want to keep the subnet ID and tags defined in the spec.
// We don't want to mess with tags that exist only on AWS.
existingSubnet.ID = sub.ID
existingSubnet.Tags = sub.Tags
existingSubnet.DeepCopyInto(sub)

// Make sure tags are up-to-date.
Expand Down Expand Up @@ -168,12 +152,6 @@ func (s *Service) reconcileSubnets() error {
}
}

// If we have an unmanaged VPC, require that the user has specified at least 1 subnet.
if unmanagedVPC && len(subnets) < 1 {
record.Warnf(s.scope.InfraCluster(), "FailedNoSubnet", "Expected at least 1 subnet but got 0")
return errors.New("expected at least 1 subnet but got 0")
}

// Reconciling the zone information for the subnets. Subnets are grouped
// by regular zones (availability zones) or edge zones (local zones or wavelength zones)
// based in the zone-type attribute for zone.
Expand Down
86 changes: 12 additions & 74 deletions pkg/cloud/services/network/subnets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import (
"context"
"encoding/json"
"fmt"
"reflect"
"slices"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -63,10 +63,7 @@
{ID: "subnet-private-us-east-1-wl1-nyc-wlz-1", AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", CidrBlock: "10.0.7.0/24", IsPublic: false},
{ID: "subnet-public-us-east-1-wl1-nyc-wlz-1", AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", CidrBlock: "10.0.8.0/24", IsPublic: true},
}
// TODO(mtulio): replace by slices.Concat(...) on go 1.22+
stubSubnetsAllZones := stubSubnetsAvailabilityZone
stubSubnetsAllZones = append(stubSubnetsAllZones, stubSubnetsLocalZone...)
stubSubnetsAllZones = append(stubSubnetsAllZones, stubSubnetsWavelengthZone...)
stubSubnetsAllZones := slices.Concat(stubSubnetsAvailabilityZone, stubSubnetsLocalZone, stubSubnetsWavelengthZone)

Check failure on line 66 in pkg/cloud/services/network/subnets_test.go

View workflow job for this annotation

GitHub Actions / lint

undefined: slices.Concat (typecheck)

Check failure on line 66 in pkg/cloud/services/network/subnets_test.go

View workflow job for this annotation

GitHub Actions / lint

undefined: slices.Concat (typecheck)

Check failure on line 66 in pkg/cloud/services/network/subnets_test.go

View workflow job for this annotation

GitHub Actions / lint

undefined: slices.Concat (typecheck)

// NetworkSpec with subnets in zone type availability-zone
stubNetworkSpecWithSubnets := &infrav1.NetworkSpec{
Expand Down Expand Up @@ -678,7 +675,6 @@
AvailabilityZone: "us-east-1a",
CidrBlock: "10.0.10.0/24",
IsPublic: true,
Tags: infrav1.Tags{},
},
},
expect: func(m *mocks.MockEC2APIMockRecorder) {
Expand All @@ -702,6 +698,12 @@
AvailabilityZone: aws.String("us-east-1a"),
CidrBlock: aws.String("10.0.10.0/24"),
MapPublicIpOnLaunch: aws.Bool(false),
Tags: []*ec2.Tag{
{
Key: aws.String("company-policy"),
Value: aws.String("enabled"),
},
},
},
},
}, nil)
Expand Down Expand Up @@ -784,15 +786,13 @@
AvailabilityZone: "us-east-1a",
CidrBlock: "10.0.10.0/24",
IsPublic: true,
Tags: infrav1.Tags{},
},
{
ID: "subnet-2",
ResourceID: "subnet-2",
AvailabilityZone: "us-east-1b",
CidrBlock: "10.0.11.0/24",
IsPublic: true,
Tags: infrav1.Tags{},
},
},
expect: func(m *mocks.MockEC2APIMockRecorder) {
Expand Down Expand Up @@ -1057,55 +1057,7 @@
},
Subnets: []infrav1.SubnetSpec{},
}).WithTagUnmanagedNetworkResources(true),
expect: func(m *mocks.MockEC2APIMockRecorder) {
m.DescribeSubnetsWithContext(context.TODO(), gomock.Eq(&ec2.DescribeSubnetsInput{
Filters: []*ec2.Filter{
{
Name: aws.String("state"),
Values: []*string{aws.String("pending"), aws.String("available")},
},
{
Name: aws.String("vpc-id"),
Values: []*string{aws.String(subnetsVPCID)},
},
},
})).
Return(&ec2.DescribeSubnetsOutput{
Subnets: []*ec2.Subnet{
{
VpcId: aws.String(subnetsVPCID),
SubnetId: aws.String("subnet-1"),
AvailabilityZone: aws.String("us-east-1a"),
CidrBlock: aws.String("10.0.10.0/24"),
MapPublicIpOnLaunch: aws.Bool(false),
},
{
VpcId: aws.String(subnetsVPCID),
SubnetId: aws.String("subnet-2"),
AvailabilityZone: aws.String("us-east-1a"),
CidrBlock: aws.String("10.0.20.0/24"),
MapPublicIpOnLaunch: aws.Bool(false),
},
},
}, nil)
m.DescribeRouteTablesWithContext(context.TODO(), gomock.AssignableToTypeOf(&ec2.DescribeRouteTablesInput{})).
Return(&ec2.DescribeRouteTablesOutput{}, nil)

m.DescribeNatGatewaysPagesWithContext(context.TODO(),
gomock.Eq(&ec2.DescribeNatGatewaysInput{
Filter: []*ec2.Filter{
{
Name: aws.String("vpc-id"),
Values: []*string{aws.String(subnetsVPCID)},
},
{
Name: aws.String("state"),
Values: []*string{aws.String("pending"), aws.String("available")},
},
},
}),
gomock.Any()).Return(nil)
},
expect: func(m *mocks.MockEC2APIMockRecorder) {},
errorExpected: true,
tagUnmanagedNetworkResources: true,
},
Expand Down Expand Up @@ -3261,10 +3213,7 @@
CidrBlock: "10.0.10.0/24",
IsPublic: true,
RouteTableID: aws.String("rtb-1"),
Tags: infrav1.Tags{
"Name": "provided-subnet-public",
},
ZoneType: ptr.To[infrav1.ZoneType]("availability-zone"),
ZoneType: ptr.To[infrav1.ZoneType]("availability-zone"),
},
{
ID: "subnet-2",
Expand All @@ -3273,10 +3222,7 @@
CidrBlock: "10.0.11.0/24",
IsPublic: false,
RouteTableID: aws.String("rtb-2"),
Tags: infrav1.Tags{
"Name": "provided-subnet-private",
},
ZoneType: ptr.To[infrav1.ZoneType]("availability-zone"),
ZoneType: ptr.To[infrav1.ZoneType]("availability-zone"),
},
},
},
Expand Down Expand Up @@ -3328,15 +3274,7 @@
}

if !cmp.Equal(sn, exp) {
expected, err := json.MarshalIndent(exp, "", "\t")
if err != nil {
t.Fatalf("got an unexpected error: %v", err)
}
actual, err := json.MarshalIndent(sn, "", "\t")
if err != nil {
t.Fatalf("got an unexpected error: %v", err)
}
t.Errorf("Expected %s, got %s", string(expected), string(actual))
t.Errorf("Expected subnets to be equal. Diff %s", cmp.Diff(sn, exp))
}
delete(out, exp.ID)
}
Expand Down
Loading