Skip to content
18 changes: 18 additions & 0 deletions cns/fakes/imdsclientfake.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,21 @@ func (m *MockIMDSClient) GetVMUniqueID(ctx context.Context) (string, error) {

return "55b8499d-9b42-4f85-843f-24ff69f4a643", nil
}

func (m *MockIMDSClient) GetNCVersions(ctx context.Context) ([]imds.NetworkInterface, error) {
if ctx.Value(SimulateError) != nil {
return nil, imds.ErrUnexpectedStatusCode
}

// Return some mock NC versions for testing
return []imds.NetworkInterface{
{
InterfaceCompartmentID: "nc1",
InterfaceCompartmentVersion: "1",
},
{
InterfaceCompartmentID: "nc2",
InterfaceCompartmentVersion: "2",
},
}, nil
}
46 changes: 43 additions & 3 deletions cns/imds/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func RetryAttempts(attempts uint) ClientOption {
const (
vmUniqueIDProperty = "vmId"
imdsComputePath = "/metadata/instance/compute"
imdsNetworkPath = "/metadata/instance/network"
imdsComputeAPIVersion = "api-version=2021-01-01"
imdsFormatJSON = "format=json"
metadataHeaderKey = "Metadata"
Expand Down Expand Up @@ -79,7 +80,7 @@ func NewClient(opts ...ClientOption) *Client {
func (c *Client) GetVMUniqueID(ctx context.Context) (string, error) {
var vmUniqueID string
err := retry.Do(func() error {
computeDoc, err := c.getInstanceComputeMetadata(ctx)
computeDoc, err := c.getInstanceMetadata(ctx, imdsComputePath)
if err != nil {
return errors.Wrap(err, "error getting IMDS compute metadata")
}
Expand All @@ -102,10 +103,37 @@ func (c *Client) GetVMUniqueID(ctx context.Context) (string, error) {
return vmUniqueID, nil
}

func (c *Client) getInstanceComputeMetadata(ctx context.Context) (map[string]any, error) {
func (c *Client) GetNCVersions(ctx context.Context) ([]NetworkInterface, error) {
var networkData NetworkMetadata
err := retry.Do(func() error {
networkMetadata, err := c.getInstanceMetadata(ctx, imdsNetworkPath)
if err != nil {
return errors.Wrap(err, "error getting IMDS network metadata")
}

// Try to parse the network metadata as the expected structure
// Convert the map to JSON and back to properly unmarshal into struct
jsonData, err := json.Marshal(networkMetadata)
if err != nil {
return errors.Wrap(err, "error marshaling network metadata")
}

if err := json.Unmarshal(jsonData, &networkData); err != nil {
return errors.Wrap(err, "error unmarshaling network metadata")
}
return nil
}, retry.Context(ctx), retry.Attempts(c.config.retryAttempts), retry.DelayType(retry.BackOffDelay))
if err != nil {
return nil, errors.Wrap(err, "external call failed")
}

return networkData.Interface, nil
}

func (c *Client) getInstanceMetadata(ctx context.Context, imdsComputePath string) (map[string]any, error) {
imdsComputeURL, err := url.JoinPath(c.config.endpoint, imdsComputePath)
if err != nil {
return nil, errors.Wrap(err, "unable to build path to IMDS compute metadata")
return nil, errors.Wrap(err, "unable to build path to IMDS metadata for path"+imdsComputePath)
}
imdsComputeURL = imdsComputeURL + "?" + imdsComputeAPIVersion + "&" + imdsFormatJSON

Expand Down Expand Up @@ -133,3 +161,15 @@ func (c *Client) getInstanceComputeMetadata(ctx context.Context) (map[string]any

return m, nil
}

// NetworkInterface represents a network interface from IMDS
type NetworkInterface struct {
// IMDS only returns compartment fields - these are mapped to NC ID and NC version concepts
InterfaceCompartmentID string `json:"interfaceCompartmentID,omitempty"`
InterfaceCompartmentVersion string `json:"interfaceCompartmentVersion,omitempty"`
}

// NetworkMetadata represents the network metadata from IMDS
type NetworkMetadata struct {
Interface []NetworkInterface `json:"interface"`
}
115 changes: 115 additions & 0 deletions cns/imds/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,118 @@ func TestInvalidVMUniqueID(t *testing.T) {
require.Error(t, err, "error querying testserver")
require.Equal(t, "", vmUniqueID)
}

func TestGetNCVersions(t *testing.T) {
networkMetadata := []byte(`{
"interface": [
{
"interfaceCompartmentVersion": "1",
"interfaceCompartmentID": "nc-12345-67890"
},
{
"interfaceCompartmentVersion": "1",
"interfaceCompartmentID": ""
}
]
}`)

mockIMDSServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// request header "Metadata: true" must be present
metadataHeader := r.Header.Get("Metadata")
assert.Equal(t, "true", metadataHeader)

// verify path is network metadata
assert.Contains(t, r.URL.Path, "/metadata/instance/network")

// query params should include apiversion and json format
apiVersion := r.URL.Query().Get("api-version")
assert.Equal(t, "2021-01-01", apiVersion)
format := r.URL.Query().Get("format")
assert.Equal(t, "json", format)

w.WriteHeader(http.StatusOK)
_, writeErr := w.Write(networkMetadata)
if writeErr != nil {
t.Errorf("error writing response: %v", writeErr)
return
}
}))
defer mockIMDSServer.Close()

imdsClient := imds.NewClient(imds.Endpoint(mockIMDSServer.URL))
interfaces, err := imdsClient.GetNCVersions(context.Background())
require.NoError(t, err, "error querying testserver")

// Verify we got the expected interfaces
require.Len(t, interfaces, 2, "expected 2 interfaces")

// Check first interface
assert.Equal(t, "nc-12345-67890", interfaces[0].InterfaceCompartmentID)
assert.Equal(t, "1", interfaces[0].InterfaceCompartmentVersion)

// Check second interface
assert.Equal(t, "", interfaces[1].InterfaceCompartmentID)
assert.Equal(t, "1", interfaces[1].InterfaceCompartmentVersion)
}

func TestGetNCVersionsInvalidEndpoint(t *testing.T) {
imdsClient := imds.NewClient(imds.Endpoint(string([]byte{0x7f})), imds.RetryAttempts(1))
_, err := imdsClient.GetNCVersions(context.Background())
require.Error(t, err, "expected invalid path")
}

func TestGetNCVersionsInvalidJSON(t *testing.T) {
mockIMDSServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte("not json"))
if err != nil {
t.Errorf("error writing response: %v", err)
return
}
}))
defer mockIMDSServer.Close()

imdsClient := imds.NewClient(imds.Endpoint(mockIMDSServer.URL), imds.RetryAttempts(1))
_, err := imdsClient.GetNCVersions(context.Background())
require.Error(t, err, "expected json decoding error")
}

func TestGetNCVersionsNoNCIDs(t *testing.T) {
networkMetadataNoNC := []byte(`{
"interface": [
{
"ipv4": {
"ipAddress": [
{
"privateIpAddress": "10.0.0.4",
"publicIpAddress": ""
}
]
}
}
]
}`)

mockIMDSServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
metadataHeader := r.Header.Get("Metadata")
assert.Equal(t, "true", metadataHeader)

w.WriteHeader(http.StatusOK)
_, writeErr := w.Write(networkMetadataNoNC)
if writeErr != nil {
t.Errorf("error writing response: %v", writeErr)
return
}
}))
defer mockIMDSServer.Close()

