diff --git a/apis/v1alpha1/upstreamsettingspolicy_types.go b/apis/v1alpha1/upstreamsettingspolicy_types.go index a8f6c09a29..4b8131fad1 100644 --- a/apis/v1alpha1/upstreamsettingspolicy_types.go +++ b/apis/v1alpha1/upstreamsettingspolicy_types.go @@ -36,6 +36,9 @@ type UpstreamSettingsPolicyList struct { } // UpstreamSettingsPolicySpec defines the desired state of the UpstreamSettingsPolicy. +// +kubebuilder:validation:XValidation:rule="!(has(self.loadBalancingMethod) && (self.loadBalancingMethod == 'hash' || self.loadBalancingMethod == 'hash consistent')) || has(self.hashMethodKey)",message="hashMethodKey is required when loadBalancingMethod is 'hash' or 'hash consistent'" +// +//nolint:lll type UpstreamSettingsPolicySpec struct { // ZoneSize is the size of the shared memory zone used by the upstream. This memory zone is used to share // the upstream configuration between nginx worker processes. The more servers that an upstream has, @@ -58,6 +61,12 @@ type UpstreamSettingsPolicySpec struct { // +optional LoadBalancingMethod *LoadBalancingType `json:"loadBalancingMethod,omitempty"` + // HashMethodKey defines the key used for hash-based load balancing methods. + // This field is required when `LoadBalancingMethod` is set to `hash` or `hash consistent`. + // + // +optional + HashMethodKey *HashMethodKey `json:"hashMethodKey,omitempty"` + // TargetRefs identifies API object(s) to apply the policy to. // Objects must be in the same namespace as the policy. // Support: Service @@ -108,19 +117,96 @@ type UpstreamKeepAlive struct { // LoadBalancingType defines the supported load balancing methods. // -// +kubebuilder:validation:Enum=ip_hash;random two least_conn +// +kubebuilder:validation:Enum=round_robin;least_conn;ip_hash;hash;hash consistent;random;random two;random two least_conn;random two least_time=header;random two least_time=last_byte;least_time header;least_time last_byte;least_time header inflight;least_time last_byte inflight +// +//nolint:lll type LoadBalancingType string const ( + // Combination of NGINX directive + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#random + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#least_conn + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#least_time + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#upstream + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#ip_hash + // - https://nginx.org/en/docs/http/ngx_http_upstream_module.html#hash + + // LoadBalancingMethods supported by NGINX OSS and NGINX Plus. + + // LoadBalancingTypeRoundRobin enables round-robin load balancing, + // distributing requests evenly across all upstream servers. + LoadBalancingTypeRoundRobin LoadBalancingType = "round_robin" + + // LoadBalancingTypeLeastConnection enables least-connections load balancing, + // routing requests to the upstream server with the fewest active connections. + LoadBalancingTypeLeastConnection LoadBalancingType = "least_conn" + // LoadBalancingTypeIPHash enables IP hash-based load balancing, // ensuring requests from the same client IP are routed to the same upstream server. - // NGINX directive: https://nginx.org/en/docs/http/ngx_http_upstream_module.html#ip_hash LoadBalancingTypeIPHash LoadBalancingType = "ip_hash" + // LoadBalancingTypeHash enables generic hash-based load balancing, + // routing requests to upstream servers based on a hash of a specified key + // HashMethodKey field must be set when this method is selected. + // Example configuration: hash $binary_remote_addr;. + LoadBalancingTypeHash LoadBalancingType = "hash" + + // LoadBalancingTypeHashConsistent enables consistent hash-based load balancing, + // which minimizes the number of keys remapped when a server is added or removed. + // HashMethodKey field must be set when this method is selected. + // Example configuration: hash $binary_remote_addr consistent;. + LoadBalancingTypeHashConsistent LoadBalancingType = "hash consistent" + + // LoadBalancingTypeRandom enables random load balancing, + // routing requests to upstream servers in a random manner. + LoadBalancingTypeRandom LoadBalancingType = "random" + + // LoadBalancingTypeRandomTwo enables a variation of random load balancing + // that randomly selects two servers and forwards traffic to one of them. + // The default method is least_conn which passes a request to a server with the least number of active connections. + LoadBalancingTypeRandomTwo LoadBalancingType = "random two" + // LoadBalancingTypeRandomTwoLeastConnection enables a variation of least-connections // balancing that randomly selects two servers and forwards traffic to the one with // fewer active connections. - // NGINX directive least_conn: https://nginx.org/en/docs/http/ngx_http_upstream_module.html#least_conn - // NGINX directive random: https://nginx.org/en/docs/http/ngx_http_upstream_module.html#random LoadBalancingTypeRandomTwoLeastConnection LoadBalancingType = "random two least_conn" + + // LoadBalancingMethods supported by NGINX Plus. + + // LoadBalancingTypeRandomTwoLeastTimeHeader enables a variation of least-time load balancing + // that randomly selects two servers and forwards traffic to the one with the least + // time to receive the response header. + LoadBalancingTypeRandomTwoLeastTimeHeader LoadBalancingType = "random two least_time=header" + + // LoadBalancingTypeRandomTwoLeastTimeLastByte enables a variation of least-time load balancing + // that randomly selects two servers and forwards traffic to the one with the least time + // to receive the full response. + LoadBalancingTypeRandomTwoLeastTimeLastByte LoadBalancingType = "random two least_time=last_byte" + + // LoadBalancingTypeLeastTimeHeader enables least-time load balancing, + // routing requests to the upstream server with the least time to receive the response header. + LoadBalancingTypeLeastTimeHeader LoadBalancingType = "least_time header" + + // LoadBalancingTypeLeastTimeLastByte enables least-time load balancing, + // routing requests to the upstream server with the least time to receive the full response. + LoadBalancingTypeLeastTimeLastByte LoadBalancingType = "least_time last_byte" + + // LoadBalancingTypeLeastTimeHeaderInflight enables least-time load balancing, + // routing requests to the upstream server with the least time to receive the response header, + // considering the incomplete requests. + LoadBalancingTypeLeastTimeHeaderInflight LoadBalancingType = "least_time header inflight" + + // LoadBalancingTypeLeastTimeLastByteInflight enables least-time load balancing, + // routing requests to the upstream server with the least time to receive the full response, + // considering the incomplete requests. + LoadBalancingTypeLeastTimeLastByteInflight LoadBalancingType = "least_time last_byte inflight" ) + +// HashMethodKey defines the key used for hash-based load balancing methods. +// The key must be a valid NGINX variable name starting with '$' followed by lowercase +// letters and underscores only. +// For a full list of NGINX variables, +// refer to: https://nginx.org/en/docs/http/ngx_http_upstream_module.html#variables +// +// +kubebuilder:validation:Pattern=`^\$[a-z_]+$` +type HashMethodKey string diff --git a/apis/v1alpha1/zz_generated.deepcopy.go b/apis/v1alpha1/zz_generated.deepcopy.go index 4f0c2bebc4..164bef0cba 100644 --- a/apis/v1alpha1/zz_generated.deepcopy.go +++ b/apis/v1alpha1/zz_generated.deepcopy.go @@ -561,6 +561,11 @@ func (in *UpstreamSettingsPolicySpec) DeepCopyInto(out *UpstreamSettingsPolicySp *out = new(LoadBalancingType) **out = **in } + if in.HashMethodKey != nil { + in, out := &in.HashMethodKey, &out.HashMethodKey + *out = new(HashMethodKey) + **out = **in + } if in.TargetRefs != nil { in, out := &in.TargetRefs, &out.TargetRefs *out = make([]apisv1.LocalPolicyTargetReference, len(*in)) diff --git a/config/crd/bases/gateway.nginx.org_upstreamsettingspolicies.yaml b/config/crd/bases/gateway.nginx.org_upstreamsettingspolicies.yaml index 6fa9de4104..c8cda0c218 100644 --- a/config/crd/bases/gateway.nginx.org_upstreamsettingspolicies.yaml +++ b/config/crd/bases/gateway.nginx.org_upstreamsettingspolicies.yaml @@ -51,6 +51,12 @@ spec: spec: description: Spec defines the desired state of the UpstreamSettingsPolicy. properties: + hashMethodKey: + description: |- + HashMethodKey defines the key used for hash-based load balancing methods. + This field is required when `LoadBalancingMethod` is set to `hash` or `hash consistent`. + pattern: ^\$[a-z_]+$ + type: string keepAlive: description: KeepAlive defines the keep-alive settings. properties: @@ -91,8 +97,20 @@ spec: If not specified, NGINX Gateway Fabric defaults to `random two least_conn`, which differs from the standard NGINX default `round-robin`. enum: + - round_robin + - least_conn - ip_hash + - hash + - hash consistent + - random + - random two - random two least_conn + - random two least_time=header + - random two least_time=last_byte + - least_time header + - least_time last_byte + - least_time header inflight + - least_time last_byte inflight type: string targetRefs: description: |- @@ -152,6 +170,12 @@ spec: required: - targetRefs type: object + x-kubernetes-validations: + - message: hashMethodKey is required when loadBalancingMethod is 'hash' + or 'hash consistent' + rule: '!(has(self.loadBalancingMethod) && (self.loadBalancingMethod + == ''hash'' || self.loadBalancingMethod == ''hash consistent'')) || + has(self.hashMethodKey)' status: description: Status defines the state of the UpstreamSettingsPolicy. properties: diff --git a/deploy/crds.yaml b/deploy/crds.yaml index ce16d35823..6cf54dd1f5 100644 --- a/deploy/crds.yaml +++ b/deploy/crds.yaml @@ -9578,6 +9578,12 @@ spec: spec: description: Spec defines the desired state of the UpstreamSettingsPolicy. properties: + hashMethodKey: + description: |- + HashMethodKey defines the key used for hash-based load balancing methods. + This field is required when `LoadBalancingMethod` is set to `hash` or `hash consistent`. + pattern: ^\$[a-z_]+$ + type: string keepAlive: description: KeepAlive defines the keep-alive settings. properties: @@ -9618,8 +9624,20 @@ spec: If not specified, NGINX Gateway Fabric defaults to `random two least_conn`, which differs from the standard NGINX default `round-robin`. enum: + - round_robin + - least_conn - ip_hash + - hash + - hash consistent + - random + - random two - random two least_conn + - random two least_time=header + - random two least_time=last_byte + - least_time header + - least_time last_byte + - least_time header inflight + - least_time last_byte inflight type: string targetRefs: description: |- @@ -9679,6 +9697,12 @@ spec: required: - targetRefs type: object + x-kubernetes-validations: + - message: hashMethodKey is required when loadBalancingMethod is 'hash' + or 'hash consistent' + rule: '!(has(self.loadBalancingMethod) && (self.loadBalancingMethod + == ''hash'' || self.loadBalancingMethod == ''hash consistent'')) || + has(self.hashMethodKey)' status: description: Status defines the state of the UpstreamSettingsPolicy. properties: diff --git a/internal/controller/manager.go b/internal/controller/manager.go index d4e2114e8a..fe3c387d2e 100644 --- a/internal/controller/manager.go +++ b/internal/controller/manager.go @@ -124,7 +124,7 @@ func StartManager(cfg config.Config) error { mustExtractGVK := kinds.NewMustExtractGKV(scheme) genericValidator := ngxvalidation.GenericValidator{} - policyManager := createPolicyManager(mustExtractGVK, genericValidator) + policyManager := createPolicyManager(mustExtractGVK, genericValidator, cfg.Plus) plusSecrets, err := createPlusSecretMetadata(cfg, mgr.GetAPIReader()) if err != nil { @@ -323,6 +323,7 @@ func StartManager(cfg config.Config) error { func createPolicyManager( mustExtractGVK kinds.MustExtractGVK, validator validation.GenericValidator, + plusEnabled bool, ) *policies.CompositeValidator { cfgs := []policies.ManagerConfig{ { @@ -335,7 +336,7 @@ func createPolicyManager( }, { GVK: mustExtractGVK(&ngfAPIv1alpha1.UpstreamSettingsPolicy{}), - Validator: upstreamsettings.NewValidator(validator), + Validator: upstreamsettings.NewValidator(validator, plusEnabled), }, } diff --git a/internal/controller/nginx/config/http/config.go b/internal/controller/nginx/config/http/config.go index 14af2c8ca7..cd4d565965 100644 --- a/internal/controller/nginx/config/http/config.go +++ b/internal/controller/nginx/config/http/config.go @@ -1,6 +1,7 @@ package http import ( + ngfAPI "github.com/nginx/nginx-gateway-fabric/v2/apis/v1alpha1" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/shared" ) @@ -123,6 +124,7 @@ type Upstream struct { ZoneSize string // format: 512k, 1m StateFile string LoadBalancingMethod string + HashMethodKey string KeepAlive UpstreamKeepAlive Servers []UpstreamServer } @@ -167,3 +169,33 @@ type ServerConfig struct { Plus bool DisableSNIHostValidation bool } + +var ( + OSSAllowedLBMethods = map[ngfAPI.LoadBalancingType]struct{}{ + ngfAPI.LoadBalancingTypeRoundRobin: {}, + ngfAPI.LoadBalancingTypeLeastConnection: {}, + ngfAPI.LoadBalancingTypeIPHash: {}, + ngfAPI.LoadBalancingTypeRandom: {}, + ngfAPI.LoadBalancingTypeHash: {}, + ngfAPI.LoadBalancingTypeHashConsistent: {}, + ngfAPI.LoadBalancingTypeRandomTwo: {}, + ngfAPI.LoadBalancingTypeRandomTwoLeastConnection: {}, + } + + PlusAllowedLBMethods = map[ngfAPI.LoadBalancingType]struct{}{ + ngfAPI.LoadBalancingTypeRoundRobin: {}, + ngfAPI.LoadBalancingTypeLeastConnection: {}, + ngfAPI.LoadBalancingTypeIPHash: {}, + ngfAPI.LoadBalancingTypeRandom: {}, + ngfAPI.LoadBalancingTypeHash: {}, + ngfAPI.LoadBalancingTypeHashConsistent: {}, + ngfAPI.LoadBalancingTypeRandomTwo: {}, + ngfAPI.LoadBalancingTypeRandomTwoLeastConnection: {}, + ngfAPI.LoadBalancingTypeLeastTimeHeader: {}, + ngfAPI.LoadBalancingTypeLeastTimeLastByte: {}, + ngfAPI.LoadBalancingTypeLeastTimeHeaderInflight: {}, + ngfAPI.LoadBalancingTypeLeastTimeLastByteInflight: {}, + ngfAPI.LoadBalancingTypeRandomTwoLeastTimeHeader: {}, + ngfAPI.LoadBalancingTypeRandomTwoLeastTimeLastByte: {}, + } +) diff --git a/internal/controller/nginx/config/policies/upstreamsettings/processor.go b/internal/controller/nginx/config/policies/upstreamsettings/processor.go index 7c29f807c4..9b4f23c7a7 100644 --- a/internal/controller/nginx/config/policies/upstreamsettings/processor.go +++ b/internal/controller/nginx/config/policies/upstreamsettings/processor.go @@ -15,6 +15,8 @@ type UpstreamSettings struct { ZoneSize string // LoadBalancingMethod is the load balancing method setting. LoadBalancingMethod string + // HashMethodKey is the key to be used for hash-based load balancing methods. + HashMethodKey string // KeepAlive contains the keepalive settings. KeepAlive http.UpstreamKeepAlive } @@ -67,6 +69,10 @@ func processPolicies(pols []policies.Policy) UpstreamSettings { if usp.Spec.LoadBalancingMethod != nil { upstreamSettings.LoadBalancingMethod = string(*usp.Spec.LoadBalancingMethod) } + + if usp.Spec.HashMethodKey != nil { + upstreamSettings.HashMethodKey = string(*usp.Spec.HashMethodKey) + } } return upstreamSettings diff --git a/internal/controller/nginx/config/policies/upstreamsettings/processor_test.go b/internal/controller/nginx/config/policies/upstreamsettings/processor_test.go index 4156781663..8473e59f40 100644 --- a/internal/controller/nginx/config/policies/upstreamsettings/processor_test.go +++ b/internal/controller/nginx/config/policies/upstreamsettings/processor_test.go @@ -38,6 +38,7 @@ func TestProcess(t *testing.T) { Timeout: helpers.GetPointer[ngfAPIv1alpha1.Duration]("10s"), }), LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeIPHash), + HashMethodKey: helpers.GetPointer[ngfAPIv1alpha1.HashMethodKey]("$upstream_addr"), }, }, }, @@ -50,6 +51,7 @@ func TestProcess(t *testing.T) { Timeout: "10s", }, LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeIPHash), + HashMethodKey: "$upstream_addr", }, }, { @@ -69,6 +71,25 @@ func TestProcess(t *testing.T) { LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeRandomTwoLeastConnection), }, }, + { + name: "load balancing method set with hash key", + policies: []policies.Policy{ + &ngfAPIv1alpha1.UpstreamSettingsPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "usp", + Namespace: "test", + }, + Spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + HashMethodKey: helpers.GetPointer[ngfAPIv1alpha1.HashMethodKey]("$request_time"), + }, + }, + }, + expUpstreamSettings: UpstreamSettings{ + LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + HashMethodKey: "$request_time", + }, + }, { name: "zone size set", policies: []policies.Policy{ @@ -245,7 +266,8 @@ func TestProcess(t *testing.T) { Namespace: "test", }, Spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ - LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeIPHash), + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + HashMethodKey: helpers.GetPointer[ngfAPIv1alpha1.HashMethodKey]("$upstream_addr"), }, }, }, @@ -257,7 +279,8 @@ func TestProcess(t *testing.T) { Time: "5s", Timeout: "10s", }, - LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeIPHash), + LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + HashMethodKey: "$upstream_addr", }, }, { @@ -345,7 +368,8 @@ func TestProcess(t *testing.T) { Namespace: "test", }, Spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ - LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeIPHash), + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHash), + HashMethodKey: helpers.GetPointer[ngfAPIv1alpha1.HashMethodKey]("$remote_addr"), }, }, }, @@ -357,7 +381,8 @@ func TestProcess(t *testing.T) { Time: "5s", Timeout: "10s", }, - LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeIPHash), + LoadBalancingMethod: string(ngfAPIv1alpha1.LoadBalancingTypeHash), + HashMethodKey: "$remote_addr", }, }, } diff --git a/internal/controller/nginx/config/policies/upstreamsettings/validator.go b/internal/controller/nginx/config/policies/upstreamsettings/validator.go index d3b235a3ba..e56857f752 100644 --- a/internal/controller/nginx/config/policies/upstreamsettings/validator.go +++ b/internal/controller/nginx/config/policies/upstreamsettings/validator.go @@ -1,10 +1,14 @@ package upstreamsettings import ( + "fmt" + "strings" + "k8s.io/apimachinery/pkg/util/validation/field" gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" ngfAPI "github.com/nginx/nginx-gateway-fabric/v2/apis/v1alpha1" + httpConfig "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/http" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/policies" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/conditions" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/validation" @@ -16,11 +20,15 @@ import ( // Implements policies.Validator interface. type Validator struct { genericValidator validation.GenericValidator + plusEnabled bool } // NewValidator returns a new Validator. -func NewValidator(genericValidator validation.GenericValidator) Validator { - return Validator{genericValidator: genericValidator} +func NewValidator(genericValidator validation.GenericValidator, plusEnabled bool) Validator { + return Validator{ + genericValidator: genericValidator, + plusEnabled: plusEnabled, + } } // Validate validates the spec of an UpstreamsSettingsPolicy. @@ -83,10 +91,22 @@ func conflicts(a, b ngfAPI.UpstreamSettingsPolicySpec) bool { } } + if checkConflictsForLoadBalancingFields(a, b) { + return true + } + + return false +} + +func checkConflictsForLoadBalancingFields(a, b ngfAPI.UpstreamSettingsPolicySpec) bool { if a.LoadBalancingMethod != nil && b.LoadBalancingMethod != nil { return true } + if a.HashMethodKey != nil && b.HashMethodKey != nil { + return true + } + return false } @@ -107,6 +127,8 @@ func (v Validator) validateSettings(spec ngfAPI.UpstreamSettingsPolicySpec) erro allErrs = append(allErrs, v.validateUpstreamKeepAlive(*spec.KeepAlive, fieldPath.Child("keepAlive"))...) } + allErrs = append(allErrs, v.validateLoadBalancingMethod(spec)...) + return allErrs.ToAggregate() } @@ -134,3 +156,51 @@ func (v Validator) validateUpstreamKeepAlive( return allErrs } + +// ValidateLoadBalancingMethod validates the load balancing method for upstream servers. +func (v Validator) validateLoadBalancingMethod(spec ngfAPI.UpstreamSettingsPolicySpec) field.ErrorList { + if spec.LoadBalancingMethod == nil { + return nil + } + + var allErrs field.ErrorList + path := field.NewPath("spec") + lbPath := path.Child("loadBalancingMethod") + + allowedMethods := httpConfig.OSSAllowedLBMethods + nginxType := "NGINX OSS" + if v.plusEnabled { + allowedMethods = httpConfig.PlusAllowedLBMethods + nginxType = "NGINX Plus" + } + + if _, ok := allowedMethods[*spec.LoadBalancingMethod]; !ok { + allErrs = append(allErrs, field.Invalid( + lbPath, + *spec.LoadBalancingMethod, + fmt.Sprintf( + "%s supports the following load balancing methods: %s", + nginxType, + getLoadBalancingMethodList(allowedMethods), + ), + )) + } + + if spec.HashMethodKey != nil { + hashMethodKey := *spec.HashMethodKey + if err := v.genericValidator.ValidateNginxVariableName(string(hashMethodKey)); err != nil { + path := path.Child("hashMethodKey") + allErrs = append(allErrs, field.Invalid(path, hashMethodKey, err.Error())) + } + } + + return allErrs +} + +func getLoadBalancingMethodList(lbMethods map[ngfAPI.LoadBalancingType]struct{}) string { + methods := make([]string, 0, len(lbMethods)) + for method := range lbMethods { + methods = append(methods, string(method)) + } + return strings.Join(methods, ", ") +} diff --git a/internal/controller/nginx/config/policies/upstreamsettings/validator_test.go b/internal/controller/nginx/config/policies/upstreamsettings/validator_test.go index 1bae51e5bb..43285bdb70 100644 --- a/internal/controller/nginx/config/policies/upstreamsettings/validator_test.go +++ b/internal/controller/nginx/config/policies/upstreamsettings/validator_test.go @@ -16,6 +16,8 @@ import ( "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/kinds" ) +const plusDisabled = false + type policyModFunc func(policy *ngfAPI.UpstreamSettingsPolicy) *ngfAPI.UpstreamSettingsPolicy func createValidPolicy() *ngfAPI.UpstreamSettingsPolicy { @@ -39,6 +41,7 @@ func createValidPolicy() *ngfAPI.UpstreamSettingsPolicy { Connections: helpers.GetPointer[int32](100), }, LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeRandomTwoLeastConnection), + HashMethodKey: helpers.GetPointer[ngfAPI.HashMethodKey]("$upstream_addr"), }, Status: v1.PolicyStatus{}, } @@ -125,7 +128,7 @@ func TestValidator_Validate(t *testing.T) { }, } - v := upstreamsettings.NewValidator(validation.GenericValidator{}) + v := upstreamsettings.NewValidator(validation.GenericValidator{}, plusDisabled) for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -140,7 +143,7 @@ func TestValidator_Validate(t *testing.T) { func TestValidator_ValidatePanics(t *testing.T) { t.Parallel() - v := upstreamsettings.NewValidator(nil) + v := upstreamsettings.NewValidator(nil, plusDisabled) validate := func() { _ = v.Validate(&policiesfakes.FakePolicy{}) @@ -155,7 +158,7 @@ func TestValidator_ValidateGlobalSettings(t *testing.T) { t.Parallel() g := NewWithT(t) - v := upstreamsettings.NewValidator(validation.GenericValidator{}) + v := upstreamsettings.NewValidator(validation.GenericValidator{}, plusDisabled) g.Expect(v.ValidateGlobalSettings(nil, nil)).To(BeNil()) } @@ -258,9 +261,20 @@ func TestValidator_Conflicts(t *testing.T) { }, conflicts: true, }, + { + name: "hash key conflicts", + polA: createValidPolicy(), + polB: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeHashConsistent), + HashMethodKey: helpers.GetPointer[ngfAPI.HashMethodKey]("$upstream_addr"), + }, + }, + conflicts: true, + }, } - v := upstreamsettings.NewValidator(nil) + v := upstreamsettings.NewValidator(nil, plusDisabled) for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -274,7 +288,7 @@ func TestValidator_Conflicts(t *testing.T) { func TestValidator_ConflictsPanics(t *testing.T) { t.Parallel() - v := upstreamsettings.NewValidator(nil) + v := upstreamsettings.NewValidator(nil, plusDisabled) conflicts := func() { _ = v.Conflicts(&policiesfakes.FakePolicy{}, &policiesfakes.FakePolicy{}) @@ -284,3 +298,95 @@ func TestValidator_ConflictsPanics(t *testing.T) { g.Expect(conflicts).To(Panic()) } + +func TestValidate_ValidateLoadBalancingMethod(t *testing.T) { + t.Parallel() + + tests := []struct { + policy *ngfAPI.UpstreamSettingsPolicy + name string + expConditions []conditions.Condition + plusEnabled bool + }{ + { + name: "oss method random with Plus disabled", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeRandom), + }, + }, + expConditions: nil, + }, + { + name: "oss method hash consistent with Plus disabled", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeHashConsistent), + }, + }, + expConditions: nil, + }, + { + name: "plus load balancing method least_time last_byte not allowed with Plus disabled", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeLeastTimeLastByte), + }, + }, + expConditions: []conditions.Condition{ + conditions.NewPolicyInvalid("spec.loadBalancingMethod: Invalid value: \"least_time last_byte\": " + + "NGINX OSS supports the following load balancing methods: "), + }, + }, + { + name: "plus load balancing method least_time header allowed with Plus enabled", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingTypeLeastTimeHeader), + }, + }, + plusEnabled: true, + expConditions: nil, + }, + { + name: "invalid load balancing method for NGINX OSS", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingType("invalid-method")), + }, + }, + expConditions: []conditions.Condition{ + conditions.NewPolicyInvalid("spec.loadBalancingMethod: Invalid value: \"invalid-method\": " + + "NGINX OSS supports the following load balancing methods: "), + }, + }, + { + name: "invalid load balancing method for NGINX Plus", + policy: &ngfAPI.UpstreamSettingsPolicy{ + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(ngfAPI.LoadBalancingType("invalid-method")), + }, + }, + expConditions: []conditions.Condition{ + conditions.NewPolicyInvalid("spec.loadBalancingMethod: Invalid value: \"invalid-method\": " + + "NGINX Plus supports the following load balancing methods: "), + }, + plusEnabled: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + v := upstreamsettings.NewValidator(validation.GenericValidator{}, test.plusEnabled) + conds := v.Validate(test.policy) + + if test.expConditions != nil { + g.Expect(conds).To(HaveLen(1)) + g.Expect(conds[0].Message).To(ContainSubstring(test.expConditions[0].Message)) + } + }) + } +} diff --git a/internal/controller/nginx/config/upstreams.go b/internal/controller/nginx/config/upstreams.go index bf56ec052d..7708ba1d97 100644 --- a/internal/controller/nginx/config/upstreams.go +++ b/internal/controller/nginx/config/upstreams.go @@ -4,6 +4,7 @@ import ( "fmt" gotemplate "text/template" + ngfAPI "github.com/nginx/nginx-gateway-fabric/v2/apis/v1alpha1" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/http" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/policies/upstreamsettings" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/stream" @@ -162,6 +163,22 @@ func (g GeneratorImpl) createUpstream( zoneSize = upstreamPolicySettings.ZoneSize } + chosenLBMethod := defaultLBMethod + if upstreamPolicySettings.LoadBalancingMethod != "" { + lbMethod := upstreamPolicySettings.LoadBalancingMethod + + if lbMethod == string(ngfAPI.LoadBalancingTypeHash) { + lbMethod = fmt.Sprintf("hash %s", upstreamPolicySettings.HashMethodKey) + } + if lbMethod == string(ngfAPI.LoadBalancingTypeHashConsistent) { + lbMethod = fmt.Sprintf("hash %s consistent", upstreamPolicySettings.HashMethodKey) + } + if lbMethod == string(ngfAPI.LoadBalancingTypeRoundRobin) { + lbMethod = "" + } + chosenLBMethod = lbMethod + } + if len(up.Endpoints) == 0 { return http.Upstream{ Name: up.Name, @@ -172,6 +189,7 @@ func (g GeneratorImpl) createUpstream( Address: types.Nginx503Server, }, }, + LoadBalancingMethod: chosenLBMethod, } } @@ -187,11 +205,6 @@ func (g GeneratorImpl) createUpstream( } } - chosenLBMethod := defaultLBMethod - if upstreamPolicySettings.LoadBalancingMethod != "" { - chosenLBMethod = upstreamPolicySettings.LoadBalancingMethod - } - return http.Upstream{ Name: up.Name, ZoneSize: zoneSize, diff --git a/internal/controller/nginx/config/upstreams_test.go b/internal/controller/nginx/config/upstreams_test.go index c87761b35f..f387fa4ed9 100644 --- a/internal/controller/nginx/config/upstreams_test.go +++ b/internal/controller/nginx/config/upstreams_test.go @@ -105,7 +105,7 @@ func TestExecuteUpstreams(t *testing.T) { "zone up5-usp 2m;": 1, "ip_hash;": 1, - "random two least_conn;": 3, + "random two least_conn;": 4, } upstreams := gen.createUpstreams(stateUpstreams, upstreamsettings.NewProcessor()) @@ -233,6 +233,7 @@ func TestCreateUpstreams(t *testing.T) { Address: types.Nginx503Server, }, }, + LoadBalancingMethod: defaultLBMethod, }, { Name: "up4-ipv6", @@ -296,6 +297,7 @@ func TestCreateUpstream(t *testing.T) { Address: types.Nginx503Server, }, }, + LoadBalancingMethod: defaultLBMethod, }, msg: "nil endpoints", }, @@ -312,6 +314,7 @@ func TestCreateUpstream(t *testing.T) { Address: types.Nginx503Server, }, }, + LoadBalancingMethod: defaultLBMethod, }, msg: "no endpoints", }, @@ -705,6 +708,7 @@ func TestCreateUpstreamPlus(t *testing.T) { Address: types.Nginx503Server, }, }, + LoadBalancingMethod: defaultLBMethod, }, }, } @@ -1198,3 +1202,206 @@ func TestKeepAliveChecker(t *testing.T) { }) } } + +func TestExecuteUpstreams_LoadBalancingMethod(t *testing.T) { + t.Parallel() + + tests := []struct { + expectedSubStrings map[string]int + name string + lbType ngfAPI.LoadBalancingType + HashMethodKey ngfAPI.HashMethodKey + }{ + { + name: "default load balancing method", + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "random two least_conn;": 2, + }, + }, + { + name: "round_robin load balancing method", + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + }, + }, + { + name: "least_conn load balancing method", + lbType: ngfAPI.LoadBalancingTypeLeastConnection, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "least_conn;": 2, + }, + }, + { + name: "ip_hash load balancing method", + lbType: ngfAPI.LoadBalancingTypeIPHash, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "ip_hash;": 2, + }, + }, + { + name: "hash load balancing method with specific hash key", + lbType: ngfAPI.LoadBalancingTypeHash, + HashMethodKey: ngfAPI.HashMethodKey("$request_uri"), + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "hash $request_uri;": 2, + }, + }, + { + name: "hash consistent load balancing method with specific hash key", + lbType: ngfAPI.LoadBalancingTypeHashConsistent, + HashMethodKey: ngfAPI.HashMethodKey("$remote_addr"), + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "hash $remote_addr consistent;": 2, + }, + }, + { + name: "random load balancing method", + lbType: ngfAPI.LoadBalancingTypeRandom, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "random;": 2, + }, + }, + { + name: "random two load balancing method", + lbType: ngfAPI.LoadBalancingTypeRandomTwo, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "random two;": 2, + }, + }, + { + name: "random two least_time=header load balancing method", + lbType: ngfAPI.LoadBalancingTypeRandomTwoLeastTimeHeader, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "random two least_time=header;": 2, + }, + }, + { + name: "random two least_time=last_byte load balancing method", + lbType: ngfAPI.LoadBalancingTypeRandomTwoLeastTimeLastByte, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "random two least_time=last_byte;": 2, + }, + }, + { + name: "least_time header load balancing method", + lbType: ngfAPI.LoadBalancingTypeLeastTimeHeader, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "least_time header;": 2, + }, + }, + { + name: "least_time last_byte load balancing method", + lbType: ngfAPI.LoadBalancingTypeLeastTimeLastByte, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "least_time last_byte;": 2, + }, + }, + { + name: "least_time header inflight load balancing method", + lbType: ngfAPI.LoadBalancingTypeLeastTimeHeaderInflight, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "least_time header inflight;": 2, + }, + }, + { + name: "least_time last_byte inflight load balancing method", + lbType: ngfAPI.LoadBalancingTypeLeastTimeLastByteInflight, + expectedSubStrings: map[string]int{ + "upstream up1-usp-ipv4": 1, + "upstream up2-usp-ipv6": 1, + "least_time last_byte inflight;": 2, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + gen := GeneratorImpl{} + stateUpstreams := []dataplane.Upstream{ + { + Name: "up1-usp-ipv4", + Endpoints: []resolver.Endpoint{ + { + Address: "12.0.0.0", + Port: 80, + }, + }, + Policies: []policies.Policy{ + &ngfAPI.UpstreamSettingsPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "usp-ipv4", + Namespace: "test", + }, + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(tt.lbType), + HashMethodKey: helpers.GetPointer(tt.HashMethodKey), + }, + }, + }, + }, + { + Name: "up2-usp-ipv6", + Endpoints: []resolver.Endpoint{ + { + Address: "2001:db8::1", + Port: 80, + }, + }, + Policies: []policies.Policy{ + &ngfAPI.UpstreamSettingsPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "usp-ipv6", + Namespace: "test", + }, + Spec: ngfAPI.UpstreamSettingsPolicySpec{ + LoadBalancingMethod: helpers.GetPointer(tt.lbType), + HashMethodKey: helpers.GetPointer(tt.HashMethodKey), + }, + }, + }, + }, + } + + upstreams := gen.createUpstreams(stateUpstreams, upstreamsettings.NewProcessor()) + upstreamResults := executeUpstreams(upstreams) + + g.Expect(upstreamResults).To(HaveLen(1)) + nginxUpstreams := string(upstreamResults[0].data) + + for expSubString, expectedCount := range tt.expectedSubStrings { + actualCount := strings.Count(nginxUpstreams, expSubString) + g.Expect(actualCount).To( + Equal(expectedCount), + fmt.Sprintf("substring %q expected %d occurrence(s), got %d", expSubString, expectedCount, actualCount), + ) + } + }) + } +} diff --git a/internal/controller/nginx/config/validation/generic.go b/internal/controller/nginx/config/validation/generic.go index 8342ab4134..f63073955b 100644 --- a/internal/controller/nginx/config/validation/generic.go +++ b/internal/controller/nginx/config/validation/generic.go @@ -106,3 +106,24 @@ func (GenericValidator) ValidateEndpoint(endpoint string) error { return nil } + +const ( + variableNameFmt = `\$[a-z_]+` + variableNameErrMsg = "must start with '$' followed by lowercase letters and underscores only" +) + +var variableNameRegexp = regexp.MustCompile("^" + variableNameFmt + "$") + +// ValidateNginxVariableName validates an nginx variable name. +func (GenericValidator) ValidateNginxVariableName(name string) error { + if !variableNameRegexp.MatchString(name) { + examples := []string{ + "$upstream_addr", + "$remote_addr", + } + + return errors.New(k8svalidation.RegexError(variableNameFmt, variableNameErrMsg, examples...)) + } + + return nil +} diff --git a/internal/controller/nginx/config/validation/generic_test.go b/internal/controller/nginx/config/validation/generic_test.go index 5f57b51c56..73be3f10cb 100644 --- a/internal/controller/nginx/config/validation/generic_test.go +++ b/internal/controller/nginx/config/validation/generic_test.go @@ -112,3 +112,25 @@ func TestValidateEndpoint(t *testing.T) { `my$endpoint`, ) } + +func TestValidateNginxVariableName(t *testing.T) { + t.Parallel() + validator := GenericValidator{} + + testValidValuesForSimpleValidator( + t, + validator.ValidateNginxVariableName, + `$upstream_bytes_sent`, + `$upstream_last_server_name`, + `$remote_addr`, + ) + + testInvalidValuesForSimpleValidator( + t, + validator.ValidateNginxVariableName, + `1varname`, + `var-name`, + `var name`, + `var$name`, + ) +} diff --git a/internal/controller/state/graph/policies_test.go b/internal/controller/state/graph/policies_test.go index 251ab5aeb8..e4532f5a54 100644 --- a/internal/controller/state/graph/policies_test.go +++ b/internal/controller/state/graph/policies_test.go @@ -242,7 +242,7 @@ func TestAttachPolicies(t *testing.T) { NGFPolicies: test.ngfPolicies, } - graph.attachPolicies(nil, "nginx-gateway", logr.Discard()) + graph.attachPolicies(&policiesfakes.FakeValidator{}, "nginx-gateway", logr.Discard()) for _, expect := range test.expects { expect(g, graph) } diff --git a/internal/controller/state/validation/validationfakes/fake_generic_validator.go b/internal/controller/state/validation/validationfakes/fake_generic_validator.go index 8c83a4ff9a..cd162c7359 100644 --- a/internal/controller/state/validation/validationfakes/fake_generic_validator.go +++ b/internal/controller/state/validation/validationfakes/fake_generic_validator.go @@ -52,6 +52,17 @@ type FakeGenericValidator struct { validateNginxSizeReturnsOnCall map[int]struct { result1 error } + ValidateNginxVariableNameStub func(string) error + validateNginxVariableNameMutex sync.RWMutex + validateNginxVariableNameArgsForCall []struct { + arg1 string + } + validateNginxVariableNameReturns struct { + result1 error + } + validateNginxVariableNameReturnsOnCall map[int]struct { + result1 error + } ValidateServiceNameStub func(string) error validateServiceNameMutex sync.RWMutex validateServiceNameArgsForCall []struct { @@ -311,6 +322,67 @@ func (fake *FakeGenericValidator) ValidateNginxSizeReturnsOnCall(i int, result1 }{result1} } +func (fake *FakeGenericValidator) ValidateNginxVariableName(arg1 string) error { + fake.validateNginxVariableNameMutex.Lock() + ret, specificReturn := fake.validateNginxVariableNameReturnsOnCall[len(fake.validateNginxVariableNameArgsForCall)] + fake.validateNginxVariableNameArgsForCall = append(fake.validateNginxVariableNameArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.ValidateNginxVariableNameStub + fakeReturns := fake.validateNginxVariableNameReturns + fake.recordInvocation("ValidateNginxVariableName", []interface{}{arg1}) + fake.validateNginxVariableNameMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeGenericValidator) ValidateNginxVariableNameCallCount() int { + fake.validateNginxVariableNameMutex.RLock() + defer fake.validateNginxVariableNameMutex.RUnlock() + return len(fake.validateNginxVariableNameArgsForCall) +} + +func (fake *FakeGenericValidator) ValidateNginxVariableNameCalls(stub func(string) error) { + fake.validateNginxVariableNameMutex.Lock() + defer fake.validateNginxVariableNameMutex.Unlock() + fake.ValidateNginxVariableNameStub = stub +} + +func (fake *FakeGenericValidator) ValidateNginxVariableNameArgsForCall(i int) string { + fake.validateNginxVariableNameMutex.RLock() + defer fake.validateNginxVariableNameMutex.RUnlock() + argsForCall := fake.validateNginxVariableNameArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeGenericValidator) ValidateNginxVariableNameReturns(result1 error) { + fake.validateNginxVariableNameMutex.Lock() + defer fake.validateNginxVariableNameMutex.Unlock() + fake.ValidateNginxVariableNameStub = nil + fake.validateNginxVariableNameReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeGenericValidator) ValidateNginxVariableNameReturnsOnCall(i int, result1 error) { + fake.validateNginxVariableNameMutex.Lock() + defer fake.validateNginxVariableNameMutex.Unlock() + fake.ValidateNginxVariableNameStub = nil + if fake.validateNginxVariableNameReturnsOnCall == nil { + fake.validateNginxVariableNameReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.validateNginxVariableNameReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeGenericValidator) ValidateServiceName(arg1 string) error { fake.validateServiceNameMutex.Lock() ret, specificReturn := fake.validateServiceNameReturnsOnCall[len(fake.validateServiceNameArgsForCall)] diff --git a/internal/controller/state/validation/validator.go b/internal/controller/state/validation/validator.go index 10dc1fe8c3..5c1bd54e4e 100644 --- a/internal/controller/state/validation/validator.go +++ b/internal/controller/state/validation/validator.go @@ -49,6 +49,7 @@ type GenericValidator interface { ValidateNginxDuration(duration string) error ValidateNginxSize(size string) error ValidateEndpoint(endpoint string) error + ValidateNginxVariableName(name string) error } // PolicyValidator validates an NGF Policy. diff --git a/tests/cel/common.go b/tests/cel/common.go index 6208c1a75f..067888a2d3 100644 --- a/tests/cel/common.go +++ b/tests/cel/common.go @@ -56,9 +56,11 @@ const ( // UpstreamSettingsPolicy validation errors. const ( - expectedTargetRefKindServiceError = `TargetRefs Kind must be: Service` - expectedTargetRefGroupCoreError = `TargetRefs Group must be core` - expectedTargetRefNameUniqueError = `TargetRef Name must be unique` + expectedTargetRefKindServiceError = `TargetRefs Kind must be: Service` + expectedTargetRefGroupCoreError = `TargetRefs Group must be core` + expectedTargetRefNameUniqueError = `TargetRef Name must be unique` + expectedHashKeyLoadBalancingTypeError = `hashMethodKey is required when loadBalancingMethod ` + + `is 'hash' or 'hash consistent'` ) // SnippetsFilter validation errors. diff --git a/tests/cel/upstreamsettingspolicy_test.go b/tests/cel/upstreamsettingspolicy_test.go index 35a4fbd364..5b0267350f 100644 --- a/tests/cel/upstreamsettingspolicy_test.go +++ b/tests/cel/upstreamsettingspolicy_test.go @@ -7,6 +7,7 @@ import ( gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" ngfAPIv1alpha1 "github.com/nginx/nginx-gateway-fabric/v2/apis/v1alpha1" + "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/helpers" ) func TestUpstreamSettingsPolicyTargetRefKind(t *testing.T) { @@ -372,3 +373,86 @@ func TestUpstreamSettingsPolicyTargetRefNameUniqueness(t *testing.T) { }) } } + +func TestUpstreamSettingsPolicy_LoadBalancing(t *testing.T) { + t.Parallel() + k8sClient := getKubernetesClient(t) + + tests := []struct { + spec ngfAPIv1alpha1.UpstreamSettingsPolicySpec + name string + wantErrors []string + }{ + { + name: "when load balancing method is hash, hash key is required, error expected", + spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + TargetRefs: []gatewayv1.LocalPolicyTargetReference{ + { + Kind: serviceKind, + Group: coreGroup, + }, + }, + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHash), + }, + wantErrors: []string{expectedHashKeyLoadBalancingTypeError}, + }, + { + name: "when load balancing method is hash consistent, hash key is required, error expected", + spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + TargetRefs: []gatewayv1.LocalPolicyTargetReference{ + { + Kind: serviceKind, + Group: coreGroup, + }, + }, + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + }, + wantErrors: []string{expectedHashKeyLoadBalancingTypeError}, + }, + { + name: "specify load balancing method as hash and set the hash key, no error expected", + spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + TargetRefs: []gatewayv1.LocalPolicyTargetReference{ + { + Kind: serviceKind, + Group: coreGroup, + }, + }, + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHash), + HashMethodKey: helpers.GetPointer(ngfAPIv1alpha1.HashMethodKey("$upstream_connect_time")), + }, + }, + { + name: "specify load balancing method as hash consistent and set the hash key, no error expected", + spec: ngfAPIv1alpha1.UpstreamSettingsPolicySpec{ + TargetRefs: []gatewayv1.LocalPolicyTargetReference{ + { + Kind: serviceKind, + Group: coreGroup, + }, + }, + LoadBalancingMethod: helpers.GetPointer(ngfAPIv1alpha1.LoadBalancingTypeHashConsistent), + HashMethodKey: helpers.GetPointer(ngfAPIv1alpha1.HashMethodKey("$upstream_bytes_sent")), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + for i := range tt.spec.TargetRefs { + tt.spec.TargetRefs[i].Name = gatewayv1.ObjectName(uniqueResourceName(testTargetRefName)) + } + + upstreamSettingsPolicy := &ngfAPIv1alpha1.UpstreamSettingsPolicy{ + ObjectMeta: controllerruntime.ObjectMeta{ + Name: uniqueResourceName(testResourceName), + Namespace: defaultNamespace, + }, + Spec: tt.spec, + } + validateCrd(t, tt.wantErrors, upstreamSettingsPolicy, k8sClient) + }) + } +}