Skip to content

Commit f974fb3

Browse files
committed
add iptables reconciliation to cns
1 parent ef97f2a commit f974fb3

File tree

5 files changed

+408
-85
lines changed

5 files changed

+408
-85
lines changed

cns/fakes/iptablesfake.go

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package fakes
22

33
import (
44
"errors"
5+
"fmt"
56
"strings"
67

78
"github.com/Azure/azure-container-networking/iptables"
@@ -11,6 +12,7 @@ var (
1112
errChainExists = errors.New("chain already exists")
1213
errChainNotFound = errors.New("chain not found")
1314
errRuleExists = errors.New("rule already exists")
15+
errRuleNotFound = errors.New("rule not found")
1416
)
1517

1618
type IPTablesMock struct {
@@ -83,21 +85,101 @@ func (c *IPTablesMock) Exists(table, chain string, rulespec ...string) (bool, er
8385
func (c *IPTablesMock) Append(table, chain string, rulespec ...string) error {
8486
c.ensureTableExists(table)
8587

88+
chainRules := c.state[table][chain]
89+
return c.Insert(table, chain, len(chainRules)+1, rulespec...)
90+
}
91+
92+
func (c *IPTablesMock) Insert(table, chain string, pos int, rulespec ...string) error {
93+
c.ensureTableExists(table)
94+
8695
chainExists, _ := c.ChainExists(table, chain)
8796
if !chainExists {
8897
return errChainNotFound
8998
}
9099

91-
ruleExists, _ := c.Exists(table, chain, rulespec...)
92-
if ruleExists {
93-
return errRuleExists
100+
targetRule := strings.Join(rulespec, " ")
101+
chainRules := c.state[table][chain]
102+
103+
// convert 1-based position to 0-based index
104+
index := pos - 1
105+
if index < 0 {
106+
index = 0
107+
}
108+
109+
if index >= len(chainRules) {
110+
c.state[table][chain] = append(chainRules, targetRule)
111+
} else {
112+
c.state[table][chain] = append(chainRules[:index], append([]string{targetRule}, chainRules[index:]...)...)
94113
}
95114

96-
targetRule := strings.Join(rulespec, " ")
97-
c.state[table][chain] = append(c.state[table][chain], targetRule)
98115
return nil
99116
}
100117

101-
func (c *IPTablesMock) Insert(table, chain string, _ int, rulespec ...string) error {
102-
return c.Append(table, chain, rulespec...)
118+
func (c *IPTablesMock) List(table, chain string) ([]string, error) {
119+
c.ensureTableExists(table)
120+
121+
chainExists, _ := c.ChainExists(table, chain)
122+
if !chainExists {
123+
return nil, errChainNotFound
124+
}
125+
126+
var result []string
127+
128+
// for built-in chains, start with policy -P, otherwise start with definition -N
129+
builtins := []string{iptables.Input, iptables.Output, iptables.Prerouting, iptables.Postrouting, iptables.Forward}
130+
isBuiltIn := false
131+
for _, builtin := range builtins {
132+
if chain == builtin {
133+
isBuiltIn = true
134+
break
135+
}
136+
}
137+
138+
if isBuiltIn {
139+
result = append(result, fmt.Sprintf("-P %s ACCEPT", chain))
140+
} else {
141+
result = append(result, fmt.Sprintf("-N %s", chain))
142+
}
143+
144+
// iptables with -S always outputs the rules in -A format
145+
chainRules := c.state[table][chain]
146+
for _, rule := range chainRules {
147+
result = append(result, fmt.Sprintf("-A %s %s", chain, rule))
148+
}
149+
150+
return result, nil
151+
}
152+
153+
func (c *IPTablesMock) ClearChain(table, chain string) error {
154+
c.ensureTableExists(table)
155+
156+
chainExists, _ := c.ChainExists(table, chain)
157+
if !chainExists {
158+
return errChainNotFound
159+
}
160+
161+
c.state[table][chain] = []string{}
162+
return nil
163+
}
164+
165+
func (c *IPTablesMock) Delete(table, chain string, rulespec ...string) error {
166+
c.ensureTableExists(table)
167+
168+
chainExists, _ := c.ChainExists(table, chain)
169+
if !chainExists {
170+
return errChainNotFound
171+
}
172+
173+
targetRule := strings.Join(rulespec, " ")
174+
chainRules := c.state[table][chain]
175+
176+
// delete first match
177+
for i, rule := range chainRules {
178+
if rule == targetRule {
179+
c.state[table][chain] = append(chainRules[:i], chainRules[i+1:]...)
180+
return nil
181+
}
182+
}
183+
184+
return errRuleNotFound
103185
}

cns/restserver/internalapi_linux.go

Lines changed: 73 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -39,30 +39,59 @@ func (service *HTTPRestService) programSNATRules(req *cns.CreateNetworkContainer
3939

4040
chainExist, err := ipt.ChainExists(iptables.Nat, SWIFT)
4141
if err != nil {
42-
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT chain: %v", err)
42+
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT-POSTROUTING chain: %v", err)
4343
}
4444
if !chainExist { // create and append chain if it doesn't exist
4545
logger.Printf("[Azure CNS] Creating SWIFT Chain ...")
4646
err = ipt.NewChain(iptables.Nat, SWIFT)
4747
if err != nil {
48-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to create SWIFT chain : " + err.Error()
49-
}
50-
logger.Printf("[Azure CNS] Append SWIFT Chain to POSTROUTING ...")
51-
err = ipt.Append(iptables.Nat, iptables.Postrouting, "-j", SWIFT)
52-
if err != nil {
53-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append SWIFT chain : " + err.Error()
48+
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to create SWIFT-POSTROUTING chain : " + err.Error()
5449
}
5550
}
5651

57-
postroutingToSwiftJumpexist, err := ipt.Exists(iptables.Nat, iptables.Postrouting, "-j", SWIFT)
52+
// reconcile jump to SWIFT-POSTROUTING chain
53+
rules, err := ipt.List(iptables.Nat, iptables.Postrouting)
5854
if err != nil {
59-
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of POSTROUTING to SWIFT chain jump: %v", err)
55+
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check rules in postrouting chain of nat table: %v", err)
56+
}
57+
swiftRuleIndex := len(rules) // append if neither jump rule from POSTROUTING is found
58+
// one time migration from old SWIFT chain
59+
// previously, CNI may have a jump to the SWIFT chain-- our jump to SWIFT-POSTROUTING needs to happen first
60+
for index, rule := range rules {
61+
if rule == "-A POSTROUTING -j SWIFT" {
62+
// jump to SWIFT comes before jump to SWIFT-POSTROUTING, so potential reordering required
63+
swiftRuleIndex = index
64+
break
65+
}
66+
if rule == "-A POSTROUTING -j SWIFT-POSTROUTING" {
67+
// jump to SWIFT-POSTROUTING comes before jump to SWIFT, which requires no further action
68+
swiftRuleIndex = -1
69+
break
70+
}
6071
}
61-
if !postroutingToSwiftJumpexist {
62-
logger.Printf("[Azure CNS] Append SWIFT Chain to POSTROUTING ...")
63-
err = ipt.Append(iptables.Nat, iptables.Postrouting, "-j", SWIFT)
72+
if swiftRuleIndex != -1 {
73+
// jump SWIFT rule exists, insert SWIFT-POSTROUTING rule at the same position so it ends up running first
74+
// first, remove any existing SWIFT-POSTROUTING rules to avoid duplicates
75+
swiftPostroutingExists, err := ipt.Exists(iptables.Nat, iptables.Postrouting, "-j", SWIFT)
76+
if err != nil {
77+
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT-POSTROUTING rule: %v", err)
78+
}
79+
if swiftPostroutingExists {
80+
err = ipt.Delete(iptables.Nat, iptables.Postrouting, "-j", SWIFT)
81+
if err != nil {
82+
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to delete existing SWIFT-POSTROUTING rule : " + err.Error()
83+
}
84+
}
85+
86+
// slice index is 0-based, iptables insert is 1-based, but list also gives us the -P POSTROUTING ACCEPT
87+
// as the first rule so swiftRuleIndex gives us the correct 1-indexed iptables position.
88+
// Example:
89+
// -P POSTROUTING ACCEPT is at swiftRuleIndex 0
90+
// -A POSTROUTING -j SWIFT is at swiftRuleIndex 1, and iptables index 1
91+
logger.Printf("[Azure CNS] Inserting SWIFT-POSTROUTING Chain at iptables position %d", swiftRuleIndex)
92+
err = ipt.Insert(iptables.Nat, iptables.Postrouting, swiftRuleIndex, "-j", SWIFT)
6493
if err != nil {
65-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append SWIFT chain : " + err.Error()
94+
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert SWIFT-POSTROUTING chain : " + err.Error()
6695
}
6796
}
6897

@@ -71,39 +100,47 @@ func (service *HTTPRestService) programSNATRules(req *cns.CreateNetworkContainer
71100
// put the ip address in standard cidr form (where we zero out the parts that are not relevant)
72101
_, podSubnet, _ := net.ParseCIDR(v.IPAddress + "/" + fmt.Sprintf("%d", req.IPConfiguration.IPSubnet.PrefixLength))
73102

74-
snatUDPRuleExists, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())
75-
if err != nil {
76-
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT UDP rule : %v", err)
103+
// define all rules we want in the chain
104+
rules := [][]string{
105+
{"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()},
106+
{"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()},
107+
{"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP},
77108
}
78-
if !snatUDPRuleExists {
79-
logger.Printf("[Azure CNS] Inserting pod SNAT UDP rule ...")
80-
err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())
109+
110+
// check if all rules exist
111+
allRulesExist := true
112+
for _, rule := range rules {
113+
exists, err := ipt.Exists(iptables.Nat, SWIFT, rule...)
81114
if err != nil {
82-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT UDP rule : " + err.Error()
115+
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of rule: %v", err)
116+
}
117+
if !exists {
118+
allRulesExist = false
119+
break
83120
}
84121
}
85122

86-
snatPodTCPRuleExists, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())
123+
// get current rule count in SWIFT-POSTROUTING chain
124+
currentRules, err := ipt.List(iptables.Nat, SWIFT)
87125
if err != nil {
88-
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT TCP rule : %v", err)
126+
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to list rules in SWIFT-POSTROUTING chain: %v", err)
89127
}
90-
if !snatPodTCPRuleExists {
91-
logger.Printf("[Azure CNS] Inserting pod SNAT TCP rule ...")
92-
err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())
128+
129+
// if rule count doesn't match or not all rules exist, reconcile
130+
// add one because there is always a singular starting rule in the chain, in addition to the ones we add
131+
if len(currentRules) != len(rules)+1 || !allRulesExist {
132+
logger.Printf("[Azure CNS] Reconciling SWIFT-POSTROUTING chain rules")
133+
134+
err = ipt.ClearChain(iptables.Nat, SWIFT)
93135
if err != nil {
94-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT TCP rule : " + err.Error()
136+
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to flush SWIFT-POSTROUTING chain : " + err.Error()
95137
}
96-
}
97138

98-
snatIMDSRuleexist, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP)
99-
if err != nil {
100-
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT IMDS rule : %v", err)
101-
}
102-
if !snatIMDSRuleexist {
103-
logger.Printf("[Azure CNS] Inserting pod SNAT IMDS rule ...")
104-
err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP)
105-
if err != nil {
106-
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT IMDS rule : " + err.Error()
139+
for _, rule := range rules {
140+
err = ipt.Append(iptables.Nat, SWIFT, rule...)
141+
if err != nil {
142+
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append rule to SWIFT-POSTROUTING chain : " + err.Error()
143+
}
107144
}
108145
}
109146

0 commit comments

Comments
 (0)