Skip to content

Commit a2760be

Browse files
authored
Merge pull request #592 from spectrocloud/NSGUpdateIssue
Update NSG only if default rules are not present, or else skip the update
2 parents 2f3e9dd + 5d80ea5 commit a2760be

File tree

3 files changed

+117
-49
lines changed

3 files changed

+117
-49
lines changed

cloud/services/securitygroups/client.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package securitygroups
1818

1919
import (
2020
"context"
21-
2221
"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2019-06-01/network"
2322
"github.com/Azure/go-autorest/autorest"
2423
azure "sigs.k8s.io/cluster-api-provider-azure/cloud"
@@ -59,10 +58,25 @@ func (ac *AzureClient) Get(ctx context.Context, resourceGroupName, sgName string
5958

6059
// CreateOrUpdate creates or updates a network security group in the specified resource group.
6160
func (ac *AzureClient) CreateOrUpdate(ctx context.Context, resourceGroupName string, sgName string, sg network.SecurityGroup) error {
62-
future, err := ac.securitygroups.CreateOrUpdate(ctx, resourceGroupName, sgName, sg)
61+
var etag string
62+
if sg.Etag != nil {
63+
etag = *sg.Etag
64+
}
65+
req, err := ac.securitygroups.CreateOrUpdatePreparer(ctx, resourceGroupName, sgName, sg)
66+
if err != nil {
67+
err = autorest.NewErrorWithError(err, "network.SecurityGroupsClient", "CreateOrUpdate", nil, "Failure preparing request")
68+
return err
69+
}
70+
if etag != "" {
71+
req.Header.Add("If-Match", etag)
72+
}
73+
74+
future, err := ac.securitygroups.CreateOrUpdateSender(req)
6375
if err != nil {
76+
err = autorest.NewErrorWithError(err, "network.SecurityGroupsClient", "CreateOrUpdate", future.Response(), "Failure sending request")
6477
return err
6578
}
79+
6680
err = future.WaitForCompletionRef(ctx, ac.securitygroups.Client)
6781
if err != nil {
6882
return err

cloud/services/securitygroups/securitygroups.go

Lines changed: 92 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package securitygroups
1919
import (
2020
"context"
2121
"strconv"
22+
"strings"
2223

2324
"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2019-06-01/network"
2425
"github.com/Azure/go-autorest/autorest/to"
@@ -27,6 +28,11 @@ import (
2728
azure "sigs.k8s.io/cluster-api-provider-azure/cloud"
2829
)
2930

31+
const (
32+
apiServerRule = "apiServerRule"
33+
sshRule = "sshRule"
34+
)
35+
3036
// Spec specification for network security groups
3137
type Spec struct {
3238
Name string
@@ -59,52 +65,59 @@ func (s *Service) Reconcile(ctx context.Context, spec interface{}) error {
5965
return errors.New("invalid security groups specification")
6066
}
6167

62-
securityRules := &[]network.SecurityRule{}
68+
securityGroup, err := s.Client.Get(ctx, s.Scope.ResourceGroup(), nsgSpec.Name)
69+
if err != nil && !azure.ResourceNotFound(err) {
70+
return errors.Wrapf(err, "failed to get NSG %s in %s", nsgSpec.Name, s.Scope.ResourceGroup())
71+
}
72+
73+
nsgExists := false
74+
securityRules := make([]network.SecurityRule, 0)
75+
if securityGroup.Name != nil {
76+
nsgExists = true
77+
securityRules = *securityGroup.SecurityRules
78+
}
79+
80+
defaultRules := make(map[string]network.SecurityRule, 0)
81+
defaultRules[sshRule] = getRule("allow_ssh", "22", 100)
82+
defaultRules[apiServerRule] = getRule("allow_6443", strconv.Itoa(int(s.Scope.APIServerPort())), 101)
6383

6484
if nsgSpec.IsControlPlane {
65-
klog.V(2).Infof("using additional rules for control plane %s", nsgSpec.Name)
66-
securityRules = &[]network.SecurityRule{
67-
{
68-
Name: to.StringPtr("allow_ssh"),
69-
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
70-
Protocol: network.SecurityRuleProtocolTCP,
71-
SourceAddressPrefix: to.StringPtr("*"),
72-
SourcePortRange: to.StringPtr("*"),
73-
DestinationAddressPrefix: to.StringPtr("*"),
74-
DestinationPortRange: to.StringPtr("22"),
75-
Access: network.SecurityRuleAccessAllow,
76-
Direction: network.SecurityRuleDirectionInbound,
77-
Priority: to.Int32Ptr(100),
78-
},
79-
},
80-
{
81-
Name: to.StringPtr("allow_6443"),
82-
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
83-
Protocol: network.SecurityRuleProtocolTCP,
84-
SourceAddressPrefix: to.StringPtr("*"),
85-
SourcePortRange: to.StringPtr("*"),
86-
DestinationAddressPrefix: to.StringPtr("*"),
87-
DestinationPortRange: to.StringPtr(strconv.Itoa(int(s.Scope.APIServerPort()))),
88-
Access: network.SecurityRuleAccessAllow,
89-
Direction: network.SecurityRuleDirectionInbound,
90-
Priority: to.Int32Ptr(101),
91-
},
92-
},
85+
if nsgExists {
86+
// Check if the expected rules are present
87+
update := false
88+
for _, rule := range defaultRules {
89+
if !ruleExists(securityRules, rule) {
90+
update = true
91+
securityRules = append(securityRules, rule)
92+
}
93+
}
94+
if !update {
95+
// Skip update for control-plane NSG as the required default rules are present
96+
klog.V(2).Infof("security group %s exists and no default rules are missing, skipping update", nsgSpec.Name)
97+
return nil
98+
}
99+
} else {
100+
klog.V(2).Infof("applying missing default rules for control plane NSG %s", nsgSpec.Name)
101+
securityRules = append(securityRules, defaultRules[sshRule], defaultRules[apiServerRule])
93102
}
103+
} else if nsgExists {
104+
// Skip update for node NSG as no default rules are required
105+
klog.V(2).Infof("security group %s exists and no default rules are required, skipping update", nsgSpec.Name)
106+
return nil
94107
}
95108

96-
klog.V(2).Infof("creating security group %s", nsgSpec.Name)
97-
err := s.Client.CreateOrUpdate(
98-
ctx,
99-
s.Scope.ResourceGroup(),
100-
nsgSpec.Name,
101-
network.SecurityGroup{
102-
Location: to.StringPtr(s.Scope.Location()),
103-
SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{
104-
SecurityRules: securityRules,
105-
},
109+
sg := network.SecurityGroup{
110+
Location: to.StringPtr(s.Scope.Location()),
111+
SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{
112+
SecurityRules: &securityRules,
106113
},
107-
)
114+
}
115+
if nsgExists {
116+
// We append the existing NSG etag to the header to ensure we only apply the updates if the NSG has not been modified.
117+
sg.Etag = securityGroup.Etag
118+
}
119+
klog.V(2).Infof("creating security group %s", nsgSpec.Name)
120+
err = s.Client.CreateOrUpdate(ctx, s.Scope.ResourceGroup(), nsgSpec.Name, sg)
108121
if err != nil {
109122
return errors.Wrapf(err, "failed to create security group %s in resource group %s", nsgSpec.Name, s.Scope.ResourceGroup())
110123
}
@@ -113,6 +126,45 @@ func (s *Service) Reconcile(ctx context.Context, spec interface{}) error {
113126
return err
114127
}
115128

129+
func ruleExists(rules []network.SecurityRule, rule network.SecurityRule) bool {
130+
for _, existingRule := range rules {
131+
if !strings.EqualFold(to.String(existingRule.Name), to.String(rule.Name)) {
132+
continue
133+
}
134+
if !strings.EqualFold(to.String(existingRule.DestinationPortRange), to.String(rule.DestinationPortRange)) {
135+
continue
136+
}
137+
if existingRule.Protocol != network.SecurityRuleProtocolTCP &&
138+
existingRule.Access != network.SecurityRuleAccessAllow &&
139+
existingRule.Direction != network.SecurityRuleDirectionInbound {
140+
continue
141+
}
142+
if !strings.EqualFold(to.String(existingRule.SourcePortRange), "*") &&
143+
!strings.EqualFold(to.String(existingRule.SourceAddressPrefix), "*") &&
144+
!strings.EqualFold(to.String(existingRule.DestinationAddressPrefix), "*") {
145+
continue
146+
}
147+
return true
148+
}
149+
return false
150+
}
151+
152+
func getRule(name, destinationPort string, priority int32) network.SecurityRule {
153+
return network.SecurityRule{
154+
Name: to.StringPtr(name),
155+
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
156+
Protocol: network.SecurityRuleProtocolTCP,
157+
SourceAddressPrefix: to.StringPtr("*"),
158+
SourcePortRange: to.StringPtr("*"),
159+
DestinationAddressPrefix: to.StringPtr("*"),
160+
DestinationPortRange: to.StringPtr(destinationPort),
161+
Access: network.SecurityRuleAccessAllow,
162+
Direction: network.SecurityRuleDirectionInbound,
163+
Priority: to.Int32Ptr(priority),
164+
},
165+
}
166+
}
167+
116168
// Delete deletes the network security group with the provided name.
117169
func (s *Service) Delete(ctx context.Context, spec interface{}) error {
118170
nsgSpec, ok := spec.(*Spec)

cloud/services/securitygroups/securitygroups_test.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,30 +52,32 @@ func TestReconcileSecurityGroups(t *testing.T) {
5252
sgName string
5353
isControlPlane bool
5454
vnetSpec *infrav1.VnetSpec
55-
expect func(m *mock_securitygroups.MockClientMockRecorder)
55+
expect func(m *mock_securitygroups.MockClientMockRecorder, m1 *mock_securitygroups.MockClientMockRecorder)
5656
}{
5757
{
5858
name: "security group does not exists",
5959
sgName: "my-sg",
6060
isControlPlane: true,
6161
vnetSpec: &infrav1.VnetSpec{},
62-
expect: func(m *mock_securitygroups.MockClientMockRecorder) {
63-
m.CreateOrUpdate(context.TODO(), "my-rg", "my-sg", gomock.AssignableToTypeOf(network.SecurityGroup{}))
62+
expect: func(m *mock_securitygroups.MockClientMockRecorder, m1 *mock_securitygroups.MockClientMockRecorder) {
63+
m.Get(context.TODO(), "my-rg", "my-sg")
64+
m1.CreateOrUpdate(context.TODO(), "my-rg", "my-sg", gomock.AssignableToTypeOf(network.SecurityGroup{}))
6465
},
6566
}, {
6667
name: "security group does not exist and it's not for a control plane",
6768
sgName: "my-sg",
6869
isControlPlane: false,
6970
vnetSpec: &infrav1.VnetSpec{},
70-
expect: func(m *mock_securitygroups.MockClientMockRecorder) {
71-
m.CreateOrUpdate(context.TODO(), "my-rg", "my-sg", gomock.AssignableToTypeOf(network.SecurityGroup{}))
71+
expect: func(m *mock_securitygroups.MockClientMockRecorder, m1 *mock_securitygroups.MockClientMockRecorder) {
72+
m.Get(context.TODO(), "my-rg", "my-sg")
73+
m1.CreateOrUpdate(context.TODO(), "my-rg", "my-sg", gomock.AssignableToTypeOf(network.SecurityGroup{}))
7274
},
7375
}, {
7476
name: "skipping network security group reconcile in custom vnet mode",
7577
sgName: "my-sg",
7678
isControlPlane: false,
7779
vnetSpec: &infrav1.VnetSpec{ResourceGroup: "custom-vnet-rg", Name: "custom-vnet", ID: "id1"},
78-
expect: func(m *mock_securitygroups.MockClientMockRecorder) {
80+
expect: func(m *mock_securitygroups.MockClientMockRecorder, m1 *mock_securitygroups.MockClientMockRecorder) {
7981

8082
},
8183
},
@@ -91,7 +93,7 @@ func TestReconcileSecurityGroups(t *testing.T) {
9193

9294
client := fake.NewFakeClient(cluster)
9395

94-
tc.expect(sgMock.EXPECT())
96+
tc.expect(sgMock.EXPECT(), sgMock.EXPECT())
9597

9698
clusterScope, err := scope.NewClusterScope(scope.ClusterScopeParams{
9799
AzureClients: scope.AzureClients{

0 commit comments

Comments
 (0)