Skip to content

Commit d05f2c1

Browse files
author
David Mather
committed
Allow multiple security group filter matches
1 parent ad99a1d commit d05f2c1

File tree

4 files changed

+76
-30
lines changed

4 files changed

+76
-30
lines changed

controllers/awsmachine_controller_unit_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ func TestAWSMachineReconciler(t *testing.T) {
914914
g.Expect(ms.AWSMachine.Finalizers).To(ContainElement(infrav1.MachineFinalizer))
915915
expectConditions(g, ms.AWSMachine, []conditionAssertion{{infrav1.SecurityGroupsReadyCondition, corev1.ConditionFalse, clusterv1.ConditionSeverityError, infrav1.SecurityGroupsFailedReason}})
916916
})
917-
t.Run("Should return silently if ensureSecurityGroups fails to fetch additional security groups", func(t *testing.T) {
917+
t.Run("Should fail if ensureSecurityGroups fails to fetch additional security groups", func(t *testing.T) {
918918
g := NewWithT(t)
919919
awsMachine := getAWSMachine()
920920
setup(t, g, awsMachine)
@@ -940,9 +940,9 @@ func TestAWSMachineReconciler(t *testing.T) {
940940
ec2Svc.EXPECT().GetAdditionalSecurityGroupsIDs(gomock.Any()).Return([]string{"sg-1"}, errors.New("failed to get filtered SGs"))
941941

942942
_, err := reconciler.reconcileNormal(context.Background(), ms, cs, cs, cs, cs)
943-
g.Expect(err).To(BeNil())
943+
g.Expect(err).ToNot(BeNil())
944944
g.Expect(ms.AWSMachine.Finalizers).To(ContainElement(infrav1.MachineFinalizer))
945-
expectConditions(g, ms.AWSMachine, []conditionAssertion{{infrav1.SecurityGroupsReadyCondition, corev1.ConditionTrue, "", ""}})
945+
expectConditions(g, ms.AWSMachine, []conditionAssertion{{infrav1.SecurityGroupsReadyCondition, corev1.ConditionFalse, clusterv1.ConditionSeverityError, infrav1.SecurityGroupsFailedReason}})
946946
})
947947
t.Run("Should fail to update security group", func(t *testing.T) {
948948
g := NewWithT(t)

controllers/awsmachine_security_groups.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func (r *AWSMachineReconciler) ensureSecurityGroups(ec2svc service.EC2Interface,
5151

5252
additionalSecurityGroupsIDs, err := ec2svc.GetAdditionalSecurityGroupsIDs(additional)
5353
if err != nil {
54-
return false, nil //nolint:nilerr
54+
return false, err
5555
}
5656

5757
changed, ids := r.securityGroupsChanged(annotation, core, additionalSecurityGroupsIDs, existing)

pkg/cloud/services/ec2/instances_test.go

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,10 +3079,10 @@ func TestGetFilteredSecurityGroupID(t *testing.T) {
30793079
name string
30803080
securityGroup infrav1.AWSResourceReference
30813081
expect func(m *mocks.MockEC2APIMockRecorder)
3082-
check func(id string, err error)
3082+
check func(ids []string, err error)
30833083
}{
30843084
{
3085-
name: "successfully return security group id",
3085+
name: "successfully return single security group id",
30863086
securityGroup: infrav1.AWSResourceReference{
30873087
Filters: []infrav1.Filter{
30883088
{
@@ -3107,27 +3107,71 @@ func TestGetFilteredSecurityGroupID(t *testing.T) {
31073107
},
31083108
}, nil)
31093109
},
3110-
check: func(id string, err error) {
3110+
check: func(ids []string, err error) {
31113111
if err != nil {
31123112
t.Fatalf("did not expect error: %v", err)
31133113
}
31143114

3115-
if id != securityGroupID {
3116-
t.Fatalf("expected security group id %v but got: %v", securityGroupID, id)
3115+
if ids[0] != securityGroupID {
3116+
t.Fatalf("expected security group id %v but got: %v", securityGroupID, ids[0])
3117+
}
3118+
},
3119+
},
3120+
{
3121+
name: "allow returning multiple security groups",
3122+
securityGroup: infrav1.AWSResourceReference{
3123+
Filters: []infrav1.Filter{
3124+
{
3125+
Name: securityGroupFilterName, Values: securityGroupFilterValues,
3126+
},
3127+
},
3128+
},
3129+
expect: func(m *mocks.MockEC2APIMockRecorder) {
3130+
m.DescribeSecurityGroups(gomock.Eq(&ec2.DescribeSecurityGroupsInput{
3131+
Filters: []*ec2.Filter{
3132+
{
3133+
Name: aws.String(securityGroupFilterName),
3134+
Values: aws.StringSlice(securityGroupFilterValues),
3135+
},
3136+
},
3137+
})).Return(
3138+
&ec2.DescribeSecurityGroupsOutput{
3139+
SecurityGroups: []*ec2.SecurityGroup{
3140+
{
3141+
GroupId: aws.String(securityGroupID),
3142+
},
3143+
{
3144+
GroupId: aws.String(securityGroupID),
3145+
},
3146+
{
3147+
GroupId: aws.String(securityGroupID),
3148+
},
3149+
},
3150+
}, nil)
3151+
},
3152+
check: func(ids []string, err error) {
3153+
if err != nil {
3154+
t.Fatalf("did not expect error: %v", err)
3155+
}
3156+
3157+
for _, id := range ids {
3158+
if id != securityGroupID {
3159+
t.Fatalf("expected security group id %v but got: %v", securityGroupID, id)
3160+
}
31173161
}
31183162
},
31193163
},
31203164
{
31213165
name: "return early when filters are missing",
31223166
securityGroup: infrav1.AWSResourceReference{},
31233167
expect: func(m *mocks.MockEC2APIMockRecorder) {},
3124-
check: func(id string, err error) {
3168+
check: func(ids []string, err error) {
31253169
if err != nil {
31263170
t.Fatalf("did not expect error: %v", err)
31273171
}
31283172

3129-
if id != "" {
3130-
t.Fatalf("didn't expect secutity group id %v", id)
3173+
if len(ids) > 0 {
3174+
t.Fatalf("didn't expect security group ids %v", ids)
31313175
}
31323176
},
31333177
},
@@ -3150,14 +3194,14 @@ func TestGetFilteredSecurityGroupID(t *testing.T) {
31503194
},
31513195
})).Return(nil, errors.New("some error"))
31523196
},
3153-
check: func(id string, err error) {
3197+
check: func(_ []string, err error) {
31543198
if err == nil {
31553199
t.Fatalf("expected error but got none.")
31563200
}
31573201
},
31583202
},
31593203
{
3160-
name: "error when no security groups found",
3204+
name: "no error when no security groups found",
31613205
securityGroup: infrav1.AWSResourceReference{
31623206
Filters: []infrav1.Filter{
31633207
{
@@ -3178,9 +3222,12 @@ func TestGetFilteredSecurityGroupID(t *testing.T) {
31783222
SecurityGroups: []*ec2.SecurityGroup{},
31793223
}, nil)
31803224
},
3181-
check: func(id string, err error) {
3182-
if err == nil {
3183-
t.Fatalf("expected error but got none.")
3225+
check: func(ids []string, err error) {
3226+
if err != nil {
3227+
t.Fatalf("did not expect error: %v", err)
3228+
}
3229+
if len(ids) > 0 {
3230+
t.Fatalf("didn't expect security group ids %v", ids)
31843231
}
31853232
},
31863233
},
@@ -3195,8 +3242,8 @@ func TestGetFilteredSecurityGroupID(t *testing.T) {
31953242
EC2Client: ec2Mock,
31963243
}
31973244

3198-
id, err := s.getFilteredSecurityGroupID(tc.securityGroup)
3199-
tc.check(id, err)
3245+
ids, err := s.getFilteredSecurityGroupIDs(tc.securityGroup)
3246+
tc.check(ids, err)
32003247
})
32013248
}
32023249
}

pkg/cloud/services/ec2/launchtemplate.go

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package ec2
1919
import (
2020
"encoding/base64"
2121
"encoding/json"
22-
"fmt"
2322
"sort"
2423
"strconv"
2524
"strings"
@@ -772,12 +771,12 @@ func (s *Service) GetAdditionalSecurityGroupsIDs(securityGroups []infrav1.AWSRes
772771
if sg.ID != nil {
773772
additionalSecurityGroupsIDs = append(additionalSecurityGroupsIDs, *sg.ID)
774773
} else if sg.Filters != nil {
775-
id, err := s.getFilteredSecurityGroupID(sg)
774+
ids, err := s.getFilteredSecurityGroupIDs(sg)
776775
if err != nil {
777776
return nil, err
778777
}
779778

780-
additionalSecurityGroupsIDs = append(additionalSecurityGroupsIDs, id)
779+
additionalSecurityGroupsIDs = append(additionalSecurityGroupsIDs, ids...)
781780
}
782781
}
783782

@@ -822,10 +821,10 @@ func (s *Service) buildLaunchTemplateTagSpecificationRequest(scope scope.LaunchT
822821
return tagSpecifications
823822
}
824823

825-
// getFilteredSecurityGroupID get security group ID using filters.
826-
func (s *Service) getFilteredSecurityGroupID(securityGroup infrav1.AWSResourceReference) (string, error) {
824+
// getFilteredSecurityGroupIDs get security group IDs using filters.
825+
func (s *Service) getFilteredSecurityGroupIDs(securityGroup infrav1.AWSResourceReference) ([]string, error) {
827826
if securityGroup.Filters == nil {
828-
return "", nil
827+
return nil, nil
829828
}
830829

831830
filters := []*ec2.Filter{}
@@ -835,14 +834,14 @@ func (s *Service) getFilteredSecurityGroupID(securityGroup infrav1.AWSResourceRe
835834

836835
sgs, err := s.EC2Client.DescribeSecurityGroups(&ec2.DescribeSecurityGroupsInput{Filters: filters})
837836
if err != nil {
838-
return "", err
837+
return nil, err
839838
}
840-
841-
if len(sgs.SecurityGroups) == 0 {
842-
return "", fmt.Errorf("failed to find security group matching filters: %q, reason: %w", filters, err)
839+
ids := make([]string, 0, len(sgs.SecurityGroups))
840+
for _, sg := range sgs.SecurityGroups {
841+
ids = append(ids, *sg.GroupId)
843842
}
844843

845-
return *sgs.SecurityGroups[0].GroupId, nil
844+
return ids, nil
846845
}
847846

848847
func getLaunchTemplateInstanceMarketOptionsRequest(spotMarketOptions *infrav1.SpotMarketOptions) *ec2.LaunchTemplateInstanceMarketOptionsRequest {

0 commit comments

Comments
 (0)