Skip to content
38 changes: 38 additions & 0 deletions cns/fakes/imdsclientfake.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package fakes

import (
"context"
"net"

"github.com/Azure/azure-container-networking/cns/imds"
"github.com/Azure/azure-container-networking/cns/wireserver"
Expand Down Expand Up @@ -57,3 +58,40 @@ func (m *MockIMDSClient) GetVMUniqueID(ctx context.Context) (string, error) {

return "55b8499d-9b42-4f85-843f-24ff69f4a643", nil
}

func (m *MockIMDSClient) GetNetworkInterfaces(ctx context.Context) ([]imds.NetworkInterface, error) {
if ctx.Value(SimulateError) != nil {
return nil, imds.ErrUnexpectedStatusCode
}

// Parse MAC addresses for testing
macAddr1, _ := net.ParseMAC("00:15:5d:01:02:01")
macAddr2, _ := net.ParseMAC("00:15:5d:01:02:02")

// Return some mock network interfaces for testing
return []imds.NetworkInterface{
{
InterfaceCompartmentID: "nc1",
MacAddress: imds.HardwareAddr(macAddr1),
},
{
InterfaceCompartmentID: "nc2",
MacAddress: imds.HardwareAddr(macAddr2),
},
}, nil
}

func (m *MockIMDSClient) GetIMDSVersions(ctx context.Context) (*imds.APIVersionsResponse, error) {
if ctx.Value(SimulateError) != nil {
return nil, imds.ErrUnexpectedStatusCode
}

// Return supported API versions including the expected one
return &imds.APIVersionsResponse{
APIVersions: []string{
"2017-03-01",
"2021-01-01",
"2025-07-24",
},
}, nil
}
127 changes: 120 additions & 7 deletions cns/imds/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package imds
import (
"context"
"encoding/json"
"net"
"net/http"
"net/url"

Expand Down Expand Up @@ -46,7 +47,10 @@ func RetryAttempts(attempts uint) ClientOption {
const (
vmUniqueIDProperty = "vmId"
imdsComputePath = "/metadata/instance/compute"
imdsComputeAPIVersion = "api-version=2021-01-01"
imdsNetworkPath = "/metadata/instance/network"
imdsVersionsPath = "/metadata/versions"
imdsDefaultAPIVersion = "api-version=2021-01-01"
imdsNCDetailsVersion = "api-version=2025-07-24"
imdsFormatJSON = "format=json"
metadataHeaderKey = "Metadata"
metadataHeaderValue = "true"
Expand Down Expand Up @@ -79,7 +83,7 @@ func NewClient(opts ...ClientOption) *Client {
func (c *Client) GetVMUniqueID(ctx context.Context) (string, error) {
var vmUniqueID string
err := retry.Do(func() error {
computeDoc, err := c.getInstanceComputeMetadata(ctx)
computeDoc, err := c.getInstanceMetadata(ctx, imdsComputePath, imdsDefaultAPIVersion)
if err != nil {
return errors.Wrap(err, "error getting IMDS compute metadata")
}
Expand All @@ -102,14 +106,40 @@ func (c *Client) GetVMUniqueID(ctx context.Context) (string, error) {
return vmUniqueID, nil
}

func (c *Client) getInstanceComputeMetadata(ctx context.Context) (map[string]any, error) {
imdsComputeURL, err := url.JoinPath(c.config.endpoint, imdsComputePath)
func (c *Client) GetNetworkInterfaces(ctx context.Context) ([]NetworkInterface, error) {
var networkData NetworkInterfaces
err := retry.Do(func() error {
networkInterfaces, err := c.getInstanceMetadata(ctx, imdsNetworkPath, imdsNCDetailsVersion)
if err != nil {
return errors.Wrap(err, "error getting IMDS network metadata")
}

// Parse the network metadata to the expected structure
jsonData, err := json.Marshal(networkInterfaces)
if err != nil {
return errors.Wrap(err, "error marshaling network metadata")
}

if err := json.Unmarshal(jsonData, &networkData); err != nil {
return errors.Wrap(err, "error unmarshaling network metadata")
}
return nil
}, retry.Context(ctx), retry.Attempts(c.config.retryAttempts), retry.DelayType(retry.BackOffDelay))
if err != nil {
return nil, errors.Wrap(err, "unable to build path to IMDS compute metadata")
return nil, errors.Wrap(err, "external call failed")
}
imdsComputeURL = imdsComputeURL + "?" + imdsComputeAPIVersion + "&" + imdsFormatJSON

req, err := http.NewRequestWithContext(ctx, http.MethodGet, imdsComputeURL, http.NoBody)
return networkData.Interface, nil
}

func (c *Client) getInstanceMetadata(ctx context.Context, imdsMetadataPath, imdsAPIVersion string) (map[string]any, error) {
imdsRequestURL, err := url.JoinPath(c.config.endpoint, imdsMetadataPath)
if err != nil {
return nil, errors.Wrap(err, "unable to build path to IMDS metadata for path"+imdsMetadataPath)
}
imdsRequestURL = imdsRequestURL + "?" + imdsAPIVersion + "&" + imdsFormatJSON

req, err := http.NewRequestWithContext(ctx, http.MethodGet, imdsRequestURL, http.NoBody)
if err != nil {
return nil, errors.Wrap(err, "error building IMDS http request")
}
Expand All @@ -133,3 +163,86 @@ func (c *Client) getInstanceComputeMetadata(ctx context.Context) (map[string]any

return m, nil
}

func (c *Client) GetIMDSVersions(ctx context.Context) (*APIVersionsResponse, error) {
var versionsResp APIVersionsResponse
err := retry.Do(func() error {
// Build the URL for the versions endpoint
imdsRequestURL, err := url.JoinPath(c.config.endpoint, imdsVersionsPath)
if err != nil {
return errors.Wrap(err, "unable to build path to IMDS versions endpoint")
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, imdsRequestURL, http.NoBody)
if err != nil {
return errors.Wrap(err, "error building IMDS versions http request")
}

req.Header.Add(metadataHeaderKey, metadataHeaderValue)
resp, err := c.cli.Do(req)
if err != nil {
return errors.Wrap(err, "error querying IMDS versions API")
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return errors.Wrapf(ErrUnexpectedStatusCode, "unexpected status code %d", resp.StatusCode)
}

if err := json.NewDecoder(resp.Body).Decode(&versionsResp); err != nil {
return errors.Wrap(err, "error decoding IMDS versions response as json")
}

return nil
}, retry.Context(ctx), retry.Attempts(c.config.retryAttempts), retry.DelayType(retry.BackOffDelay))
if err != nil {
return nil, errors.Wrap(err, "exhausted retries querying IMDS versions")
}

return &versionsResp, nil
}

// Required for marshaling/unmarshaling of mac address
type HardwareAddr net.HardwareAddr

func (h *HardwareAddr) MarshalJSON() ([]byte, error) {
data, err := json.Marshal(net.HardwareAddr(*h).String())
if err != nil {
return nil, errors.Wrap(err, "failed to marshal hardware address")
}
return data, nil
}

func (h *HardwareAddr) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return errors.Wrap(err, "failed to unmarshal JSON data")
}
mac, err := net.ParseMAC(s)
if err != nil {
return errors.Wrap(err, "failed to parse MAC address")
}
*h = HardwareAddr(mac)
return nil
}

func (h *HardwareAddr) String() string {
return net.HardwareAddr(*h).String()
}

// NetworkInterface represents a network interface from IMDS
type NetworkInterface struct {
// IMDS returns compartment fields - these are mapped to NC ID and NC version
MacAddress HardwareAddr `json:"macAddress"`
InterfaceCompartmentID string `json:"interfaceCompartmentID,omitempty"`
}

// NetworkInterfaces represents the network interfaces from IMDS
type NetworkInterfaces struct {
Interface []NetworkInterface `json:"interface"`
}

// APIVersionsResponse represents versions form IMDS
type APIVersionsResponse struct {
APIVersions []string `json:"apiVersions"`
}
Loading
Loading