diff --git a/cns/service/main.go b/cns/service/main.go index 36f24dffa7..a3cd5f49ad 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -26,6 +26,7 @@ import ( "github.com/Azure/azure-container-networking/cns/cnireconciler" "github.com/Azure/azure-container-networking/cns/common" "github.com/Azure/azure-container-networking/cns/configuration" + "github.com/Azure/azure-container-networking/cns/deviceplugin" "github.com/Azure/azure-container-networking/cns/endpointmanager" "github.com/Azure/azure-container-networking/cns/fsnotify" "github.com/Azure/azure-container-networking/cns/grpc" @@ -65,6 +66,7 @@ import ( "github.com/Azure/azure-container-networking/store" "github.com/Azure/azure-container-networking/telemetry" "github.com/avast/retry-go/v4" + "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -105,9 +107,14 @@ const ( // envVarEnableCNIConflistGeneration enables cni conflist generation if set (value doesn't matter) envVarEnableCNIConflistGeneration = "CNS_ENABLE_CNI_CONFLIST_GENERATION" - cnsReqTimeout = 15 * time.Second - defaultLocalServerIP = "localhost" - defaultLocalServerPort = "10090" + cnsReqTimeout = 15 * time.Second + defaultLocalServerIP = "localhost" + defaultLocalServerPort = "10090" + defaultDevicePluginRetryInterval = 2 * time.Second + defaultNodeInfoCRDPollInterval = 5 * time.Second + defaultDevicePluginMaxRetryCount = 5 + initialVnetNICCount = 0 + initialIBNICCount = 0 ) type cniConflistScenario string @@ -910,6 +917,50 @@ func main() { } } + if cnsconfig.EnableSwiftV2 && cnsconfig.EnableK8sDevicePlugin { + // Create device plugin manager instance + pluginManager := deviceplugin.NewPluginManager(z) + pluginManager.AddPlugin(mtv1alpha1.DeviceTypeVnetNIC, initialVnetNICCount) + pluginManager.AddPlugin(mtv1alpha1.DeviceTypeInfiniBandNIC, initialIBNICCount) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start device plugin manager in a separate goroutine + go func() { + retryCount := 0 + ticker := time.NewTicker(defaultDevicePluginRetryInterval) + // Ensure the ticker is stopped on exit + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + z.Info("Context canceled, stopping plugin manager") + return + case <-ticker.C: + if pluginErr := pluginManager.Run(ctx); pluginErr != nil { + z.Error("plugin manager exited with error", zap.Error(pluginErr)) + retryCount++ + // Implementing a basic circuit breaker + if retryCount >= defaultDevicePluginMaxRetryCount { + z.Error("Max retries reached, stopping plugin manager") + return + } + } else { + return + } + } + } + }() + + // go routine to poll node info crd and update device counts + go func() { + if pollErr := pollNodeInfoCRDAndUpdatePlugin(ctx, z, pluginManager); pollErr != nil { + z.Error("Error in pollNodeInfoCRDAndUpdatePlugin", zap.Error(pollErr)) + } + }() + } + // Conditionally initialize and start the gRPC server if cnsconfig.GRPCSettings.Enable { // Define gRPC server settings @@ -1083,6 +1134,91 @@ func main() { logger.Close() } +// Poll CRD until it's set and update PluginManager +func pollNodeInfoCRDAndUpdatePlugin(ctx context.Context, zlog *zap.Logger, pluginManager *deviceplugin.PluginManager) error { + kubeConfig, err := ctrl.GetConfig() + if err != nil { + logger.Errorf("Failed to get kubeconfig for request controller: %v", err) + return errors.Wrap(err, "failed to get kubeconfig") + } + kubeConfig.UserAgent = "azure-cns-" + version + + clientset, err := kubernetes.NewForConfig(kubeConfig) + if err != nil { + return errors.Wrap(err, "failed to build clientset") + } + + nodeName, err := configuration.NodeName() + if err != nil { + return errors.Wrap(err, "failed to get NodeName") + } + + node, err := clientset.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{}) + if err != nil { + return errors.Wrapf(err, "failed to get node %s", nodeName) + } + + // check the Node labels for Swift V2 + if _, ok := node.Labels[configuration.LabelNodeSwiftV2]; !ok { + zlog.Info("Node is not labeled for Swift V2, skipping polling nodeinfo crd") + return nil + } + + directcli, err := client.New(kubeConfig, client.Options{Scheme: multitenancy.Scheme}) + if err != nil { + return errors.Wrap(err, "failed to create ctrl client") + } + + nodeInfoCli := multitenancy.NodeInfoClient{ + Cli: directcli, + } + + ticker := time.NewTicker(defaultNodeInfoCRDPollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + zlog.Info("Polling context canceled, exiting") + return nil + case <-ticker.C: + // Fetch the CRD status + nodeInfo, err := nodeInfoCli.Get(ctx, node.Name) + if err != nil { + zlog.Error("Error fetching nodeinfo CRD", zap.Error(err)) + return errors.Wrap(err, "failed to get nodeinfo crd") + } + + // Check if the status is set + if !cmp.Equal(nodeInfo.Status, mtv1alpha1.NodeInfoStatus{}) && len(nodeInfo.Status.DeviceInfos) > 0 { + // Create a map to count devices by type + deviceCounts := map[mtv1alpha1.DeviceType]int{ + mtv1alpha1.DeviceTypeVnetNIC: 0, + mtv1alpha1.DeviceTypeInfiniBandNIC: 0, + } + + // Aggregate device counts from the CRD + for _, deviceInfo := range nodeInfo.Status.DeviceInfos { + switch deviceInfo.DeviceType { + case mtv1alpha1.DeviceTypeVnetNIC, mtv1alpha1.DeviceTypeInfiniBandNIC: + deviceCounts[deviceInfo.DeviceType]++ + default: + zlog.Error("Unknown device type", zap.String("deviceType", string(deviceInfo.DeviceType))) + } + } + + // Update the plugin manager with device counts + for deviceType, count := range deviceCounts { + pluginManager.TrackDevices(deviceType, count) + } + + // Exit polling loop once the CRD status is successfully processed + return nil + } + } + } +} + func InitializeMultiTenantController(ctx context.Context, httpRestService cns.HTTPService, cnsconfig configuration.CNSConfig) error { var multiTenantController multitenantcontroller.RequestController kubeConfig, err := ctrl.GetConfig() diff --git a/crd/multitenancy/client.go b/crd/multitenancy/client.go index bfd7a0061e..1c2065ad2d 100644 --- a/crd/multitenancy/client.go +++ b/crd/multitenancy/client.go @@ -216,3 +216,12 @@ func (n *NodeInfoClient) CreateOrUpdate(ctx context.Context, nodeInfo *v1alpha1. } return nil } + +// Get retrieves the NodeInfo CRD by name. +func (n *NodeInfoClient) Get(ctx context.Context, name string) (*v1alpha1.NodeInfo, error) { + var nodeInfo v1alpha1.NodeInfo + if err := n.Cli.Get(ctx, client.ObjectKey{Name: name}, &nodeInfo); err != nil { + return nil, errors.Wrap(err, "error getting nodeinfo crd") + } + return &nodeInfo, nil +}