Skip to content

Commit e6c2e2a

Browse files
Add support in CNS NMAgent Client to fetch secondary IPs (#3017)
* feat(CNS): Changes for fetching secondary IPs from NMAgent in CNS * test(CNS): Add UT for case where no fetch happens * style: Cleanup some comments, address some linting issues * style: address some linting issues * style:Run gofumpt * test: Update test * refactor: Address comments to move business logic out of nmagent client * chore: Add missed files * style: Better naming, comments * style: Better naming * style: Better naming, comments * chore: undo accidental edit * style: comments * style: naming * style: linting issues * feat: Address comments. Add MacAddress and IPAddress as types, refactor ip_fetcher code * style: linting issues * style: linting issues * chore: remove accidental edits * style: lower case in error messages * chore: add missed file * style: Rename MACAddress * style: Address comments * refactor: Address comments * chore: ip_fetcher changes
1 parent 0218099 commit e6c2e2a

File tree

10 files changed

+470
-3
lines changed

10 files changed

+470
-3
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package nodesubnet
2+
3+
import "time"
4+
5+
// This method is in this file (_test.go) because it is a test helper method.
6+
// The following method is built during tests, and is not part of the main code.
7+
func (c *IPFetcher) SetSecondaryIPQueryInterval(interval time.Duration) {
8+
c.secondaryIPQueryInterval = interval
9+
}

cns/nodesubnet/ip_fetcher.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package nodesubnet
2+
3+
import (
4+
"context"
5+
"log"
6+
"net/netip"
7+
"time"
8+
9+
"github.com/Azure/azure-container-networking/nmagent"
10+
"github.com/pkg/errors"
11+
)
12+
13+
var ErrRefreshSkipped = errors.New("refresh skipped due to throttling")
14+
15+
// InterfaceRetriever is an interface is implemented by the NMAgent Client, and also a mock client for testing.
16+
type InterfaceRetriever interface {
17+
GetInterfaceIPInfo(ctx context.Context) (nmagent.Interfaces, error)
18+
}
19+
20+
type IPFetcher struct {
21+
// Node subnet state
22+
secondaryIPQueryInterval time.Duration // Minimum time between secondary IP fetches
23+
secondaryIPLastRefreshTime time.Time // Time of last secondary IP fetch
24+
25+
ipFectcherClient InterfaceRetriever
26+
}
27+
28+
func NewIPFetcher(nmaClient InterfaceRetriever, queryInterval time.Duration) *IPFetcher {
29+
return &IPFetcher{
30+
ipFectcherClient: nmaClient,
31+
secondaryIPQueryInterval: queryInterval,
32+
}
33+
}
34+
35+
func (c *IPFetcher) RefreshSecondaryIPsIfNeeded(ctx context.Context) (ips []netip.Addr, err error) {
36+
// If secondaryIPQueryInterval has elapsed since the last fetch, fetch secondary IPs
37+
if time.Since(c.secondaryIPLastRefreshTime) < c.secondaryIPQueryInterval {
38+
return nil, ErrRefreshSkipped
39+
}
40+
41+
c.secondaryIPLastRefreshTime = time.Now()
42+
response, err := c.ipFectcherClient.GetInterfaceIPInfo(ctx)
43+
if err != nil {
44+
return nil, errors.Wrap(err, "getting interface IPs")
45+
}
46+
47+
res := flattenIPListFromResponse(&response)
48+
return res, nil
49+
}
50+
51+
// Get the list of secondary IPs from fetched Interfaces
52+
func flattenIPListFromResponse(resp *nmagent.Interfaces) (res []netip.Addr) {
53+
// For each interface...
54+
for _, intf := range resp.Entries {
55+
if !intf.IsPrimary {
56+
continue
57+
}
58+
59+
// For each subnet on the interface...
60+
for _, s := range intf.InterfaceSubnets {
61+
addressCount := 0
62+
// For each address in the subnet...
63+
for _, a := range s.IPAddress {
64+
// Primary addresses are reserved for the host.
65+
if a.IsPrimary {
66+
continue
67+
}
68+
69+
res = append(res, netip.Addr(a.Address))
70+
addressCount++
71+
}
72+
log.Printf("Got %d addresses from subnet %s", addressCount, s.Prefix)
73+
}
74+
}
75+
76+
return res
77+
}

cns/nodesubnet/ip_fetcher_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package nodesubnet_test
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
"time"
8+
9+
"github.com/Azure/azure-container-networking/cns/nodesubnet"
10+
"github.com/Azure/azure-container-networking/nmagent"
11+
)
12+
13+
// Mock client that simply tracks if refresh has been called
14+
type TestClient struct {
15+
fetchCalled bool
16+
}
17+
18+
// Mock refresh
19+
func (c *TestClient) GetInterfaceIPInfo(_ context.Context) (nmagent.Interfaces, error) {
20+
c.fetchCalled = true
21+
return nmagent.Interfaces{}, nil
22+
}
23+
24+
func TestRefreshSecondaryIPsIfNeeded(t *testing.T) {
25+
getTests := []struct {
26+
name string
27+
shouldCall bool
28+
interval time.Duration
29+
}{
30+
{
31+
"fetch called",
32+
true,
33+
-1 * time.Second, // Negative timeout to force refresh
34+
},
35+
{
36+
"no refresh needed",
37+
false,
38+
10 * time.Hour, // High timeout to avoid refresh
39+
},
40+
}
41+
42+
clientPtr := &TestClient{}
43+
fetcher := nodesubnet.NewIPFetcher(clientPtr, 0)
44+
45+
for _, test := range getTests {
46+
test := test
47+
t.Run(test.name, func(t *testing.T) { // Do not parallelize, as we are using a shared client
48+
fetcher.SetSecondaryIPQueryInterval(test.interval)
49+
ctx, cancel := testContext(t)
50+
defer cancel()
51+
clientPtr.fetchCalled = false
52+
_, err := fetcher.RefreshSecondaryIPsIfNeeded(ctx)
53+
54+
if test.shouldCall {
55+
if err != nil && errors.Is(err, nodesubnet.ErrRefreshSkipped) {
56+
t.Error("refresh expected, but didn't happen")
57+
}
58+
59+
checkErr(t, err, false)
60+
} else if err == nil || !errors.Is(err, nodesubnet.ErrRefreshSkipped) {
61+
t.Error("refresh not expected, but happened")
62+
}
63+
})
64+
}
65+
}
66+
67+
// testContext creates a context from the provided testing.T that will be
68+
// canceled if the test suite is terminated.
69+
func testContext(t *testing.T) (context.Context, context.CancelFunc) {
70+
if deadline, ok := t.Deadline(); ok {
71+
return context.WithDeadline(context.Background(), deadline)
72+
}
73+
return context.WithCancel(context.Background())
74+
}
75+
76+
// checkErr is an assertion of the presence or absence of an error
77+
func checkErr(t *testing.T, err error, shouldErr bool) {
78+
t.Helper()
79+
if err != nil && !shouldErr {
80+
t.Fatal("unexpected error: err:", err)
81+
}
82+
83+
if err == nil && shouldErr {
84+
t.Fatal("expected error but received none")
85+
}
86+
}

