Skip to content

Commit 943c1ae

Browse files
authored
fix: parse subnet cidr and calculate gateway (#1064)
1 parent e692542 commit 943c1ae

File tree

3 files changed

+136
-42
lines changed

3 files changed

+136
-42
lines changed

cns/imdsclient/imdsclient.go

Lines changed: 73 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1-
// Copyright 2017 Microsoft. All rights reserved.
2-
// MIT License
3-
41
package imdsclient
52

63
import (
4+
"bytes"
75
"encoding/json"
86
"encoding/xml"
97
"fmt"
8+
"io"
109
"math"
10+
"net"
1111
"net/http"
12-
"strings"
1312

1413
"github.com/Azure/azure-container-networking/cns/logger"
14+
"github.com/pkg/errors"
15+
)
16+
17+
var (
18+
// ErrNoPrimaryInterface indicates the imds respnose does not have a primary interface indicated.
19+
ErrNoPrimaryInterface = errors.New("no primary interface found")
20+
// ErrInsufficientAddressSpace indicates that the CIDR space is too small to include a gateway IP; it is 1 IP.
21+
ErrInsufficientAddressSpace = errors.New("insufficient address space to generate gateway IP")
1522
)
1623

1724
// GetNetworkContainerInfoFromHost retrieves the programmed version of network container from Host.
@@ -45,6 +52,7 @@ func (imdsClient *ImdsClient) GetNetworkContainerInfoFromHost(networkContainerID
4552
}
4653

4754
// GetPrimaryInterfaceInfoFromHost retrieves subnet and gateway of primary NIC from Host.
55+
// TODO(rbtr): this is not a good client contract, we should return the resp.
4856
func (imdsClient *ImdsClient) GetPrimaryInterfaceInfoFromHost() (*InterfaceInfo, error) {
4957
logger.Printf("[Azure CNS] GetPrimaryInterfaceInfoFromHost")
5058

@@ -53,64 +61,87 @@ func (imdsClient *ImdsClient) GetPrimaryInterfaceInfoFromHost() (*InterfaceInfo,
5361
if err != nil {
5462
return nil, err
5563
}
56-
5764
defer resp.Body.Close()
65+
b, err := io.ReadAll(resp.Body)
66+
if err != nil {
67+
return nil, errors.Wrap(err, "failed to read response body")
68+
}
5869

59-
logger.Printf("[Azure CNS] Response received from NMAgent for get interface details: %v", resp.Body)
70+
logger.Printf("[Azure CNS] Response received from NMAgent for get interface details: %s", string(b))
6071

6172
var doc xmlDocument
62-
decoder := xml.NewDecoder(resp.Body)
63-
err = decoder.Decode(&doc)
64-
if err != nil {
65-
return nil, err
73+
if err := xml.NewDecoder(bytes.NewReader(b)).Decode(&doc); err != nil {
74+
return nil, errors.Wrap(err, "failed to decode response body")
6675
}
6776

6877
foundPrimaryInterface := false
6978

7079
// For each interface.
7180
for _, i := range doc.Interface {
72-
// Find primary Interface.
73-
if i.IsPrimary {
74-
interfaceInfo.IsPrimary = true
75-
76-
// Get the first subnet.
77-
for _, s := range i.IPSubnet {
78-
interfaceInfo.Subnet = s.Prefix
79-
malformedSubnetError := fmt.Errorf("Malformed subnet received from host %s", s.Prefix)
80-
81-
st := strings.Split(s.Prefix, "/")
82-
if len(st) != 2 {
83-
return nil, malformedSubnetError
84-
}
85-
86-
ip := strings.Split(st[0], ".")
87-
if len(ip) != 4 {
88-
return nil, malformedSubnetError
89-
}
90-
91-
interfaceInfo.Gateway = fmt.Sprintf("%s.%s.%s.1", ip[0], ip[1], ip[2])
92-
for _, ip := range s.IPAddress {
93-
if ip.IsPrimary {
94-
interfaceInfo.PrimaryIP = ip.Address
95-
}
81+
// skip if not primary
82+
if !i.IsPrimary {
83+
continue
84+
}
85+
interfaceInfo.IsPrimary = true
86+
87+
// Get the first subnet.
88+
for _, s := range i.IPSubnet {
89+
interfaceInfo.Subnet = s.Prefix
90+
gw, err := calculateGatewayIP(s.Prefix)
91+
if err != nil {
92+
return nil, err
93+
}
94+
interfaceInfo.Gateway = gw.String()
95+
for _, ip := range s.IPAddress {
96+
if ip.IsPrimary {
97+
interfaceInfo.PrimaryIP = ip.Address
9698
}
97-
98-
imdsClient.primaryInterface = interfaceInfo
99-
break
10099
}
101100

102-
foundPrimaryInterface = true
101+
imdsClient.primaryInterface = interfaceInfo
103102
break
104103
}
104+
105+
foundPrimaryInterface = true
106+
break
105107
}
106108

107-
var er error
108-
er = nil
109109
if !foundPrimaryInterface {
110-
er = fmt.Errorf("Unable to find primary NIC")
110+
return nil, ErrNoPrimaryInterface
111111
}
112112

113-
return interfaceInfo, er
113+
return interfaceInfo, nil
114+
}
115+
116+
// calculateGatewayIP parses the passed CIDR string and returns the first IP in the range.
117+
func calculateGatewayIP(cidr string) (net.IP, error) {
118+
_, subnet, err := net.ParseCIDR(cidr)
119+
if err != nil {
120+
return nil, errors.Wrap(err, "received malformed subnet from host")
121+
}
122+
123+
// check if we have enough address space to calculate a gateway IP
124+
// we need at least 2 IPs (eg the IPv4 mask cannot be greater than 31)
125+
// since the zeroth is reserved and the gateway is the first.
126+
mask, bits := subnet.Mask.Size()
127+
if mask == bits {
128+
return nil, ErrInsufficientAddressSpace
129+
}
130+
131+
// the subnet IP is the zero base address, so we need to increment it by one to get the gateway.
132+
gw := make([]byte, len(subnet.IP))
133+
copy(gw, subnet.IP)
134+
for idx := len(gw) - 1; idx >= 0; idx-- {
135+
gw[idx]++
136+
// net.IP is a binary byte array, check if we have overflowed and need to continue incrementing to the left
137+
// along the arary or if we're done.
138+
// it's like if we have a 9 in base 10, and add 1, it rolls over to 0 so we're not done - we need to move
139+
// left and increment that digit also.
140+
if gw[idx] != 0 {
141+
break
142+
}
143+
}
144+
return gw, nil
114145
}
115146

116147
// GetPrimaryInterfaceInfoFromMemory retrieves subnet and gateway of primary NIC that is saved in memory.

cns/imdsclient/imdsclient_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package imdsclient
2+
3+
import (
4+
"net"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestGenerateGatewayIP(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
cidr string
15+
want net.IP
16+
wantErr bool
17+
}{
18+
{
19+
name: "base case",
20+
cidr: "10.0.0.0/8",
21+
want: net.IPv4(10, 0, 0, 1),
22+
},
23+
{
24+
name: "nonzero start",
25+
cidr: "10.177.233.128/27",
26+
want: net.IPv4(10, 177, 233, 129),
27+
},
28+
{
29+
name: "invalid",
30+
cidr: "test",
31+
wantErr: true,
32+
},
33+
{
34+
name: "no available",
35+
cidr: "255.255.255.255/32",
36+
wantErr: true,
37+
},
38+
{
39+
name: "max IPv4",
40+
cidr: "255.255.255.255/31",
41+
want: net.IPv4(255, 255, 255, 255),
42+
},
43+
}
44+
for _, tt := range tests {
45+
tt := tt
46+
t.Run(tt.name, func(t *testing.T) {
47+
got, err := calculateGatewayIP(tt.cidr)
48+
if tt.wantErr {
49+
require.Error(t, err)
50+
} else {
51+
require.NoError(t, err)
52+
}
53+
assert.Truef(t, tt.want.Equal(got), "want %s, got %s", tt.want.String(), got.String())
54+
})
55+
}
56+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<Interfaces>
2+
<Interface MacAddress="002248263DBD" IsPrimary="true">
3+
<IPSubnet Prefix="10.240.0.0/16">
4+
<IPAddress Address="10.240.0.4" IsPrimary="true"/>
5+
</IPSubnet>
6+
</Interface>
7+
</Interfaces>

0 commit comments

Comments
 (0)