diff --git a/cns/fakes/imdsclientfake.go b/cns/fakes/imdsclientfake.go index dd69123e90..cc0075ae56 100644 --- a/cns/fakes/imdsclientfake.go +++ b/cns/fakes/imdsclientfake.go @@ -8,6 +8,7 @@ package fakes import ( "context" + "net" "github.com/Azure/azure-container-networking/cns/imds" "github.com/Azure/azure-container-networking/cns/wireserver" @@ -57,3 +58,40 @@ func (m *MockIMDSClient) GetVMUniqueID(ctx context.Context) (string, error) { return "55b8499d-9b42-4f85-843f-24ff69f4a643", nil } + +func (m *MockIMDSClient) GetNetworkInterfaces(ctx context.Context) ([]imds.NetworkInterface, error) { + if ctx.Value(SimulateError) != nil { + return nil, imds.ErrUnexpectedStatusCode + } + + // Parse MAC addresses for testing + macAddr1, _ := net.ParseMAC("00:15:5d:01:02:01") + macAddr2, _ := net.ParseMAC("00:15:5d:01:02:02") + + // Return some mock network interfaces for testing + return []imds.NetworkInterface{ + { + InterfaceCompartmentID: "nc1", + MacAddress: imds.HardwareAddr(macAddr1), + }, + { + InterfaceCompartmentID: "nc2", + MacAddress: imds.HardwareAddr(macAddr2), + }, + }, nil +} + +func (m *MockIMDSClient) GetIMDSVersions(ctx context.Context) (*imds.APIVersionsResponse, error) { + if ctx.Value(SimulateError) != nil { + return nil, imds.ErrUnexpectedStatusCode + } + + // Return supported API versions including the expected one + return &imds.APIVersionsResponse{ + APIVersions: []string{ + "2017-03-01", + "2021-01-01", + "2025-07-24", + }, + }, nil +} diff --git a/cns/imds/client.go b/cns/imds/client.go index 6e210e4705..ac06e6d8a3 100644 --- a/cns/imds/client.go +++ b/cns/imds/client.go @@ -6,6 +6,7 @@ package imds import ( "context" "encoding/json" + "net" "net/http" "net/url" @@ -46,7 +47,10 @@ func RetryAttempts(attempts uint) ClientOption { const ( vmUniqueIDProperty = "vmId" imdsComputePath = "/metadata/instance/compute" - imdsComputeAPIVersion = "api-version=2021-01-01" + imdsNetworkPath = "/metadata/instance/network" + imdsVersionsPath = "/metadata/versions" + imdsDefaultAPIVersion = "api-version=2021-01-01" + imdsNCDetailsVersion = "api-version=2025-07-24" imdsFormatJSON = "format=json" metadataHeaderKey = "Metadata" metadataHeaderValue = "true" @@ -79,7 +83,7 @@ func NewClient(opts ...ClientOption) *Client { func (c *Client) GetVMUniqueID(ctx context.Context) (string, error) { var vmUniqueID string err := retry.Do(func() error { - computeDoc, err := c.getInstanceComputeMetadata(ctx) + computeDoc, err := c.getInstanceMetadata(ctx, imdsComputePath, imdsDefaultAPIVersion) if err != nil { return errors.Wrap(err, "error getting IMDS compute metadata") } @@ -102,14 +106,40 @@ func (c *Client) GetVMUniqueID(ctx context.Context) (string, error) { return vmUniqueID, nil } -func (c *Client) getInstanceComputeMetadata(ctx context.Context) (map[string]any, error) { - imdsComputeURL, err := url.JoinPath(c.config.endpoint, imdsComputePath) +func (c *Client) GetNetworkInterfaces(ctx context.Context) ([]NetworkInterface, error) { + var networkData NetworkInterfaces + err := retry.Do(func() error { + networkInterfaces, err := c.getInstanceMetadata(ctx, imdsNetworkPath, imdsNCDetailsVersion) + if err != nil { + return errors.Wrap(err, "error getting IMDS network metadata") + } + + // Parse the network metadata to the expected structure + jsonData, err := json.Marshal(networkInterfaces) + if err != nil { + return errors.Wrap(err, "error marshaling network metadata") + } + + if err := json.Unmarshal(jsonData, &networkData); err != nil { + return errors.Wrap(err, "error unmarshaling network metadata") + } + return nil + }, retry.Context(ctx), retry.Attempts(c.config.retryAttempts), retry.DelayType(retry.BackOffDelay)) if err != nil { - return nil, errors.Wrap(err, "unable to build path to IMDS compute metadata") + return nil, errors.Wrap(err, "external call failed") } - imdsComputeURL = imdsComputeURL + "?" + imdsComputeAPIVersion + "&" + imdsFormatJSON - req, err := http.NewRequestWithContext(ctx, http.MethodGet, imdsComputeURL, http.NoBody) + return networkData.Interface, nil +} + +func (c *Client) getInstanceMetadata(ctx context.Context, imdsMetadataPath, imdsAPIVersion string) (map[string]any, error) { + imdsRequestURL, err := url.JoinPath(c.config.endpoint, imdsMetadataPath) + if err != nil { + return nil, errors.Wrap(err, "unable to build path to IMDS metadata for path"+imdsMetadataPath) + } + imdsRequestURL = imdsRequestURL + "?" + imdsAPIVersion + "&" + imdsFormatJSON + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, imdsRequestURL, http.NoBody) if err != nil { return nil, errors.Wrap(err, "error building IMDS http request") } @@ -133,3 +163,86 @@ func (c *Client) getInstanceComputeMetadata(ctx context.Context) (map[string]any return m, nil } + +func (c *Client) GetIMDSVersions(ctx context.Context) (*APIVersionsResponse, error) { + var versionsResp APIVersionsResponse + err := retry.Do(func() error { + // Build the URL for the versions endpoint + imdsRequestURL, err := url.JoinPath(c.config.endpoint, imdsVersionsPath) + if err != nil { + return errors.Wrap(err, "unable to build path to IMDS versions endpoint") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, imdsRequestURL, http.NoBody) + if err != nil { + return errors.Wrap(err, "error building IMDS versions http request") + } + + req.Header.Add(metadataHeaderKey, metadataHeaderValue) + resp, err := c.cli.Do(req) + if err != nil { + return errors.Wrap(err, "error querying IMDS versions API") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.Wrapf(ErrUnexpectedStatusCode, "unexpected status code %d", resp.StatusCode) + } + + if err := json.NewDecoder(resp.Body).Decode(&versionsResp); err != nil { + return errors.Wrap(err, "error decoding IMDS versions response as json") + } + + return nil + }, retry.Context(ctx), retry.Attempts(c.config.retryAttempts), retry.DelayType(retry.BackOffDelay)) + if err != nil { + return nil, errors.Wrap(err, "exhausted retries querying IMDS versions") + } + + return &versionsResp, nil +} + +// Required for marshaling/unmarshaling of mac address +type HardwareAddr net.HardwareAddr + +func (h *HardwareAddr) MarshalJSON() ([]byte, error) { + data, err := json.Marshal(net.HardwareAddr(*h).String()) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal hardware address") + } + return data, nil +} + +func (h *HardwareAddr) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return errors.Wrap(err, "failed to unmarshal JSON data") + } + mac, err := net.ParseMAC(s) + if err != nil { + return errors.Wrap(err, "failed to parse MAC address") + } + *h = HardwareAddr(mac) + return nil +} + +func (h *HardwareAddr) String() string { + return net.HardwareAddr(*h).String() +} + +// NetworkInterface represents a network interface from IMDS +type NetworkInterface struct { + // IMDS returns compartment fields - these are mapped to NC ID and NC version + MacAddress HardwareAddr `json:"macAddress"` + InterfaceCompartmentID string `json:"interfaceCompartmentID,omitempty"` +} + +// NetworkInterfaces represents the network interfaces from IMDS +type NetworkInterfaces struct { + Interface []NetworkInterface `json:"interface"` +} + +// APIVersionsResponse represents versions form IMDS +type APIVersionsResponse struct { + APIVersions []string `json:"apiVersions"` +} diff --git a/cns/imds/client_test.go b/cns/imds/client_test.go index 6debda59c3..ac97ba5251 100644 --- a/cns/imds/client_test.go +++ b/cns/imds/client_test.go @@ -5,6 +5,7 @@ package imds_test import ( "context" + "net" "net/http" "net/http/httptest" "os" @@ -100,3 +101,207 @@ func TestInvalidVMUniqueID(t *testing.T) { require.Error(t, err, "error querying testserver") require.Equal(t, "", vmUniqueID) } + +func TestGetNetworkInterfaces(t *testing.T) { + networkInterfaces := []byte(`{ + "interface": [ + { + "interfaceCompartmentID": "nc-12345-67890", + "macAddress": "00:00:5e:00:53:01" + }, + { + "interfaceCompartmentID": "", + "macAddress": "00:00:5e:00:53:02" + } + ] + }`) + + mockIMDSServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // request header "Metadata: true" must be present + metadataHeader := r.Header.Get("Metadata") + assert.Equal(t, "true", metadataHeader) + + // verify path is network metadata + assert.Contains(t, r.URL.Path, "/metadata/instance/network") + + w.WriteHeader(http.StatusOK) + _, writeErr := w.Write(networkInterfaces) + if writeErr != nil { + t.Errorf("error writing response: %v", writeErr) + return + } + })) + defer mockIMDSServer.Close() + + imdsClient := imds.NewClient(imds.Endpoint(mockIMDSServer.URL)) + interfaces, err := imdsClient.GetNetworkInterfaces(context.Background()) + require.NoError(t, err, "error querying testserver") + + // Verify we got the expected interfaces + assert.Len(t, interfaces, 2, "expected 2 interfaces") + + // Check first interface + assert.Equal(t, "nc-12345-67890", interfaces[0].InterfaceCompartmentID) + assert.Equal(t, "00:00:5e:00:53:01", interfaces[0].MacAddress.String(), "first interface MAC address should match") + + // Check second interface + assert.Equal(t, "", interfaces[1].InterfaceCompartmentID) + assert.Equal(t, "00:00:5e:00:53:02", interfaces[1].MacAddress.String(), "second interface MAC address should match") + + // Test that MAC addresses can be converted to net.HardwareAddr + firstMAC := net.HardwareAddr(interfaces[0].MacAddress) + secondMAC := net.HardwareAddr(interfaces[1].MacAddress) + + // Verify the underlying types work correctly + assert.Len(t, firstMAC, 6, "MAC address should be 6 bytes") + assert.Len(t, secondMAC, 6, "MAC address should be 6 bytes") + + // Test that they're different MAC addresses + assert.NotEqual(t, firstMAC.String(), secondMAC.String(), "MAC addresses should be different") +} + +func TestGetNetworkInterfacesInvalidEndpoint(t *testing.T) { + imdsClient := imds.NewClient(imds.Endpoint(string([]byte{0x7f})), imds.RetryAttempts(1)) + _, err := imdsClient.GetNetworkInterfaces(context.Background()) + require.Error(t, err, "expected invalid path") +} + +func TestGetNetworkInterfacesInvalidJSON(t *testing.T) { + mockIMDSServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("not json")) + if err != nil { + t.Errorf("error writing response: %v", err) + return + } + })) + defer mockIMDSServer.Close() + + imdsClient := imds.NewClient(imds.Endpoint(mockIMDSServer.URL), imds.RetryAttempts(1)) + _, err := imdsClient.GetNetworkInterfaces(context.Background()) + require.Error(t, err, "expected json decoding error") +} + +func TestGetNetworkInterfacesNoNCIDs(t *testing.T) { + networkInterfacesNoNC := []byte(`{ + "interface": [ + { + "ipv4": { + "ipAddress": [ + { + "privateIpAddress": "10.0.0.4", + "publicIpAddress": "" + } + ] + }, + "macAddress": "00:00:5e:00:53:01" + } + ] + }`) + + mockIMDSServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + metadataHeader := r.Header.Get("Metadata") + assert.Equal(t, "true", metadataHeader) + + w.WriteHeader(http.StatusOK) + _, writeErr := w.Write(networkInterfacesNoNC) + if writeErr != nil { + t.Errorf("error writing response: %v", writeErr) + return + } + })) + defer mockIMDSServer.Close() + + imdsClient := imds.NewClient(imds.Endpoint(mockIMDSServer.URL)) + interfaces, err := imdsClient.GetNetworkInterfaces(context.Background()) + require.NoError(t, err, "error querying testserver") + + // Verify we got interfaces but they don't have compartment IDs + assert.Len(t, interfaces, 1, "expected 1 interface") + + // Check that interfaces don't have compartment IDs + assert.Equal(t, "", interfaces[0].InterfaceCompartmentID) + assert.Equal(t, "00:00:5e:00:53:01", interfaces[0].MacAddress.String(), "MAC address should match") +} + +func TestGetIMDSVersions(t *testing.T) { + mockResponseBody := `{"apiVersions": ["2017-03-01", "2021-01-01", "2025-07-24"]}` + expectedVersions := []string{"2017-03-01", "2021-01-01", "2025-07-24"} + + mockIMDSServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request headers + metadataHeader := r.Header.Get("Metadata") + assert.Equal(t, "true", metadataHeader) + + w.WriteHeader(http.StatusOK) + _, writeErr := w.Write([]byte(mockResponseBody)) + if writeErr != nil { + t.Errorf("error writing response: %v", writeErr) + return + } + })) + defer mockIMDSServer.Close() + + imdsClient := imds.NewClient(imds.Endpoint(mockIMDSServer.URL)) + versionsResp, err := imdsClient.GetIMDSVersions(context.Background()) + + require.NoError(t, err, "unexpected error") + assert.Equal(t, expectedVersions, versionsResp.APIVersions, "API versions should match expected") +} + +func TestGetIMDSVersionsInvalidJSON(t *testing.T) { + mockIMDSServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, writeErr := w.Write([]byte(`{"invalid": json}`)) + if writeErr != nil { + t.Errorf("error writing response: %v", writeErr) + return + } + })) + defer mockIMDSServer.Close() + + imdsClient := imds.NewClient(imds.Endpoint(mockIMDSServer.URL), imds.RetryAttempts(1)) + versionsResp, err := imdsClient.GetIMDSVersions(context.Background()) + + require.Error(t, err, "expected error for invalid JSON") + assert.Nil(t, versionsResp, "response should be nil on error") +} + +func TestGetIMDSVersionsInternalServerError(t *testing.T) { + mockIMDSServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer mockIMDSServer.Close() + + imdsClient := imds.NewClient(imds.Endpoint(mockIMDSServer.URL), imds.RetryAttempts(1)) + versionsResp, err := imdsClient.GetIMDSVersions(context.Background()) + + require.Error(t, err, "expected error for 500") + assert.Nil(t, versionsResp, "response should be nil or error") +} + +func TestGetIMDSVersionsMissingAPIVersionsField(t *testing.T) { + mockResponseBody := `{"otherField": "value"}` + + mockIMDSServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, writeErr := w.Write([]byte(mockResponseBody)) + if writeErr != nil { + t.Errorf("error writing response: %v", writeErr) + return + } + })) + defer mockIMDSServer.Close() + + imdsClient := imds.NewClient(imds.Endpoint(mockIMDSServer.URL)) + versionsResp, err := imdsClient.GetIMDSVersions(context.Background()) + + require.NoError(t, err, "unexpected error") + assert.Nil(t, versionsResp.APIVersions, "API versions should be nil when field is missing") +} + +func TestGetIMDSVersionsInvalidEndpoint(t *testing.T) { + imdsClient := imds.NewClient(imds.Endpoint(string([]byte{0x7f})), imds.RetryAttempts(1)) + _, err := imdsClient.GetIMDSVersions(context.Background()) + require.Error(t, err, "expected error for invalid endpoint") +} diff --git a/cns/restserver/internalapi.go b/cns/restserver/internalapi.go index 8d477cbab8..9792aed5a3 100644 --- a/cns/restserver/internalapi.go +++ b/cns/restserver/internalapi.go @@ -25,6 +25,12 @@ import ( "github.com/pkg/errors" ) +const ( + // Known API names we care about + expectedIMDSAPIVersion = "2025-07-24" + PrefixOnNicNCVersion = "1" +) + // This file contains the internal functions called by either HTTP APIs (api.go) or // internal APIs (definde in internalapi.go). // This will be used internally (say by RequestController in case of AKS) @@ -221,18 +227,39 @@ func (service *HTTPRestService) syncHostNCVersion(ctx context.Context, channelMo return len(programmedNCs), errors.Wrap(err, "failed to get nc version list from nmagent") } + // Get IMDS NC versions for delegated NIC scenarios + imdsNCVersions, err := service.GetIMDSNCs(ctx) + if err != nil { + // If any of the NMA API check calls, imds calls fails assume that nma build doesn't have the latest changes and create empty map + imdsNCVersions = make(map[string]string) + } + nmaNCs := map[string]string{} for _, nc := range ncVersionListResp.Containers { nmaNCs[strings.ToLower(nc.NetworkContainerID)] = nc.Version } - hasNC.Set(float64(len(nmaNCs))) + + // Consolidate both nc's from NMA and IMDS calls + nmaProgrammedNCs := make(map[string]string) + for ncID, version := range nmaNCs { + nmaProgrammedNCs[ncID] = version + } + for ncID, version := range imdsNCVersions { + if _, exists := nmaProgrammedNCs[ncID]; !exists { + nmaProgrammedNCs[strings.ToLower(ncID)] = version + } else { + //nolint:staticcheck // SA1019: suppress deprecated logger.Warnf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated + logger.Warnf("NC %s exists in both NMA and IMDS responses, which is not expected", ncID) + } + } + hasNC.Set(float64(len(nmaProgrammedNCs))) for ncID := range outdatedNCs { - nmaNCVersionStr, ok := nmaNCs[ncID] + nmaProgrammedNCVersionStr, ok := nmaProgrammedNCs[ncID] if !ok { - // NMA doesn't have this NC that we need programmed yet, bail out + // Neither NMA nor IMDS has this NC that we need programmed yet, bail out continue } - nmaNCVersion, err := strconv.Atoi(nmaNCVersionStr) + nmaProgrammedNCVersion, err := strconv.Atoi(nmaProgrammedNCVersionStr) if err != nil { logger.Errorf("failed to parse container version of %s: %s", ncID, err) continue @@ -245,7 +272,7 @@ func (service *HTTPRestService) syncHostNCVersion(ctx context.Context, channelMo return len(programmedNCs), errors.Wrapf(errNonExistentContainerStatus, "can't find NC with ID %s in service state, stop updating this host NC version", ncID) } // if the NC still exists in state and is programmed to some version (doesn't have to be latest), add it to our set of NCs that have been programmed - if nmaNCVersion > -1 { + if nmaProgrammedNCVersion > -1 { programmedNCs[ncID] = struct{}{} } @@ -254,15 +281,17 @@ func (service *HTTPRestService) syncHostNCVersion(ctx context.Context, channelMo logger.Errorf("failed to parse host nc version string %s: %s", ncInfo.HostVersion, err) continue } - if localNCVersion > nmaNCVersion { - logger.Errorf("NC version from NMA is decreasing: have %d, got %d", localNCVersion, nmaNCVersion) + if localNCVersion > nmaProgrammedNCVersion { + //nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated + logger.Errorf("NC version from consolidated sources is decreasing: have %d, got %d", localNCVersion, nmaProgrammedNCVersion) continue } if channelMode == cns.CRD { - service.MarkIpsAsAvailableUntransacted(ncInfo.ID, nmaNCVersion) + service.MarkIpsAsAvailableUntransacted(ncInfo.ID, nmaProgrammedNCVersion) } - logger.Printf("Updating NC %s host version from %s to %s", ncID, ncInfo.HostVersion, nmaNCVersionStr) - ncInfo.HostVersion = nmaNCVersionStr + //nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated + logger.Printf("Updating NC %s host version from %s to %s", ncID, ncInfo.HostVersion, nmaProgrammedNCVersionStr) + ncInfo.HostVersion = nmaProgrammedNCVersionStr logger.Printf("Updated NC %s host version to %s", ncID, ncInfo.HostVersion) service.state.ContainerStatus[ncID] = ncInfo // if we successfully updated the NC, pop it from the needs update set. @@ -271,7 +300,7 @@ func (service *HTTPRestService) syncHostNCVersion(ctx context.Context, channelMo // if we didn't empty out the needs update set, NMA has not programmed all the NCs we are expecting, and we // need to return an error indicating that if len(outdatedNCs) > 0 { - return len(programmedNCs), errors.Errorf("unabled to update some NCs: %v, missing or bad response from NMA", outdatedNCs) + return len(programmedNCs), errors.Errorf("unable to update some NCs: %v, missing or bad response from NMA or IMDS", outdatedNCs) } return len(programmedNCs), nil @@ -332,7 +361,6 @@ func (service *HTTPRestService) ReconcileIPAssignment(podInfoByIP map[string]cns for _, ip := range ips { if ncReq, ok := allSecIPsIdx[ip.String()]; ok { - logger.Printf("secondary ip %s is assigned to pod %+v, ncId: %s ncVersion: %s", ip, podIPs, ncReq.NetworkContainerid, ncReq.Version) desiredIPs = append(desiredIPs, ip.String()) ncIDs = append(ncIDs, ncReq.NetworkContainerid) } else { @@ -361,7 +389,8 @@ func (service *HTTPRestService) ReconcileIPAssignment(podInfoByIP map[string]cns } if _, err := requestIPConfigsHelper(service, ipconfigsRequest); err != nil { - logger.Errorf("requestIPConfigsHelper failed for pod key %s, podInfo %+v, ncIds %v, error: %v", podKey, podIPs, ncIDs, err) + //nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated + logger.Errorf("requestIPConfigsHelper failed for pod key %s, podInfo %+v, ncIDs %v, error: %v", podKey, podIPs, ncIDs, err) return types.FailedToAllocateIPConfig } } @@ -634,3 +663,61 @@ func (service *HTTPRestService) CreateOrUpdateNetworkContainerInternal(req *cns. func (service *HTTPRestService) SetVFForAccelnetNICs() error { return service.setVFForAccelnetNICs() } + +func (service *HTTPRestService) isNCDetailsAPIExists(ctx context.Context) bool { + versionsResp, err := service.imdsClient.GetIMDSVersions(ctx) + if err != nil { + //nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated + logger.Errorf("Failed to get IMDS versions: %v", err) + return false + } + + // Check if the expected API version exists in the response + for _, version := range versionsResp.APIVersions { + if version == expectedIMDSAPIVersion { + //nolint:staticcheck // SA1019: suppress deprecated logger.Debugf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated + logger.Debugf("IMDS has expected API version") + return true + } + } + return false +} + +// GetIMDSNCs gets NC versions from IMDS and returns them as a map +func (service *HTTPRestService) GetIMDSNCs(ctx context.Context) (map[string]string, error) { + imdsClient := service.imdsClient + if imdsClient == nil { + //nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated + logger.Errorf("IMDS client is not available") + return make(map[string]string), nil + } + // Check NC version support + if !service.isNCDetailsAPIExists(ctx) { + //nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated + logger.Errorf("IMDS does not support NC details API") + return make(map[string]string), nil + } + + // Get all network interfaces from IMDS + networkInterfaces, err := imdsClient.GetNetworkInterfaces(ctx) + if err != nil { + //nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated + logger.Errorf("Failed to get network interfaces from IMDS: %v", err) + return make(map[string]string), nil + } + + // Build ncs map from the network interfaces + ncs := make(map[string]string) + for _, iface := range networkInterfaces { + //nolint:staticcheck // SA1019: suppress deprecated logger.Debugf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated + logger.Debugf("Nc id: %s and mac address: %s from IMDS call", iface.InterfaceCompartmentID, iface.MacAddress.String()) + // IMDS returns interfaceCompartmentID, as nc id guid has different context on nma. We map these to NC ID + ncID := iface.InterfaceCompartmentID + + if ncID != "" { + ncs[ncID] = PrefixOnNicNCVersion // for prefix on nic version scenario nc version is 1 + } + } + + return ncs, nil +} diff --git a/cns/restserver/internalapi_test.go b/cns/restserver/internalapi_test.go index 10b1852c63..53c26fe0ad 100644 --- a/cns/restserver/internalapi_test.go +++ b/cns/restserver/internalapi_test.go @@ -11,6 +11,7 @@ import ( "os" "reflect" "strconv" + "strings" "sync" "testing" "time" @@ -19,6 +20,7 @@ import ( "github.com/Azure/azure-container-networking/cns/common" "github.com/Azure/azure-container-networking/cns/configuration" "github.com/Azure/azure-container-networking/cns/fakes" + "github.com/Azure/azure-container-networking/cns/imds" "github.com/Azure/azure-container-networking/cns/types" "github.com/Azure/azure-container-networking/crd/nodenetworkconfig/api/v1alpha" nma "github.com/Azure/azure-container-networking/nmagent" @@ -42,6 +44,7 @@ const ( batchSize = 10 initPoolSize = 10 ncID = "6a07155a-32d7-49af-872f-1e70ee366dc0" + imdsNCID = "6a07155a-32d7-49af-872f-1e70ee36imds" ) var dnsservers = []string{"8.8.8.8", "8.8.4.4"} @@ -227,7 +230,6 @@ func TestSyncHostNCVersion(t *testing.T) { // cns.KubernetesCRD has one more logic compared to other orchestrator type, so test both of them orchestratorTypes := []string{cns.Kubernetes, cns.KubernetesCRD} for _, orchestratorType := range orchestratorTypes { - orchestratorType := orchestratorType t.Run(orchestratorType, func(t *testing.T) { req := createNCReqeustForSyncHostNCVersion(t) containerStatus := svc.state.ContainerStatus[req.NetworkContainerid] @@ -253,8 +255,22 @@ func TestSyncHostNCVersion(t *testing.T) { cleanup := setMockNMAgent(svc, mnma) defer cleanup() - // When syncing the host NC version, it will use the orchestratorType passed - // in. + // Add the IMDS NC to the CNS state + svc.state.ContainerStatus[imdsNCID] = containerstatus{ + ID: imdsNCID, + VMVersion: "0", + HostVersion: "-1", + CreateNetworkContainerRequest: cns.CreateNetworkContainerRequest{ + NetworkContainerid: imdsNCID, + Version: "1", + }, + } + + // Setup IMDS mock with version support + cleanupIMDS := setupIMDSMockAPIsWithCustomIDs(svc, []string{imdsNCID, "nc2"}) + defer cleanupIMDS() + + // When syncing the host NC version, it will use the orchestratorType passed in. svc.SyncHostNCVersion(context.Background(), orchestratorType) containerStatus = svc.state.ContainerStatus[req.NetworkContainerid] if containerStatus.HostVersion != "0" { @@ -263,6 +279,15 @@ func TestSyncHostNCVersion(t *testing.T) { if containerStatus.CreateNetworkContainerRequest.Version != "0" { t.Errorf("Unexpected nc version in containerStatus as %s, expected VM version should be 0 in string", containerStatus.CreateNetworkContainerRequest.Version) } + + // Validate the second NC from IMDS - it should now use the default version since IMDS doesn't returns version + imdsContainerStatus := svc.state.ContainerStatus[imdsNCID] + if imdsContainerStatus.HostVersion != "1" { // Changed from "0" to "1" since we use default version for IMDS NCs + t.Errorf("Unexpected imdsContainerStatus.HostVersion %s, expected host version should be 1 in string", imdsContainerStatus.HostVersion) + } + if imdsContainerStatus.CreateNetworkContainerRequest.Version != "1" { + t.Errorf("Unexpected version %s, expected VM version should remain 1 in string", imdsContainerStatus.CreateNetworkContainerRequest.Version) + } }) } } @@ -312,6 +337,212 @@ func TestPendingIPsGotUpdatedWhenSyncHostNCVersion(t *testing.T) { } } +func TestSyncHostNCVersionErrorMissingNC(t *testing.T) { + req := createNCReqeustForSyncHostNCVersion(t) + + svc.Lock() + ncStatus := svc.state.ContainerStatus[req.NetworkContainerid] + ncStatus.CreateNetworkContainerRequest.Version = "2" + ncStatus.HostVersion = "0" + svc.state.ContainerStatus[req.NetworkContainerid] = ncStatus + svc.Unlock() + + // Setup IMDS mock with different interface IDs (not matching the outdated NC) + cleanupIMDS := setupIMDSMockAPIsWithCustomIDs(svc, []string{"different-nc-id-1", "different-nc-id-2"}) + defer cleanupIMDS() + + // NMAgent returns empty + mnma := &fakes.NMAgentClientFake{ + GetNCVersionListF: func(_ context.Context) (nma.NCVersionList, error) { + return nma.NCVersionList{ + Containers: []nma.NCVersion{}, + }, nil + }, + } + cleanup := setMockNMAgent(svc, mnma) + defer cleanup() + + _, err := svc.syncHostNCVersion(context.Background(), cns.KubernetesCRD) + if err == nil { + t.Errorf("Expected error when NC is missing from both NMAgent and IMDS, but got nil") + } + + // Check that the error message contains the expected text + expectedErrorText := "unable to update some NCs" + if !strings.Contains(err.Error(), expectedErrorText) { + t.Errorf("Expected error to contain '%s', but got: %v", expectedErrorText, err) + } + + // Verify that the NC HostVersion was not updated, should still be 0 + containerStatus := svc.state.ContainerStatus[req.NetworkContainerid] + if containerStatus.HostVersion != "0" { + t.Errorf("Expected HostVersion to remain 0, but got %s", containerStatus.HostVersion) + } +} + +func TestSyncHostNCVersionLocalVersionHigher(t *testing.T) { + // Test scenario where local NC version is higher than consolidated NC version from IMDS + // This should trigger the "NC version from consolidated sources is decreasing" error + req := createNCReqeustForSyncHostNCVersion(t) + + svc.Lock() + ncStatus := svc.state.ContainerStatus[req.NetworkContainerid] + ncStatus.CreateNetworkContainerRequest.Version = "1" // DNC version is 1 + ncStatus.HostVersion = "3" // But local host version is 3 + svc.state.ContainerStatus[req.NetworkContainerid] = ncStatus + svc.Unlock() + + // Create IMDS mock that returns lower version(1) than local host version(3) + // Setup IMDS mock with version support + cleanupIMDS := setupIMDSMockAPIsWithCustomIDs(svc, []string{imdsNCID, "nc2"}) + defer cleanupIMDS() + + mnma := &fakes.NMAgentClientFake{ + GetNCVersionListF: func(_ context.Context) (nma.NCVersionList, error) { + return nma.NCVersionList{ + Containers: []nma.NCVersion{}, + }, nil + }, + } + cleanup := setMockNMAgent(svc, mnma) + defer cleanup() + + _, err := svc.syncHostNCVersion(context.Background(), cns.KubernetesCRD) + if err != nil { + t.Errorf("Expected sync to succeed, but got error: %v", err) + } + + // Verify that the NC HostVersion was NOT updated (should remain "3") + containerStatus := svc.state.ContainerStatus[req.NetworkContainerid] + if containerStatus.HostVersion != "3" { + t.Errorf("Expected HostVersion to remain 3 (unchanged due to decreasing version), got %s", containerStatus.HostVersion) + } + + t.Logf("Successfully handled decreasing version scenario: local=%s, consolidated=%s", + containerStatus.HostVersion, "2") +} + +func TestSyncHostNCVersionLocalHigherThanDNC(t *testing.T) { + // Test scenario where localNCVersion > dncNCVersion + // This should trigger an error log: "NC version from NMAgent is larger than DNC" + req := createNCReqeustForSyncHostNCVersion(t) + + // Set up the NC state where HostVersion (localNCVersion) > DNC NC Version + svc.Lock() + ncStatus := svc.state.ContainerStatus[req.NetworkContainerid] + ncStatus.CreateNetworkContainerRequest.Version = "1" + ncStatus.HostVersion = "3" + svc.state.ContainerStatus[req.NetworkContainerid] = ncStatus + svc.Unlock() + + cleanupIMDS := setupIMDSMockAPIsWithCustomIDs(svc, []string{imdsNCID, "nc2"}) + defer cleanupIMDS() + + mnma := &fakes.NMAgentClientFake{ + GetNCVersionListF: func(_ context.Context) (nma.NCVersionList, error) { + return nma.NCVersionList{ + Containers: []nma.NCVersion{}, // Empty + }, nil + }, + } + cleanup := setMockNMAgent(svc, mnma) + defer cleanup() + + // This should detect that localNCVersion (3) > dncNCVersion (1) and log error + // but since there are no outdated NCs, it should return successfully + _, err := svc.syncHostNCVersion(context.Background(), cns.KubernetesCRD) + if err != nil { + t.Errorf("Expected no error when localNCVersion > dncNCVersion (no outdated NCs), but got: %v", err) + } + + containerStatus := svc.state.ContainerStatus[req.NetworkContainerid] + if containerStatus.HostVersion != "3" { + t.Errorf("Expected HostVersion to remain 3 (unchanged), got %s", containerStatus.HostVersion) + } + + // Verify that the DNC version remains unchanged, should still be 1 + if containerStatus.CreateNetworkContainerRequest.Version != "1" { + t.Errorf("Expected DNC version to remain 1 (unchanged), got %s", containerStatus.CreateNetworkContainerRequest.Version) + } + + t.Logf("Successfully handled localNCVersion > dncNCVersion scenario: local=%s, dnc=%s", + containerStatus.HostVersion, containerStatus.CreateNetworkContainerRequest.Version) +} + +func TestSyncHostNCVersionIMDSAPIVersionNotSupported(t *testing.T) { + orchestratorTypes := []string{cns.Kubernetes, cns.KubernetesCRD} + for _, orchestratorType := range orchestratorTypes { + t.Run(orchestratorType, func(t *testing.T) { + req := createNCReqeustForSyncHostNCVersion(t) + + // Make the NMA NC up-to-date so it doesn't interfere with the test + svc.Lock() + nmaNCStatus := svc.state.ContainerStatus[req.NetworkContainerid] + nmaNCStatus.HostVersion = "0" // Same as Version "0", so not outdated + svc.state.ContainerStatus[req.NetworkContainerid] = nmaNCStatus + svc.Unlock() + + // NMAgent mock - not important for this test, just needs to not interfere + mnma := &fakes.NMAgentClientFake{ + GetNCVersionListF: func(_ context.Context) (nma.NCVersionList, error) { + return nma.NCVersionList{Containers: []nma.NCVersion{}}, nil + }, + } + cleanup := setMockNMAgent(svc, mnma) + defer cleanup() + + // Add IMDS NC that is outdated - this is the focus of the test + svc.state.ContainerStatus[imdsNCID] = containerstatus{ + ID: imdsNCID, + VMVersion: "0", + HostVersion: "-1", // Outdated + CreateNetworkContainerRequest: cns.CreateNetworkContainerRequest{ + NetworkContainerid: imdsNCID, + Version: "1", // Higher than HostVersion + }, + } + + // Setup IMDS mock that returns API versions WITHOUT the expected version "2025-07-24" + mockIMDS := &struct { + networkInterfaces func(_ context.Context) ([]imds.NetworkInterface, error) + imdsVersions func(_ context.Context) (*imds.APIVersionsResponse, error) + }{ + networkInterfaces: func(_ context.Context) ([]imds.NetworkInterface, error) { + return []imds.NetworkInterface{ + {InterfaceCompartmentID: imdsNCID}, + }, nil + }, + imdsVersions: func(_ context.Context) (*imds.APIVersionsResponse, error) { + return &imds.APIVersionsResponse{ + APIVersions: []string{"2017-03-01", "2021-01-01"}, // Missing "2025-07-24" + }, nil + }, + } + originalIMDS := svc.imdsClient + svc.imdsClient = &mockIMDSAdapter{mockIMDS} + defer func() { svc.imdsClient = originalIMDS }() + + // Test should fail because of outdated IMDS NC that can't be updated + _, err := svc.syncHostNCVersion(context.Background(), orchestratorType) + if err == nil { + t.Errorf("Expected error when there are outdated IMDS NCs but API version is not supported, but got nil") + } + + // Verify the error is about being unable to update NCs + expectedErrorText := "unable to update some NCs" + if !strings.Contains(err.Error(), expectedErrorText) { + t.Errorf("Expected error to contain '%s', but got: %v", expectedErrorText, err) + } + + // Only verify IMDS NC state - this is the focus of the test + imdsContainerStatus := svc.state.ContainerStatus[imdsNCID] + if imdsContainerStatus.HostVersion != "-1" { + t.Errorf("Expected IMDS NC HostVersion to remain -1, got %s", imdsContainerStatus.HostVersion) + } + }) + } +} + func createNCReqeustForSyncHostNCVersion(t *testing.T) cns.CreateNetworkContainerRequest { restartService() setEnv(t) @@ -1111,7 +1342,11 @@ func TestCNIConflistGenerationNewNC(t *testing.T) { }, }, nil }, + SupportedAPIsF: func(_ context.Context) ([]string, error) { + return []string{"EnableSwiftV2NCGoalStateSupport", "OtherAPI"}, nil + }, }, + imdsClient: fakes.NewMockIMDSClient(), } service.SyncHostNCVersion(context.Background(), cns.CRD) @@ -1137,6 +1372,22 @@ func TestCNIConflistGenerationExistingNC(t *testing.T) { }, }, }, + nma: &fakes.NMAgentClientFake{ + GetNCVersionListF: func(_ context.Context) (nma.NCVersionList, error) { + return nma.NCVersionList{ + Containers: []nma.NCVersion{ + { + NetworkContainerID: ncID, + Version: "0", + }, + }, + }, nil + }, + SupportedAPIsF: func(_ context.Context) ([]string, error) { + return []string{}, nil + }, + }, + imdsClient: fakes.NewMockIMDSClient(), } service.SyncHostNCVersion(context.Background(), cns.CRD) @@ -1174,7 +1425,11 @@ func TestCNIConflistGenerationNewNCTwice(t *testing.T) { }, }, nil }, + SupportedAPIsF: func(_ context.Context) ([]string, error) { + return []string{}, nil + }, }, + imdsClient: fakes.NewMockIMDSClient(), } service.SyncHostNCVersion(context.Background(), cns.CRD) @@ -1209,7 +1464,11 @@ func TestCNIConflistNotGenerated(t *testing.T) { GetNCVersionListF: func(_ context.Context) (nma.NCVersionList, error) { return nma.NCVersionList{}, nil }, + SupportedAPIsF: func(_ context.Context) ([]string, error) { + return []string{"EnableSwiftV2NCGoalStateSupport", "OtherAPI"}, nil + }, }, + imdsClient: fakes.NewMockIMDSClient(), } service.SyncHostNCVersion(context.Background(), cns.CRD) @@ -1248,7 +1507,11 @@ func TestCNIConflistGenerationOnNMAError(t *testing.T) { GetNCVersionListF: func(_ context.Context) (nma.NCVersionList, error) { return nma.NCVersionList{}, errors.New("some nma error") }, + SupportedAPIsF: func(_ context.Context) ([]string, error) { + return []string{"EnableSwiftV2NCGoalStateSupport", "OtherAPI"}, nil + }, }, + imdsClient: fakes.NewMockIMDSClient(), } service.SyncHostNCVersion(context.Background(), cns.CRD) @@ -1406,3 +1669,61 @@ func TestMustEnsureNoStaleNCs_PanicsWhenIPsFromStaleNCAreAssigned(t *testing.T) svc.MustEnsureNoStaleNCs([]string{"nc3", "nc4"}) }) } + +type mockIMDSAdapter struct { + mock *struct { + networkInterfaces func(_ context.Context) ([]imds.NetworkInterface, error) + imdsVersions func(_ context.Context) (*imds.APIVersionsResponse, error) + } +} + +func (m *mockIMDSAdapter) GetVMUniqueID(_ context.Context) (string, error) { + panic("GetVMUniqueID should not be called in syncHostNCVersion tests, adding mockIMDSAdapter implements the full IMDS interface") +} + +func (m *mockIMDSAdapter) GetNetworkInterfaces(ctx context.Context) ([]imds.NetworkInterface, error) { + return m.mock.networkInterfaces(ctx) +} + +func (m *mockIMDSAdapter) GetIMDSVersions(ctx context.Context) (*imds.APIVersionsResponse, error) { + if m.mock.imdsVersions != nil { + return m.mock.imdsVersions(ctx) + } + // Default implementation that returns supported API versions + return &imds.APIVersionsResponse{ + APIVersions: []string{"2017-03-01", "2021-01-01", "2025-07-24"}, + }, nil +} + +func setupIMDSMockAPIsWithCustomIDs(svc *HTTPRestService, interfaceIDs []string) func() { + mockIMDS := &struct { + networkInterfaces func(_ context.Context) ([]imds.NetworkInterface, error) + imdsVersions func(_ context.Context) (*imds.APIVersionsResponse, error) + }{ + networkInterfaces: func(_ context.Context) ([]imds.NetworkInterface, error) { + var interfaces []imds.NetworkInterface + for _, id := range interfaceIDs { + interfaces = append(interfaces, imds.NetworkInterface{ + InterfaceCompartmentID: id, + }) + } + return interfaces, nil + }, + imdsVersions: func(_ context.Context) (*imds.APIVersionsResponse, error) { + return &imds.APIVersionsResponse{ + APIVersions: []string{ + "2017-03-01", + "2021-01-01", + "2025-07-24", + }, + }, nil + }, + } + + // Set up the mock + originalIMDS := svc.imdsClient + svc.imdsClient = &mockIMDSAdapter{mockIMDS} + + // Return cleanup function + return func() { svc.imdsClient = originalIMDS } +} diff --git a/cns/restserver/restserver.go b/cns/restserver/restserver.go index c467ab04e2..23bbffa95c 100644 --- a/cns/restserver/restserver.go +++ b/cns/restserver/restserver.go @@ -11,6 +11,7 @@ import ( "github.com/Azure/azure-container-networking/cns" "github.com/Azure/azure-container-networking/cns/common" "github.com/Azure/azure-container-networking/cns/dockerclient" + "github.com/Azure/azure-container-networking/cns/imds" "github.com/Azure/azure-container-networking/cns/logger" "github.com/Azure/azure-container-networking/cns/networkcontainers" "github.com/Azure/azure-container-networking/cns/nodesubnet" @@ -52,6 +53,8 @@ type wireserverProxy interface { type imdsClient interface { GetVMUniqueID(ctx context.Context) (string, error) + GetNetworkInterfaces(ctx context.Context) ([]imds.NetworkInterface, error) + GetIMDSVersions(ctx context.Context) (*imds.APIVersionsResponse, error) } type iptablesClient interface {