nmagent/client.go

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,8 @@ type Client struct {
4444
httpClient *http.Client
4545

4646
// config
47-
host string
48-
port uint16
49-
47+
host string
48+
port uint16
5049
enableTLS bool
5150

5251
retrier interface {
@@ -284,6 +283,37 @@ func (c *Client) GetHomeAz(ctx context.Context) (AzResponse, error) {
284283
return homeAzResponse, nil
285284
}
286285

286+
// GetInterfaceIPInfo fetches the node's interface IP information from nmagent
287+
func (c *Client) GetInterfaceIPInfo(ctx context.Context) (Interfaces, error) {
288+
req, err := c.buildRequest(ctx, &GetSecondaryIPsRequest{})
289+
var out Interfaces
290+
291+
if err != nil {
292+
return out, errors.Wrap(err, "building request")
293+
}
294+
295+
resp, err := c.httpClient.Do(req)
296+
if err != nil {
297+
return out, errors.Wrap(err, "submitting request")
298+
}
299+
defer resp.Body.Close()
300+
301+
if resp.StatusCode != http.StatusOK {
302+
return out, die(resp.StatusCode, resp.Header, resp.Body, req.URL.Path)
303+
}
304+
305+
if resp.StatusCode != http.StatusOK {
306+
return out, die(resp.StatusCode, resp.Header, resp.Body, req.URL.Path)
307+
}
308+
309+
err = xml.NewDecoder(resp.Body).Decode(&out)
310+
if err != nil {
311+
return out, errors.Wrap(err, "decoding response")
312+
}
313+
314+
return out, nil
315+
}
316+
287317
func die(code int, headers http.Header, body io.ReadCloser, path string) error {
288318
// nolint:errcheck // make a best effort to return whatever information we can
289319
// returning an error here without the code and source would

nmagent/client_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package nmagent_test
33
import (
44
"context"
55
"encoding/json"
6+
"encoding/xml"
67
"fmt"
78
"net/http"
89
"net/http/httptest"
10+
"net/netip"
911
"strings"
1012
"testing"
1113

@@ -809,3 +811,86 @@ func TestGetHomeAz(t *testing.T) {
809811
})
810812
}
811813
}
814+
815+
func TestGetInterfaceIPInfo(t *testing.T) {
816+
tests := []struct {
817+
name string
818+
expURL string
819+
response nmagent.Interfaces
820+
respStr string
821+
}{
822+
{
823+
"happy path",
824+
"/machine/plugins?comp=nmagent&type=getinterfaceinfov1",
825+
nmagent.Interfaces{
826+
Entries: []nmagent.Interface{
827+
{
828+
MacAddress: nmagent.MACAddress{0x00, 0x0D, 0x3A, 0xF9, 0xDC, 0xA6},
829+
IsPrimary: true,
830+
InterfaceSubnets: []nmagent.InterfaceSubnet{
831+
{
832+
Prefix: "10.240.0.0/16",
833+
IPAddress: []nmagent.NodeIP{
834+
{
835+
Address: nmagent.IPAddress(netip.AddrFrom4([4]byte{10, 240, 0, 5})),
836+
IsPrimary: true,
837+
},
838+
{
839+
Address: nmagent.IPAddress(netip.AddrFrom4([4]byte{10, 240, 0, 6})),
840+
IsPrimary: false,
841+
},
842+
},
843+
},
844+
},
845+
},
846+
},
847+
},
848+
"<Interfaces><Interface MacAddress=\"000D3AF9DCA6\" IsPrimary=\"true\"><IPSubnet Prefix=\"10.240.0.0/16\">" +
849+
"<IPAddress Address=\"10.240.0.5\" IsPrimary=\"true\"/><IPAddress Address=\"10.240.0.6\" IsPrimary=\"false\"/>" +
850+
"</IPSubnet></Interface></Interfaces>",
851+
},
852+
}
853+
854+
for _, test := range tests {
855+
test := test
856+
t.Run(test.name, func(t *testing.T) {
857+
t.Parallel()
858+
859+
var gotURL string
860+
client := nmagent.NewTestClient(&TestTripper{
861+
RoundTripF: func(req *http.Request) (*http.Response, error) {
862+
gotURL = req.URL.RequestURI()
863+
rr := httptest.NewRecorder()
864+
rr.WriteHeader(http.StatusOK)
865+
err := xml.NewEncoder(rr).Encode(test.response)
866+
if err != nil {
867+
t.Fatal("unexpected error encoding response: err:", err)
868+
}
869+
return rr.Result(), nil
870+
},
871+
})
872+
873+
ctx, cancel := testContext(t)
874+
defer cancel()
875+
876+
resp, err := client.GetInterfaceIPInfo(ctx)
877+
checkErr(t, err, false)
878+
879+
if gotURL != test.expURL {
880+
t.Error("received URL differs from expected: got:", gotURL, "exp:", test.expURL)
881+
}
882+
883+
if got := resp; !cmp.Equal(got, test.response) {
884+
t.Error("response differs from expectation: diff:", cmp.Diff(got, test.response))
885+
}
886+
887+
var unmarshaled nmagent.Interfaces
888+
err = xml.Unmarshal([]byte(test.respStr), &unmarshaled)
889+
checkErr(t, err, false)
890+
891+
if !cmp.Equal(resp, unmarshaled) {
892+
t.Error("response differs from expected decoded string: diff:", cmp.Diff(resp, unmarshaled))
893+
}
894+
})
895+
}
896+
}

nmagent/ipaddress.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package nmagent
2+
3+
import (
4+
"encoding/xml"
5+
"net/netip"
6+
7+
"github.com/pkg/errors"
8+
)
9+
10+
type IPAddress netip.Addr
11+
12+
func (h IPAddress) Equal(other IPAddress) bool {
13+
return netip.Addr(h).Compare(netip.Addr(other)) == 0
14+
}
15+
16+
func (h *IPAddress) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
17+
var ipStr string
18+
if err := d.DecodeElement(&ipStr, &start); err != nil {
19+
return errors.Wrap(err, "decoding IP address")
20+
}
21+
22+
ip, err := netip.ParseAddr(ipStr)
23+
if err != nil {
24+
return errors.Wrap(err, "parsing IP address")
25+
}
26+
27+
*h = IPAddress(ip)
28+
return nil
29+
}
30+
31+
func (h *IPAddress) UnmarshalXMLAttr(attr xml.Attr) error {
32+
ipStr := attr.Value
33+
ip, err := netip.ParseAddr(ipStr)
34+
if err != nil {
35+
return errors.Wrap(err, "parsing IP address")
36+
}
37+
38+
*h = IPAddress(ip)
39+
return nil
40+
}
41+
42+
func (h IPAddress) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
43+
err := e.EncodeElement(netip.Addr(h).String(), start)
44+
return errors.Wrap(err, "encoding IP address")
45+
}
46+
47+
func (h IPAddress) MarshalXMLAttr(name xml.Name) (xml.Attr, error) {
48+
return xml.Attr{
49+
Name: name,
50+
Value: netip.Addr(h).String(),
51+
}, nil
52+
}

0 commit comments

Comments
 (0)