imdsClient := imds.NewClient(imds.Endpoint(mockIMDSServer.URL))
interfaces, err := imdsClient.GetNCVersions(context.Background())
require.NoError(t, err, "error querying testserver")

// Verify we got interfaces but they don't have compartment IDs
require.Len(t, interfaces, 1, "expected 1 interface")

// Check that interfaces don't have compartment IDs
assert.Equal(t, "", interfaces[0].InterfaceCompartmentID)
}
107 changes: 94 additions & 13 deletions cns/restserver/internalapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ import (
"github.com/pkg/errors"
)

const (
// Known API names we care about
nmAgentSwiftV2API = "SwiftV2DhcpRehydrationFromGoalState"
)

// This file contains the internal functions called by either HTTP APIs (api.go) or
// internal APIs (definde in internalapi.go).
// This will be used internally (say by RequestController in case of AKS)
Expand Down Expand Up @@ -221,18 +226,36 @@ func (service *HTTPRestService) syncHostNCVersion(ctx context.Context, channelMo
return len(programmedNCs), errors.Wrap(err, "failed to get nc version list from nmagent")
}

// Get IMDS NC versions for delegated NIC scenarios
imdsNCVersions, err := service.GetIMDSNCVersions(ctx)
if err != nil {
// If any of the NMA API check calls, imds calls fails assume that nma build doesn't have the latest changes and create empty map
imdsNCVersions = make(map[string]string)
}

