Skip to content

Commit 3e96bda

Browse files
QxBytessivakami
authored andcommitted
feat: add cns iptables reconciliation (#3885)
* add iptables reconciliation to cns * add test to check if test case flushes chain * rename chain variable to reflect its value * address linter * nolint logger usage for now * address lll linter * remove code from startup as reconcile nnc always runs on cns startup
1 parent 5b2cc72 commit 3e96bda

File tree

4 files changed

+441
-90
lines changed

4 files changed

+441
-90
lines changed

cns/fakes/iptablesfake.go

Lines changed: 101 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package fakes
22

33
import (
44
"errors"
5+
"fmt"
56
"strings"
67

78
"github.com/Azure/azure-container-networking/iptables"
@@ -11,10 +12,13 @@ var (
1112
errChainExists = errors.New("chain already exists")
1213
errChainNotFound = errors.New("chain not found")
1314
errRuleExists = errors.New("rule already exists")
15+
errRuleNotFound = errors.New("rule not found")
16+
errIndexBounds = errors.New("index out of bounds")
1417
)
1518

1619
type IPTablesMock struct {
17-
state map[string]map[string][]string
20+
state map[string]map[string][]string
21+
clearChainCallCount int
1822
}
1923

2024
func NewIPTablesMock() *IPTablesMock {
@@ -83,21 +87,110 @@ func (c *IPTablesMock) Exists(table, chain string, rulespec ...string) (bool, er
8387
func (c *IPTablesMock) Append(table, chain string, rulespec ...string) error {
8488
c.ensureTableExists(table)
8589

90+
chainRules := c.state[table][chain]
91+
return c.Insert(table, chain, len(chainRules)+1, rulespec...)
92+
}
93+
94+
func (c *IPTablesMock) Insert(table, chain string, pos int, rulespec ...string) error {
95+
c.ensureTableExists(table)
96+
8697
chainExists, _ := c.ChainExists(table, chain)
8798
if !chainExists {
8899
return errChainNotFound
89100
}
90101

91-
ruleExists, _ := c.Exists(table, chain, rulespec...)
92-
if ruleExists {
93-
return errRuleExists
102+
targetRule := strings.Join(rulespec, " ")
103+
chainRules := c.state[table][chain]
104+
105+
// convert 1-based position to 0-based index
106+
index := pos - 1
107+
if index < 0 {
108+
index = 0
109+
}
110+
111+
switch {
112+
case index == len(chainRules):
113+
c.state[table][chain] = append(chainRules, targetRule)
114+
case index > len(chainRules):
115+
return errIndexBounds
116+
default:
117+
c.state[table][chain] = append(chainRules[:index], append([]string{targetRule}, chainRules[index:]...)...)
94118
}
95119

96-
targetRule := strings.Join(rulespec, " ")
97-
c.state[table][chain] = append(c.state[table][chain], targetRule)
98120
return nil
99121
}
100122

101-
func (c *IPTablesMock) Insert(table, chain string, _ int, rulespec ...string) error {
102-
return c.Append(table, chain, rulespec...)
123+
func (c *IPTablesMock) List(table, chain string) ([]string, error) {
124+
c.ensureTableExists(table)
125+
126+
chainExists, _ := c.ChainExists(table, chain)
127+
if !chainExists {
128+
return nil, errChainNotFound
129+
}
130+
131+
chainRules := c.state[table][chain]
132+
// preallocate: 1 for chain header + number of rules
133+
result := make([]string, 0, 1+len(chainRules))
134+
135+
// for built-in chains, start with policy -P, otherwise start with definition -N
136+
builtins := []string{iptables.Input, iptables.Output, iptables.Prerouting, iptables.Postrouting, iptables.Forward}
137+
isBuiltIn := false
138+
for _, builtin := range builtins {
139+
if chain == builtin {
140+
isBuiltIn = true
141+
break
142+
}
143+
}
144+
145+
if isBuiltIn {
146+
result = append(result, fmt.Sprintf("-P %s ACCEPT", chain))
147+
} else {
148+
result = append(result, "-N "+chain)
149+
}
150+
151+
// iptables with -S always outputs the rules in -A format
152+
for _, rule := range chainRules {
153+
result = append(result, fmt.Sprintf("-A %s %s", chain, rule))
154+
}
155+
156+
return result, nil
157+
}
158+
159+
func (c *IPTablesMock) ClearChain(table, chain string) error {
160+
c.clearChainCallCount++
161+
c.ensureTableExists(table)
162+
163+
chainExists, _ := c.ChainExists(table, chain)
164+
if !chainExists {
165+
return errChainNotFound
166+
}
167+
168+
c.state[table][chain] = []string{}
169+
return nil
170+
}
171+
172+
func (c *IPTablesMock) Delete(table, chain string, rulespec ...string) error {
173+
c.ensureTableExists(table)
174+
175+
chainExists, _ := c.ChainExists(table, chain)
176+
if !chainExists {
177+
return errChainNotFound
178+
}
179+
180+
targetRule := strings.Join(rulespec, " ")
181+
chainRules := c.state[table][chain]
182+
183+
// delete first match
184+
for i, rule := range chainRules {
185+
if rule == targetRule {
186+
c.state[table][chain] = append(chainRules[:i], chainRules[i+1:]...)
187+
return nil
188+
}
189+
}
190+
191+
return errRuleNotFound
192+
}
193+
194+
func (c *IPTablesMock) ClearChainCallCount() int {
195+
return c.clearChainCallCount
103196
}

cns/restserver/internalapi_linux.go

Lines changed: 78 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"github.com/pkg/errors"
1515
)
1616

17-
const SWIFT = "SWIFT-POSTROUTING"
17+
const SWIFTPOSTROUTING = "SWIFT-POSTROUTING"
1818

1919
type IPtablesProvider struct{}
2020

@@ -37,32 +37,62 @@ func (service *HTTPRestService) programSNATRules(req *cns.CreateNetworkContainer
3737
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to create iptables interface : %v", err)
3838
}
3939

40-
chainExist, err := ipt.ChainExists(iptables.Nat, SWIFT)
40+
chainExist, err := ipt.ChainExists(iptables.Nat, SWIFTPOSTROUTING)
4141
if err != nil {
42-
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT chain: %v", err)
42+
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT-POSTROUTING chain: %v", err)
4343
}
4444
if !chainExist { // create and append chain if it doesn't exist
45-
logger.Printf("[Azure CNS] Creating SWIFT Chain ...")
46-
err = ipt.NewChain(iptables.Nat, SWIFT)
45+
logger.Printf("[Azure CNS] Creating SWIFT-POSTROUTING Chain ...")
46+
err = ipt.NewChain(iptables.Nat, SWIFTPOSTROUTING)
4747
if err != nil {
48-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to create SWIFT chain : " + err.Error()
49-
}
50-
logger.Printf("[Azure CNS] Append SWIFT Chain to POSTROUTING ...")
51-
err = ipt.Append(iptables.Nat, iptables.Postrouting, "-j", SWIFT)
52-
if err != nil {
53-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append SWIFT chain : " + err.Error()
48+
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to create SWIFT-POSTROUTING chain : " + err.Error()
5449
}
5550
}
5651

57-
postroutingToSwiftJumpexist, err := ipt.Exists(iptables.Nat, iptables.Postrouting, "-j", SWIFT)
52+
// reconcile jump to SWIFT-POSTROUTING chain
53+
rules, err := ipt.List(iptables.Nat, iptables.Postrouting)
5854
if err != nil {
59-
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of POSTROUTING to SWIFT chain jump: %v", err)
55+
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check rules in postrouting chain of nat table: %v", err)
56+
}
57+
swiftRuleIndex := len(rules) // append if neither jump rule from POSTROUTING is found
58+
// one time migration from old SWIFT chain
59+
// previously, CNI may have a jump to the SWIFT chain-- our jump to SWIFT-POSTROUTING needs to happen first
60+
for index, rule := range rules {
61+
if rule == "-A POSTROUTING -j SWIFT" {
62+
// jump to SWIFT comes before jump to SWIFT-POSTROUTING, so potential reordering required
63+
swiftRuleIndex = index
64+
break
65+
}
66+
if rule == "-A POSTROUTING -j SWIFT-POSTROUTING" {
67+
// jump to SWIFT-POSTROUTING comes before jump to SWIFT, which requires no further action
68+
swiftRuleIndex = -1
69+
break
70+
}
6071
}
61-
if !postroutingToSwiftJumpexist {
62-
logger.Printf("[Azure CNS] Append SWIFT Chain to POSTROUTING ...")
63-
err = ipt.Append(iptables.Nat, iptables.Postrouting, "-j", SWIFT)
72+
if swiftRuleIndex != -1 {
73+
// jump SWIFT rule exists, insert SWIFT-POSTROUTING rule at the same position so it ends up running first
74+
// first, remove any existing SWIFT-POSTROUTING rules to avoid duplicates
75+
// note: inserting at len(rules) and deleting a jump to SWIFT-POSTROUTING is mutually exclusive
76+
swiftPostroutingExists, err := ipt.Exists(iptables.Nat, iptables.Postrouting, "-j", SWIFTPOSTROUTING)
77+
if err != nil {
78+
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT-POSTROUTING rule: %v", err)
79+
}
80+
if swiftPostroutingExists {
81+
err = ipt.Delete(iptables.Nat, iptables.Postrouting, "-j", SWIFTPOSTROUTING)
82+
if err != nil {
83+
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to delete existing SWIFT-POSTROUTING rule : " + err.Error()
84+
}
85+
}
86+
87+
// slice index is 0-based, iptables insert is 1-based, but list also gives us the -P POSTROUTING ACCEPT
88+
// as the first rule so swiftRuleIndex gives us the correct 1-indexed iptables position.
89+
// Example:
90+
// -P POSTROUTING ACCEPT is at swiftRuleIndex 0
91+
// -A POSTROUTING -j SWIFT is at swiftRuleIndex 1, and iptables index 1
92+
logger.Printf("[Azure CNS] Inserting SWIFT-POSTROUTING Chain at iptables position %d", swiftRuleIndex)
93+
err = ipt.Insert(iptables.Nat, iptables.Postrouting, swiftRuleIndex, "-j", SWIFTPOSTROUTING)
6494
if err != nil {
65-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append SWIFT chain : " + err.Error()
95+
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert SWIFT-POSTROUTING chain : " + err.Error()
6696
}
6797
}
6898

@@ -71,39 +101,47 @@ func (service *HTTPRestService) programSNATRules(req *cns.CreateNetworkContainer
71101
// put the ip address in standard cidr form (where we zero out the parts that are not relevant)
72102
_, podSubnet, _ := net.ParseCIDR(v.IPAddress + "/" + fmt.Sprintf("%d", req.IPConfiguration.IPSubnet.PrefixLength))
73103

74-
snatUDPRuleExists, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())
75-
if err != nil {
76-
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT UDP rule : %v", err)
104+
// define all rules we want in the chain
105+
rules := [][]string{
106+
{"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()},
107+
{"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()},
108+
{"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP},
77109
}
78-
if !snatUDPRuleExists {
79-
logger.Printf("[Azure CNS] Inserting pod SNAT UDP rule ...")
80-
err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())
110+
111+
// check if all rules exist
112+
allRulesExist := true
113+
for _, rule := range rules {
114+
exists, err := ipt.Exists(iptables.Nat, SWIFTPOSTROUTING, rule...)
81115
if err != nil {
82-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT UDP rule : " + err.Error()
116+
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of rule: %v", err)
117+
}
118+
if !exists {
119+
allRulesExist = false
120+
break
83121
}
84122
}
85123

86-
snatPodTCPRuleExists, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())
124+
// get current rule count in SWIFT-POSTROUTING chain
125+
currentRules, err := ipt.List(iptables.Nat, SWIFTPOSTROUTING)
87126
if err != nil {
88-
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT TCP rule : %v", err)
127+
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to list rules in SWIFT-POSTROUTING chain: %v", err)
89128
}
90-
if !snatPodTCPRuleExists {
91-
logger.Printf("[Azure CNS] Inserting pod SNAT TCP rule ...")
92-
err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())
129+
130+
// if rule count doesn't match or not all rules exist, reconcile
131+
// add one because there is always a singular starting rule in the chain, in addition to the ones we add
132+
if len(currentRules) != len(rules)+1 || !allRulesExist {
133+
logger.Printf("[Azure CNS] Reconciling SWIFT-POSTROUTING chain rules")
134+
135+
err = ipt.ClearChain(iptables.Nat, SWIFTPOSTROUTING)
93136
if err != nil {
94-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT TCP rule : " + err.Error()
137+
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to flush SWIFT-POSTROUTING chain : " + err.Error()
95138
}
96-
}
97139

98-
snatIMDSRuleexist, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP)
99-
if err != nil {
100-
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT IMDS rule : %v", err)
101-
}
102-
if !snatIMDSRuleexist {
103-
logger.Printf("[Azure CNS] Inserting pod SNAT IMDS rule ...")
104-
err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP)
105-
if err != nil {
106-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT IMDS rule : " + err.Error()
140+
for _, rule := range rules {
141+
err = ipt.Append(iptables.Nat, SWIFTPOSTROUTING, rule...)
142+
if err != nil {
143+
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append rule to SWIFT-POSTROUTING chain : " + err.Error()
144+
}
107145
}
108146
}
109147

0 commit comments

Comments
 (0)