diff --git a/internal/testutil/fixture/kubernetes.go b/internal/testutil/fixture/kubernetes.go index 1a446c04b6..78b7a4f0df 100644 --- a/internal/testutil/fixture/kubernetes.go +++ b/internal/testutil/fixture/kubernetes.go @@ -159,6 +159,42 @@ func (f *KubernetesServiceFixture) WithIngressIPs(ips []string) *KubernetesServi return f } +func (f *KubernetesServiceFixture) WithOnlyTCPPorts() *KubernetesServiceFixture { + f.svc.Spec.Ports = []v1.ServicePort{ + { + Name: "http", + Protocol: v1.ProtocolTCP, + Port: 80, + NodePort: 50080, + }, + { + Name: "dns-tcp", + Protocol: v1.ProtocolTCP, + Port: 53, + NodePort: 50053, + }, + { + Name: "https", + Protocol: v1.ProtocolTCP, + Port: 443, + NodePort: 50443, + }, + } + return f +} + +func (f *KubernetesServiceFixture) WithOnlyUDPPorts() *KubernetesServiceFixture { + f.svc.Spec.Ports = []v1.ServicePort{ + { + Name: "dns-udp", + Protocol: v1.ProtocolUDP, + Port: 53, + NodePort: 50053, + }, + } + return f +} + func (f *KubernetesServiceFixture) Build() v1.Service { return f.svc } diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index 658b7fd1fc..ee96c14cb8 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -236,6 +236,11 @@ const ( LoadBalancerSKUBasic = "basic" // LoadBalancerSKUStandard is the load balancer standard SKU LoadBalancerSKUStandard = "standard" + // LoadBalancerSKUService is the load balancer service SKU + LoadBalancerSKUService = "service" + + // PodLabelServiceEgressGateway is the label used on the pod + PodLabelServiceEgressGateway = "kubernetes.azure.com/service-egress-gateway" // ServiceAnnotationLoadBalancerInternal is the annotation used on the service ServiceAnnotationLoadBalancerInternal = "service.beta.kubernetes.io/azure-load-balancer-internal" @@ -364,6 +369,8 @@ const ( BackendPoolIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/backendAddressPools/%s" // LoadBalancerProbeIDTemplate is the template of the load balancer probe LoadBalancerProbeIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/probes/%s" + // NatGatewayIDTemplate is the template of the nat gateway + NatGatewayIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/natGateways/%s" // InternalLoadBalancerNameSuffix is load balancer suffix InternalLoadBalancerNameSuffix = "-internal" @@ -381,9 +388,8 @@ const ( LoadBalancerBackendPoolConfigurationTypeNodeIPConfiguration = "nodeIPConfiguration" // LoadBalancerBackendPoolConfigurationTypeNodeIP is the lb backend pool config type node ip LoadBalancerBackendPoolConfigurationTypeNodeIP = "nodeIP" - // LoadBalancerBackendPoolConfigurationTypePODIP is the lb backend pool config type pod ip - // TODO (nilo19): support pod IP in the future - LoadBalancerBackendPoolConfigurationTypePODIP = "podIP" + // LoadBalancerBackendPoolConfigurationTypePodIP is the lb backend pool config type pod ip + LoadBalancerBackendPoolConfigurationTypePodIP = "podIP" ) // error messages diff --git a/pkg/provider/azure.go b/pkg/provider/azure.go index 6dec5d5822..fde3607c37 100644 --- a/pkg/provider/azure.go +++ b/pkg/provider/azure.go @@ -260,6 +260,18 @@ func (az *Cloud) InitializeCloudFromConfig(ctx context.Context, config *config.C return fmt.Errorf("InitializeCloudFromConfig: cannot initialize from nil config") } + // Use a single flag to determine if the service gateway is enabled. + // All 3 conditions must be true: + // 1. ServiceGatewayEnabled is true + // 2. lb sku is service + // 3. backendPoolType is PodIP + if az.ServiceGatewayEnabled && az.IsLBBackendPoolTypePodIPAndUseServiceLoadBalancer() { + klog.V(2).Info("InitializeCloudFromConfig: Service Gateway is enabled, using PodIP backend pool type with Service Load Balancer") + az.ServiceGatewayEnabled = true + } else { + az.ServiceGatewayEnabled = false + } + if config.RouteTableResourceGroup == "" { config.RouteTableResourceGroup = config.ResourceGroup } @@ -298,15 +310,13 @@ func (az *Cloud) InitializeCloudFromConfig(ctx context.Context, config *config.C } } - if config.LoadBalancerBackendPoolConfigurationType == "" || - // TODO(nilo19): support pod IP mode in the future - strings.EqualFold(config.LoadBalancerBackendPoolConfigurationType, consts.LoadBalancerBackendPoolConfigurationTypePODIP) { + if config.LoadBalancerBackendPoolConfigurationType == "" { config.LoadBalancerBackendPoolConfigurationType = consts.LoadBalancerBackendPoolConfigurationTypeNodeIPConfiguration } else { supportedLoadBalancerBackendPoolConfigurationTypes := utilsets.NewString( strings.ToLower(consts.LoadBalancerBackendPoolConfigurationTypeNodeIPConfiguration), strings.ToLower(consts.LoadBalancerBackendPoolConfigurationTypeNodeIP), - strings.ToLower(consts.LoadBalancerBackendPoolConfigurationTypePODIP)) + strings.ToLower(consts.LoadBalancerBackendPoolConfigurationTypePodIP)) if !supportedLoadBalancerBackendPoolConfigurationTypes.Has(strings.ToLower(config.LoadBalancerBackendPoolConfigurationType)) { return fmt.Errorf("loadBalancerBackendPoolConfigurationType %s is not supported, supported values are %v", config.LoadBalancerBackendPoolConfigurationType, supportedLoadBalancerBackendPoolConfigurationTypes.UnsortedList()) } diff --git a/pkg/provider/azure_fakes.go b/pkg/provider/azure_fakes.go index 9fdcff1300..8e818added 100644 --- a/pkg/provider/azure_fakes.go +++ b/pkg/provider/azure_fakes.go @@ -17,6 +17,8 @@ limitations under the License. package provider import ( + "net/netip" + "go.uber.org/mock/gomock" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/informers" @@ -29,6 +31,7 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/azclient/diskclient/mock_diskclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/interfaceclient/mock_interfaceclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/loadbalancerclient/mock_loadbalancerclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/managedclusterclient/mock_managedclusterclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/privateendpointclient/mock_privateendpointclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/privatelinkserviceclient/mock_privatelinkserviceclient" @@ -138,6 +141,9 @@ func GetTestCloud(ctrl *gomock.Controller) (az *Cloud) { virtualMachinesClient := mock_virtualmachineclient.NewMockInterface(ctrl) clientFactory.EXPECT().GetVirtualMachineClient().Return(virtualMachinesClient).AnyTimes() + managedClusterClient := mock_managedclusterclient.NewMockInterface(ctrl) + clientFactory.EXPECT().GetManagedClusterClient().Return(managedClusterClient).AnyTimes() + securtyGrouptrack2Client := mock_securitygroupclient.NewMockInterface(ctrl) clientFactory.EXPECT().GetSecurityGroupClient().Return(securtyGrouptrack2Client).AnyTimes() mockPrivateDNSClient := mock_privatezoneclient.NewMockInterface(ctrl) @@ -186,3 +192,24 @@ func GetTestCloudWithExtendedLocation(ctrl *gomock.Controller) (az *Cloud) { az.Config.ExtendedLocationType = "EdgeZone" return az } + +// GetTestCloudWithContainerLoadBalancer returns a fake azure cloud for unit tests in Azure supporting container load balancer. +func GetTestCloudWithContainerLoadBalancer(ctrl *gomock.Controller) (az *Cloud) { + az = GetTestCloud(ctrl) + az.LoadBalancerBackendPoolConfigurationType = consts.LoadBalancerBackendPoolConfigurationTypePodIP + az.LoadBalancerSKU = consts.LoadBalancerSKUService + az.ServiceGatewayEnabled = true + return az +} + +func GetTestCloudWithContainerLoadBalancerAndPrefixCidr(ctrl *gomock.Controller, isIPv6 bool) (az *Cloud) { + az = GetTestCloudWithContainerLoadBalancer(ctrl) + if !isIPv6 { + prefix, _ := netip.ParsePrefix("10.0.0.1/32") + az.PodCidrsIPv4 = []netip.Prefix{prefix} + } else { + prefix, _ := netip.ParsePrefix("2001:db8::/64") + az.PodCidrsIPv6 = []netip.Prefix{prefix} + } + return az +} diff --git a/pkg/provider/azure_natgateway_repo.go b/pkg/provider/azure_natgateway_repo.go new file mode 100644 index 0000000000..902c005837 --- /dev/null +++ b/pkg/provider/azure_natgateway_repo.go @@ -0,0 +1,101 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package provider + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "k8s.io/klog/v2" + "k8s.io/utils/ptr" +) + +func (az *Cloud) getNatGateway(ctx context.Context, natGatewayResourceGroup string, natGatewayName string) (*armnetwork.NatGateway, error) { + klog.Infof("NatGatewayClient.Get(%s) in resource group %s: start", natGatewayName, natGatewayResourceGroup) + result, err := az.NetworkClientFactory.GetNatGatewayClient().Get(ctx, natGatewayResourceGroup, natGatewayName, nil) + if err != nil { + klog.Errorf("NatGatewayClient.Get(%s) in resource group %s failed: %v", natGatewayName, natGatewayResourceGroup, err) + return nil, err + } + klog.V(10).Infof("NatGatewayClient.Get(%s) in resource group %s: success", natGatewayName, natGatewayResourceGroup) + klog.Infof("NatGatewayClient.Get(%s) in resource group %s: end, error: nil", natGatewayName, natGatewayResourceGroup) + return result, nil +} + +// CreateOrUpdateLB invokes az.NetworkClientFactory.GetLoadBalancerClient().CreateOrUpdate with exponential backoff retry +func (az *Cloud) createOrUpdateNatGateway(ctx context.Context, natGatewayResourceGroup string, natGateway armnetwork.NatGateway) error { + natGatewayName := ptr.Deref(natGateway.Name, "") + klog.Infof("NatGatewayClient.CreateOrUpdate(%s): start", natGatewayName) + + // Endless retry loop with 5-second intervals + for { + _, err := az.NetworkClientFactory.GetNatGatewayClient().CreateOrUpdate(ctx, natGatewayResourceGroup, natGatewayName, natGateway) + if err == nil { + klog.V(10).Infof("NatGatewayClient.CreateOrUpdate(%s): success", natGatewayName) + klog.Infof("NatGatewayClient.CreateOrUpdate(%s): end, error: nil", natGatewayName) + return nil + } + + natGatewayJSON, _ := json.Marshal(natGateway) + klog.Warningf("NatGatewayClient.CreateOrUpdate(%s) failed: %v, NatGateway request: %s", natGatewayName, err, string(natGatewayJSON)) + + // Check if context is canceled + select { + case <-ctx.Done(): + klog.V(3).Infof("createOrUpdateNatGateway: context canceled, stopping retry") + return fmt.Errorf("context canceled: %w", ctx.Err()) + default: + // Continue with retry + } + + // Wait 5 seconds before retrying + klog.V(3).Infof("createOrUpdateNatGateway: retrying in 5 seconds for NAT Gateway %s", natGatewayName) + time.Sleep(5 * time.Second) + } +} + +func (az *Cloud) deleteNatGateway(ctx context.Context, natGatewayResourceGroup string, natGatewayName string) error { + klog.Infof("NatGatewayClient.Delete(%s) in resource group %s: start", natGatewayName, natGatewayResourceGroup) + + // Endless retry loop with 5-second intervals + for { + err := az.NetworkClientFactory.GetNatGatewayClient().Delete(ctx, natGatewayResourceGroup, natGatewayName) + if err == nil { + klog.V(10).Infof("NatGatewayClient.Delete(%s) in resource group %s: success", natGatewayName, natGatewayResourceGroup) + klog.Infof("NatGatewayClient.Delete(%s) in resource group %s: end, error: nil", natGatewayName, natGatewayResourceGroup) + return nil + } + + klog.Errorf("NatGatewayClient.Delete(%s) in resource group %s failed: %v", natGatewayName, natGatewayResourceGroup, err) + + // Check if context is canceled + select { + case <-ctx.Done(): + klog.V(3).Infof("deleteNatGateway: context canceled, stopping retry") + return fmt.Errorf("context canceled: %w", ctx.Err()) + default: + // Continue with retry + } + + // Wait 5 seconds before retrying + klog.V(3).Infof("deleteNatGateway: retrying in 5 seconds for NAT Gateway %s", natGatewayName) + time.Sleep(5 * time.Second) + } +} diff --git a/pkg/provider/azure_publicip_repo.go b/pkg/provider/azure_publicip_repo.go index c5500005cf..a14f744149 100644 --- a/pkg/provider/azure_publicip_repo.go +++ b/pkg/provider/azure_publicip_repo.go @@ -38,6 +38,32 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/util/deepcopy" ) +func (az *Cloud) CreateOrUpdatePIPOutbound(ctx context.Context, pipResourceGroup string, pip *armnetwork.PublicIPAddress) error { + klog.Infof("CreateOrUpdatePIPOutbound(%s): start", ptr.Deref(pip.Name, "")) + + // Endless retry loop with 5-second intervals + for { + // Call the existing CreateOrUpdatePIP function + err := az.CreateOrUpdatePIP(nil, pipResourceGroup, pip) + if err == nil { + return nil + } + + // Check if context is canceled + select { + case <-ctx.Done(): + klog.V(3).Infof("CreateOrUpdatePIPOutbound: context canceled, stopping retry") + return fmt.Errorf("context canceled: %w", ctx.Err()) + default: + // Continue with retry + } + + // Wait 5 seconds before retrying + klog.V(3).Infof("CreateOrUpdatePIPOutbound: retrying in 5 seconds for PIP %s", ptr.Deref(pip.Name, "")) + time.Sleep(5 * time.Second) + } +} + // CreateOrUpdatePIP invokes az.NetworkClientFactory.GetPublicIPAddressClient().CreateOrUpdate with exponential backoff retry func (az *Cloud) CreateOrUpdatePIP(service *v1.Service, pipResourceGroup string, pip *armnetwork.PublicIPAddress) error { ctx, cancel := getContextWithCancel() @@ -74,6 +100,32 @@ func (az *Cloud) CreateOrUpdatePIP(service *v1.Service, pipResourceGroup string, return rerr } +func (az *Cloud) DeletePublicIPOutbound(ctx context.Context, pipResourceGroup string, pipName string) error { + klog.Infof("DeletePublicIPOutbound(%s): start", pipName) + + // Endless retry loop with 5-second intervals + for { + // Call the existing DeletePublicIP function + err := az.DeletePublicIP(nil, pipResourceGroup, pipName) + if err == nil { + return nil + } + + // Check if context is canceled + select { + case <-ctx.Done(): + klog.V(3).Infof("DeletePublicIPOutbound: context canceled, stopping retry") + return fmt.Errorf("context canceled: %w", ctx.Err()) + default: + // Continue with retry + } + + // Wait 5 seconds before retrying + klog.V(3).Infof("DeletePublicIPOutbound: retrying in 5 seconds for PIP %s", pipName) + time.Sleep(5 * time.Second) + } +} + // DeletePublicIP invokes az.NetworkClientFactory.GetPublicIPAddressClient().Delete with exponential backoff retry func (az *Cloud) DeletePublicIP(service *v1.Service, pipResourceGroup string, pipName string) error { ctx, cancel := getContextWithCancel() diff --git a/pkg/provider/azure_standard.go b/pkg/provider/azure_standard.go index 40f32222db..b2bc6f32f5 100644 --- a/pkg/provider/azure_standard.go +++ b/pkg/provider/azure_standard.go @@ -325,6 +325,11 @@ func getServiceName(service *v1.Service) string { return fmt.Sprintf("%s/%s", service.Namespace, service.Name) } +// This returns a unique identifier for the Service used to tag some resources. +func getServiceUID(service *v1.Service) string { + return strings.ToLower(string(service.UID)) +} + // This returns a prefix for loadbalancer/security rules. func (az *Cloud) getRulePrefix(service *v1.Service) string { return az.GetLoadBalancerName(context.TODO(), "", service) @@ -332,7 +337,16 @@ func (az *Cloud) getRulePrefix(service *v1.Service) string { func (az *Cloud) getPublicIPName(clusterName string, service *v1.Service, isIPv6 bool) (string, error) { isDualStack := isServiceDualStack(service) - pipName := fmt.Sprintf("%s-%s", clusterName, az.GetLoadBalancerName(context.TODO(), clusterName, service)) + + var pipName string + if az.ServiceGatewayEnabled { + // Base name: -pip + pipName = fmt.Sprintf("%s-pip", getServiceUID(service)) + } else { + // Legacy scheme: tied to clusterName — per-cluster naming. + pipName = fmt.Sprintf("%s-%s", clusterName, az.GetLoadBalancerName(context.TODO(), clusterName, service)) + } + if id := getServicePIPPrefixID(service, isIPv6); id != "" { id, err := getLastSegment(id, "/") if err == nil { diff --git a/pkg/provider/azure_test.go b/pkg/provider/azure_test.go index d321dc4e8a..cf23ed9053 100644 --- a/pkg/provider/azure_test.go +++ b/pkg/provider/azure_test.go @@ -38,6 +38,7 @@ import ( v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes/fake" cloudprovider "k8s.io/cloud-provider" @@ -1252,6 +1253,67 @@ func getTestServiceCommon(identifier string, proto v1.Protocol, annotations map[ return svc } +func getTestServiceWithNamedTargetPorts(identifier string, proto v1.Protocol, annotations map[string]string, isIPv6 bool, servicePort int32, namedTargetPorts ...string) v1.Service { + targetPorts := []intstr.IntOrString{} + for _, port := range namedTargetPorts { + targetPorts = append(targetPorts, intstr.FromString(port)) + } + svc := getTestServiceWithTargetPortsCommon(identifier, proto, annotations, servicePort, targetPorts...) + svc.Spec.ClusterIP = "10.0.0.2" + svc.Spec.IPFamilies = []v1.IPFamily{v1.IPv4Protocol} + if isIPv6 { + svc.Spec.ClusterIP = "fd00::1907" + svc.Spec.IPFamilies = []v1.IPFamily{v1.IPv6Protocol} + } + + return svc +} + +func getTestServiceWithIntTargetPorts(identifier string, proto v1.Protocol, annotations map[string]string, isIPv6 bool, servicePort int32, intTargetPorts ...int32) v1.Service { + targetPorts := []intstr.IntOrString{} + for _, port := range intTargetPorts { + targetPorts = append(targetPorts, intstr.FromInt(int(port))) + } + svc := getTestServiceWithTargetPortsCommon(identifier, proto, annotations, servicePort, targetPorts...) + svc.Spec.ClusterIP = "10.0.0.2" + svc.Spec.IPFamilies = []v1.IPFamily{v1.IPv4Protocol} + if isIPv6 { + svc.Spec.ClusterIP = "fd00::1907" + svc.Spec.IPFamilies = []v1.IPFamily{v1.IPv6Protocol} + } + + return svc +} + +func getTestServiceWithTargetPortsCommon(identifier string, proto v1.Protocol, annotations map[string]string, servicePort int32, targetPorts ...intstr.IntOrString) v1.Service { + ports := []v1.ServicePort{} + for _, port := range targetPorts { + ports = append(ports, v1.ServicePort{ + Name: "target-port", + Protocol: proto, + TargetPort: port, + Port: servicePort, + }) + } + + svc := v1.Service{ + Spec: v1.ServiceSpec{ + Type: v1.ServiceTypeLoadBalancer, + Ports: ports, + }, + } + svc.Name = identifier + svc.Namespace = "default" + svc.UID = types.UID(identifier) + if annotations == nil { + svc.Annotations = make(map[string]string) + } else { + svc.Annotations = annotations + } + + return svc +} + func getInternalTestService(identifier string, requestedPorts ...int32) v1.Service { return getTestServiceWithAnnotation(identifier, map[string]string{consts.ServiceAnnotationLoadBalancerInternal: consts.TrueAnnotationValue}, false, requestedPorts...) } diff --git a/pkg/provider/config/azure.go b/pkg/provider/config/azure.go index 9d8801b9c3..0f1da6edf6 100644 --- a/pkg/provider/config/azure.go +++ b/pkg/provider/config/azure.go @@ -17,6 +17,7 @@ limitations under the License. package config import ( + "net/netip" "strings" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/configloader" @@ -165,6 +166,16 @@ type Config struct { ClusterServiceSharedLoadBalancerHealthProbePort int32 `json:"clusterServiceSharedLoadBalancerHealthProbePort,omitempty" yaml:"clusterServiceSharedLoadBalancerHealthProbePort,omitempty"` // ClusterServiceSharedLoadBalancerHealthProbePath defines the target path of the shared health probe. Default to `/healthz`. ClusterServiceSharedLoadBalancerHealthProbePath string `json:"clusterServiceSharedLoadBalancerHealthProbePath,omitempty" yaml:"clusterServiceSharedLoadBalancerHealthProbePath,omitempty"` + + // PodCidrsIPv4 is a slice of IPv4 pod subnet prefixes for the cluster. + // PodCidrsIPv6 is a slice of IPv6 pod subnet prefixes for the cluster. + // The pod subnet prefix is used to configure the NSG for the pod subnet. + // Pod CIDR would be opened to internet by default + PodCidrsIPv4 []netip.Prefix `json:"podCidrIPv4" yaml:"podCidrIPv4"` + PodCidrsIPv6 []netip.Prefix `json:"podCidrIPv6" yaml:"podCidrIPv6"` + + // ServiceGatewayEnabled indicates whether the service gateway is enabled for the cluster. + ServiceGatewayEnabled bool `json:"serviceGatewayEnabled,omitempty" yaml:"serviceGatewayEnabled,omitempty"` } // HasExtendedLocation returns true if extendedlocation prop are specified. @@ -180,6 +191,18 @@ func (az *Config) IsLBBackendPoolTypeNodeIP() bool { return strings.EqualFold(az.LoadBalancerBackendPoolConfigurationType, consts.LoadBalancerBackendPoolConfigurationTypeNodeIP) } +func (az *Config) UseServiceLoadBalancer() bool { + return strings.EqualFold(az.LoadBalancerSKU, consts.LoadBalancerSKUService) +} + +func (az *Config) IsLBBackendPoolTypePodIP() bool { + return strings.EqualFold(az.LoadBalancerBackendPoolConfigurationType, consts.LoadBalancerBackendPoolConfigurationTypePodIP) +} + +func (az *Config) IsLBBackendPoolTypePodIPAndUseServiceLoadBalancer() bool { + return az.IsLBBackendPoolTypePodIP() && az.UseServiceLoadBalancer() +} + func (az *Config) GetPutVMSSVMBatchSize() int { return az.PutVMSSVMBatchSize } diff --git a/pkg/util/sets/string.go b/pkg/util/sets/string.go index 2562fcf6ea..87275bcd69 100644 --- a/pkg/util/sets/string.go +++ b/pkg/util/sets/string.go @@ -17,6 +17,7 @@ limitations under the License. package sets import ( + "encoding/json" "strings" "k8s.io/apimachinery/pkg/util/sets" @@ -27,6 +28,37 @@ type IgnoreCaseSet struct { set sets.Set[string] } +func (s *IgnoreCaseSet) MarshalJSON() ([]byte, error) { + if s == nil { + return []byte("null"), nil + } + if s.Len() == 0 { + return []byte("[]"), nil + } + return json.Marshal(s.UnsortedList()) +} + +// Equals returns true if the two sets are equal. +func (s1 *IgnoreCaseSet) Equals(s2 *IgnoreCaseSet) bool { + // Early exit if sizes are different + if len(s1.UnsortedList()) != len(s2.UnsortedList()) { + return false + } + // Check if all items in s1 are in s2 + for _, item := range s1.UnsortedList() { + if !s2.Has(item) { + return false + } + } + // Check if all items in s2 are in s1 + for _, item := range s2.UnsortedList() { + if !s1.Has(item) { + return false + } + } + return true +} + // NewString creates a new IgnoreCaseSet with the given items. func NewString(items ...string) *IgnoreCaseSet { var lowerItems []string diff --git a/pkg/util/sets/string_test.go b/pkg/util/sets/string_test.go index b0286b7d3d..8db95dd2d7 100644 --- a/pkg/util/sets/string_test.go +++ b/pkg/util/sets/string_test.go @@ -367,3 +367,82 @@ func TestLen(t *testing.T) { }) } } + +func TestEquals(t *testing.T) { + tests := []struct { + name string + s1 *IgnoreCaseSet + s2 *IgnoreCaseSet + want bool + }{ + { + name: "both nil", + s1: nil, + s2: nil, + want: true, + }, + { + name: "first nil", + s1: nil, + s2: NewString("foo"), + want: false, + }, + { + name: "second nil", + s1: NewString("foo"), + s2: nil, + want: false, + }, + { + name: "empty sets", + s1: NewString(), + s2: NewString(), + want: true, + }, + { + name: "same elements", + s1: NewString("foo", "bar"), + s2: NewString("foo", "bar"), + want: true, + }, + { + name: "same elements with different case", + s1: NewString("foo", "bar"), + s2: NewString("FOO", "BAR"), + want: true, + }, + { + name: "different sizes", + s1: NewString("foo", "bar"), + s2: NewString("foo"), + want: false, + }, + { + name: "same size but different elements", + s1: NewString("foo", "bar"), + s2: NewString("foo", "baz"), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.s1 == nil && tt.s2 == nil { + // Special case for nil sets + if !tt.want { + t.Errorf("Equals() = true, want %v", tt.want) + } + return + } + if tt.s1 == nil || tt.s2 == nil { + // One set is nil, they can't be equal + if tt.want { + t.Errorf("Equals() = false, want %v", tt.want) + } + return + } + if got := tt.s1.Equals(tt.s2); got != tt.want { + t.Errorf("Equals() = %v, want %v", got, tt.want) + } + }) + } +}