nmaNCs := map[string]string{}
for _, nc := range ncVersionListResp.Containers {
nmaNCs[strings.ToLower(nc.NetworkContainerID)] = nc.Version
}
hasNC.Set(float64(len(nmaNCs)))

// Consolidate both maps - NMA takes precedence, IMDS as fallback
consolidatedNCs := make(map[string]string)
for ncID, version := range nmaNCs {
consolidatedNCs[ncID] = version
}
for ncID, version := range imdsNCVersions {
if _, exists := consolidatedNCs[ncID]; !exists {
consolidatedNCs[strings.ToLower(ncID)] = version
}
}
hasNC.Set(float64(len(consolidatedNCs)))
for ncID := range outdatedNCs {
nmaNCVersionStr, ok := nmaNCs[ncID]
consolidatedNCVersionStr, ok := consolidatedNCs[ncID]
if !ok {
// NMA doesn't have this NC that we need programmed yet, bail out
// Neither NMA nor IMDS has this NC that we need programmed yet, bail out
continue
}
nmaNCVersion, err := strconv.Atoi(nmaNCVersionStr)
consolidatedNCVersion, err := strconv.Atoi(consolidatedNCVersionStr)
if err != nil {
logger.Errorf("failed to parse container version of %s: %s", ncID, err)
continue
Expand All @@ -245,7 +268,7 @@ func (service *HTTPRestService) syncHostNCVersion(ctx context.Context, channelMo
return len(programmedNCs), errors.Wrapf(errNonExistentContainerStatus, "can't find NC with ID %s in service state, stop updating this host NC version", ncID)
}
// if the NC still exists in state and is programmed to some version (doesn't have to be latest), add it to our set of NCs that have been programmed
if nmaNCVersion > -1 {
if consolidatedNCVersion > -1 {
programmedNCs[ncID] = struct{}{}
}

Expand All @@ -254,15 +277,17 @@ func (service *HTTPRestService) syncHostNCVersion(ctx context.Context, channelMo
logger.Errorf("failed to parse host nc version string %s: %s", ncInfo.HostVersion, err)
continue
}
if localNCVersion > nmaNCVersion {
logger.Errorf("NC version from NMA is decreasing: have %d, got %d", localNCVersion, nmaNCVersion)
if localNCVersion > consolidatedNCVersion {
//nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated
logger.Errorf("NC version from consolidated sources is decreasing: have %d, got %d", localNCVersion, consolidatedNCVersion)
continue
}
if channelMode == cns.CRD {
service.MarkIpsAsAvailableUntransacted(ncInfo.ID, nmaNCVersion)
service.MarkIpsAsAvailableUntransacted(ncInfo.ID, consolidatedNCVersion)
}
logger.Printf("Updating NC %s host version from %s to %s", ncID, ncInfo.HostVersion, nmaNCVersionStr)
ncInfo.HostVersion = nmaNCVersionStr
//nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated
logger.Printf("Updating NC %s host version from %s to %s", ncID, ncInfo.HostVersion, consolidatedNCVersionStr)
ncInfo.HostVersion = consolidatedNCVersionStr
logger.Printf("Updated NC %s host version to %s", ncID, ncInfo.HostVersion)
service.state.ContainerStatus[ncID] = ncInfo
// if we successfully updated the NC, pop it from the needs update set.
Expand All @@ -271,7 +296,7 @@ func (service *HTTPRestService) syncHostNCVersion(ctx context.Context, channelMo
// if we didn't empty out the needs update set, NMA has not programmed all the NCs we are expecting, and we
// need to return an error indicating that
if len(outdatedNCs) > 0 {
return len(programmedNCs), errors.Errorf("unabled to update some NCs: %v, missing or bad response from NMA", outdatedNCs)
return len(programmedNCs), errors.Errorf("unable to update some NCs: %v, missing or bad response from NMA or IMDS", outdatedNCs)
}

return len(programmedNCs), nil
Expand Down Expand Up @@ -332,7 +357,6 @@ func (service *HTTPRestService) ReconcileIPAssignment(podInfoByIP map[string]cns

for _, ip := range ips {
if ncReq, ok := allSecIPsIdx[ip.String()]; ok {
logger.Printf("secondary ip %s is assigned to pod %+v, ncId: %s ncVersion: %s", ip, podIPs, ncReq.NetworkContainerid, ncReq.Version)
desiredIPs = append(desiredIPs, ip.String())
ncIDs = append(ncIDs, ncReq.NetworkContainerid)
} else {
Expand Down Expand Up @@ -361,7 +385,8 @@ func (service *HTTPRestService) ReconcileIPAssignment(podInfoByIP map[string]cns
}

if _, err := requestIPConfigsHelper(service, ipconfigsRequest); err != nil {
logger.Errorf("requestIPConfigsHelper failed for pod key %s, podInfo %+v, ncIds %v, error: %v", podKey, podIPs, ncIDs, err)
//nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated
logger.Errorf("requestIPConfigsHelper failed for pod key %s, podInfo %+v, ncIDs %v, error: %v", podKey, podIPs, ncIDs, err)
return types.FailedToAllocateIPConfig
}
}
Expand Down Expand Up @@ -634,3 +659,59 @@ func (service *HTTPRestService) CreateOrUpdateNetworkContainerInternal(req *cns.
func (service *HTTPRestService) SetVFForAccelnetNICs() error {
return service.setVFForAccelnetNICs()
}

// checkNMAgentAPISupport checks if specific APIs are supported by NMAgent using the existing client
func (service *HTTPRestService) checkNMAgentAPISupport(ctx context.Context) (swiftV2Support bool, err error) {
// Use the existing NMAgent client instead of direct HTTP calls
if service.nma == nil {
return false, errors.New("NMAgent client is not available")
}

apis, err := service.nma.SupportedAPIs(ctx)
if err != nil {
return false, errors.New("failed to get supported APIs from NMAgent client")
}

for _, api := range apis {
if strings.Contains(api, nmAgentSwiftV2API) {
swiftV2Support = true
}
}

return swiftV2Support, nil
}

// GetIMDSNCVersions gets NC versions from IMDS and returns them as a map
func (service *HTTPRestService) GetIMDSNCVersions(ctx context.Context) (map[string]string, error) {
// Check NMAgent API support for SwiftV2, if it fails return empty map assuming support might not be available in that nma build
swiftV2Support, err := service.checkNMAgentAPISupport(ctx)
if err != nil || !swiftV2Support {
//nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated
logger.Errorf("NMAgent does not support SwiftV2 API or encountered an error: %v", err)
return make(map[string]string), nil
}

imdsClient := service.imdsClient

// Get all NC versions from IMDS
networkInterfaces, err := imdsClient.GetNCVersions(ctx)
if err != nil {
//nolint:staticcheck // SA1019: suppress deprecated logger.Printf usage. Todo: legacy logger usage is consistent in cns repo. Migrates when all logger usage is migrated
logger.Errorf("Failed to get NC versions from IMDS: %v", err)
return make(map[string]string), nil
}

// Build ncVersions map from the network interfaces
ncVersions := make(map[string]string)
for _, iface := range networkInterfaces {
// IMDS returns interfaceCompartmentID, interfaceCompartmentVersion fields, as nc id guid has different context on nma. We map these to NC ID and NC version
ncID := iface.InterfaceCompartmentID
ncVersion := iface.InterfaceCompartmentVersion

if ncID != "" {
ncVersions[ncID] = ncVersion
}
}

return ncVersions, nil
}
Loading
Loading