diff --git a/conformance/tests/grpcroute-weight.go b/conformance/tests/grpcroute-weight.go new file mode 100644 index 0000000000..5086b5bb08 --- /dev/null +++ b/conformance/tests/grpcroute-weight.go @@ -0,0 +1,98 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tests + +import ( + "fmt" + "testing" + + "google.golang.org/grpc/codes" + "k8s.io/apimachinery/pkg/types" + + v1 "sigs.k8s.io/gateway-api/apis/v1" + pb "sigs.k8s.io/gateway-api/conformance/echo-basic/grpcechoserver" + "sigs.k8s.io/gateway-api/conformance/utils/grpc" + "sigs.k8s.io/gateway-api/conformance/utils/kubernetes" + "sigs.k8s.io/gateway-api/conformance/utils/suite" + "sigs.k8s.io/gateway-api/conformance/utils/weight" + "sigs.k8s.io/gateway-api/pkg/features" +) + +func init() { + ConformanceTests = append(ConformanceTests, GRPCRouteWeight) +} + +var GRPCRouteWeight = suite.ConformanceTest{ + ShortName: "GRPCRouteWeight", + Description: "An GRPCRoute with weighted backends", + Manifests: []string{"tests/grpcroute-weight.yaml"}, + Features: []features.FeatureName{ + features.SupportGateway, + features.SupportGRPCRoute, + }, + Test: func(t *testing.T, suite *suite.ConformanceTestSuite) { + var ( + ns = "gateway-conformance-infra" + routeNN = types.NamespacedName{Name: "weighted-backends", Namespace: ns} + gwNN = types.NamespacedName{Name: "same-namespace", Namespace: ns} + gwAddr = kubernetes.GatewayAndRoutesMustBeAccepted(t, suite.Client, suite.TimeoutConfig, suite.ControllerName, kubernetes.NewGatewayRef(gwNN), &v1.GRPCRoute{}, true, routeNN) + ) + + t.Run("Requests should have a distribution that matches the weight", func(t *testing.T) { + expected := grpc.ExpectedResponse{ + EchoRequest: &pb.EchoRequest{}, + Response: grpc.Response{Code: codes.OK}, + Namespace: "gateway-conformance-infra", + } + + // Assert request succeeds before doing our distribution check + grpc.MakeRequestAndExpectEventuallyConsistentResponse(t, suite.GRPCClient, suite.TimeoutConfig, gwAddr, expected) + + expectedWeights := map[string]float64{ + "grpc-infra-backend-v1": 0.7, + "grpc-infra-backend-v2": 0.3, + "grpc-infra-backend-v3": 0.0, + } + + sender := weight.NewFunctionBasedSender(func() (string, error) { + uniqueExpected := expected + if err := grpc.AddEntropy(&uniqueExpected); err != nil { + return "", fmt.Errorf("error adding entropy: %w", err) + } + client := &grpc.DefaultClient{} + defer client.Close() + resp, err := client.SendRPC(t, gwAddr, uniqueExpected, suite.TimeoutConfig.MaxTimeToConsistency) + if err != nil { + return "", fmt.Errorf("failed to send gRPC request: %w", err) + } + if resp.Code != codes.OK { + return "", fmt.Errorf("expected OK response, got %v", resp.Code) + } + return resp.Response.GetAssertions().GetContext().GetPod(), nil + }) + + for i := 0; i < 10; i++ { + if err := weight.TestWeightedDistribution(sender, expectedWeights); err != nil { + t.Logf("Traffic distribution test failed (%d/10): %s", i+1, err) + } else { + return + } + } + t.Fatal("Weighted distribution tests failed") + }) + }, +} diff --git a/conformance/tests/grpcroute-weight.yaml b/conformance/tests/grpcroute-weight.yaml new file mode 100644 index 0000000000..944d3b7486 --- /dev/null +++ b/conformance/tests/grpcroute-weight.yaml @@ -0,0 +1,19 @@ +apiVersion: gateway.networking.k8s.io/v1 +kind: GRPCRoute +metadata: + name: weighted-backends + namespace: gateway-conformance-infra +spec: + parentRefs: + - name: same-namespace + rules: + - backendRefs: + - name: grpc-infra-backend-v1 + port: 8080 + weight: 70 + - name: grpc-infra-backend-v2 + port: 8080 + weight: 30 + - name: grpc-infra-backend-v3 + port: 8080 + weight: 0 diff --git a/conformance/tests/httproute-weight.go b/conformance/tests/httproute-weight.go index fa85deb9bf..50ce9efd89 100644 --- a/conformance/tests/httproute-weight.go +++ b/conformance/tests/httproute-weight.go @@ -17,21 +17,15 @@ limitations under the License. package tests import ( - "cmp" - "errors" "fmt" - "math" - "slices" - "strings" - "sync" "testing" - "golang.org/x/sync/errgroup" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api/conformance/utils/http" "sigs.k8s.io/gateway-api/conformance/utils/kubernetes" "sigs.k8s.io/gateway-api/conformance/utils/suite" + "sigs.k8s.io/gateway-api/conformance/utils/weight" "sigs.k8s.io/gateway-api/pkg/features" ) @@ -67,8 +61,30 @@ var HTTPRouteWeight = suite.ConformanceTest{ // Assert request succeeds before doing our distribution check http.MakeRequestAndExpectEventuallyConsistentResponse(t, suite.RoundTripper, suite.TimeoutConfig, gwAddr, expected) + expectedWeights := map[string]float64{ + "infra-backend-v1": 0.7, + "infra-backend-v2": 0.3, + "infra-backend-v3": 0.0, + } + + sender := weight.NewFunctionBasedSender(func() (string, error) { + uniqueExpected := expected + if err := http.AddEntropy(&uniqueExpected); err != nil { + return "", fmt.Errorf("error adding entropy: %w", err) + } + req := http.MakeRequest(t, &uniqueExpected, gwAddr, "HTTP", "http") + cReq, cRes, err := suite.RoundTripper.CaptureRoundTrip(req) + if err != nil { + return "", fmt.Errorf("failed to roundtrip request: %w", err) + } + if err := http.CompareRoundTrip(t, &req, cReq, cRes, expected); err != nil { + return "", fmt.Errorf("response expectation failed for request: %w", err) + } + return cReq.Pod, nil + }) + for i := 0; i < 10; i++ { - if err := testDistribution(t, suite, gwAddr, expected); err != nil { + if err := weight.TestWeightedDistribution(sender, expectedWeights); err != nil { t.Logf("Traffic distribution test failed (%d/10): %s", i+1, err) } else { return @@ -78,85 +94,3 @@ var HTTPRouteWeight = suite.ConformanceTest{ }) }, } - -func testDistribution(t *testing.T, suite *suite.ConformanceTestSuite, gwAddr string, expected http.ExpectedResponse) error { - const ( - concurrentRequests = 10 - tolerancePercentage = 0.05 - totalRequests = 500.0 - ) - var ( - roundTripper = suite.RoundTripper - - g errgroup.Group - seenMutex sync.Mutex - seen = make(map[string]float64, 3 /* number of backends */) - expectedWeights = map[string]float64{ - "infra-backend-v1": 0.7, - "infra-backend-v2": 0.3, - "infra-backend-v3": 0.0, - } - ) - g.SetLimit(concurrentRequests) - for i := 0.0; i < totalRequests; i++ { - g.Go(func() error { - uniqueExpected := expected - if err := http.AddEntropy(&uniqueExpected); err != nil { - return fmt.Errorf("error adding entropy: %w", err) - } - req := http.MakeRequest(t, &uniqueExpected, gwAddr, "HTTP", "http") - cReq, cRes, err := roundTripper.CaptureRoundTrip(req) - if err != nil { - return fmt.Errorf("failed to roundtrip request: %w", err) - } - if err := http.CompareRoundTrip(t, &req, cReq, cRes, expected); err != nil { - return fmt.Errorf("response expectation failed for request: %w", err) - } - - seenMutex.Lock() - defer seenMutex.Unlock() - - for expectedBackend := range expectedWeights { - if strings.HasPrefix(cReq.Pod, expectedBackend) { - seen[expectedBackend]++ - return nil - } - } - - return fmt.Errorf("request was handled by an unexpected pod %q", cReq.Pod) - }) - } - - if err := g.Wait(); err != nil { - return fmt.Errorf("error while sending requests: %w", err) - } - - var errs []error - if len(seen) != 2 { - errs = append(errs, fmt.Errorf("expected only two backends to receive traffic")) - } - - for wantBackend, wantPercent := range expectedWeights { - gotCount, ok := seen[wantBackend] - - if !ok && wantPercent != 0.0 { - errs = append(errs, fmt.Errorf("expect traffic to hit backend %q - but none was received", wantBackend)) - continue - } - - gotPercent := gotCount / totalRequests - - if math.Abs(gotPercent-wantPercent) > tolerancePercentage { - errs = append(errs, fmt.Errorf("backend %q weighted traffic of %v not within tolerance %v (+/-%f)", - wantBackend, - gotPercent, - wantPercent, - tolerancePercentage, - )) - } - } - slices.SortFunc(errs, func(a, b error) int { - return cmp.Compare(a.Error(), b.Error()) - }) - return errors.Join(errs...) -} diff --git a/conformance/tests/mesh/grpcroute-weight.go b/conformance/tests/mesh/grpcroute-weight.go new file mode 100644 index 0000000000..37dff25c70 --- /dev/null +++ b/conformance/tests/mesh/grpcroute-weight.go @@ -0,0 +1,83 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package meshtests + +import ( + "fmt" + "testing" + + "sigs.k8s.io/gateway-api/conformance/utils/echo" + "sigs.k8s.io/gateway-api/conformance/utils/http" + "sigs.k8s.io/gateway-api/conformance/utils/suite" + "sigs.k8s.io/gateway-api/conformance/utils/weight" + "sigs.k8s.io/gateway-api/pkg/features" +) + +func init() { + MeshConformanceTests = append(MeshConformanceTests, MeshGRPCRouteWeight) +} + +var MeshGRPCRouteWeight = suite.ConformanceTest{ + ShortName: "MeshGRPCRouteWeight", + Description: "A GRPCRoute with weighted backends in mesh mode", + Manifests: []string{"tests/mesh/grpcroute-weight.yaml"}, + Features: []features.FeatureName{ + features.SupportMesh, + features.SupportGRPCRoute, + }, + Test: func(t *testing.T, s *suite.ConformanceTestSuite) { + client := echo.ConnectToApp(t, s, echo.MeshAppEchoV1) + + t.Run("Requests should have a distribution that matches the weight", func(t *testing.T) { + // Create a gRPC request using the mesh client framework + expected := http.ExpectedResponse{ + Request: http.Request{Protocol: "grpc", Path: "", Host: "echo:7070"}, + Response: http.Response{StatusCode: 200}, + Namespace: "gateway-conformance-mesh", + } + + // Assert request succeeds before doing our distribution check + client.MakeRequestAndExpectEventuallyConsistentResponse(t, expected, s.TimeoutConfig) + + expectedWeights := map[string]float64{ + "echo-v1": 0.7, + "echo-v2": 0.3, + } + + sender := weight.NewFunctionBasedSender(func() (string, error) { + uniqueExpected := expected + if err := http.AddEntropy(&uniqueExpected); err != nil { + return "", fmt.Errorf("error adding entropy: %w", err) + } + _, cRes, err := client.CaptureRequestResponseAndCompare(t, uniqueExpected) + if err != nil { + return "", fmt.Errorf("failed gRPC mesh request: %w", err) + } + return cRes.Hostname, nil + }) + + for i := 0; i < 10; i++ { + if err := weight.TestWeightedDistribution(sender, expectedWeights); err != nil { + t.Logf("Traffic distribution test failed (%d/10): %s", i+1, err) + } else { + return + } + } + t.Fatal("Weighted distribution tests failed") + }) + }, +} diff --git a/conformance/tests/mesh/grpcroute-weight.yaml b/conformance/tests/mesh/grpcroute-weight.yaml new file mode 100644 index 0000000000..ef0d424196 --- /dev/null +++ b/conformance/tests/mesh/grpcroute-weight.yaml @@ -0,0 +1,19 @@ +apiVersion: gateway.networking.k8s.io/v1 +kind: GRPCRoute +metadata: + name: mesh-grpc-weighted-backends + namespace: gateway-conformance-mesh +spec: + parentRefs: + - group: "" + kind: Service + name: echo + port: 7070 + rules: + - backendRefs: + - name: echo-v1 + port: 7070 + weight: 70 + - name: echo-v2 + port: 7070 + weight: 30 diff --git a/conformance/tests/mesh/httproute-weight.go b/conformance/tests/mesh/httproute-weight.go index 844d5471e7..e2d565006d 100644 --- a/conformance/tests/mesh/httproute-weight.go +++ b/conformance/tests/mesh/httproute-weight.go @@ -17,20 +17,13 @@ limitations under the License. package meshtests import ( - "cmp" - "errors" "fmt" - "math" - "slices" - "strings" - "sync" "testing" - "golang.org/x/sync/errgroup" - "sigs.k8s.io/gateway-api/conformance/utils/echo" "sigs.k8s.io/gateway-api/conformance/utils/http" "sigs.k8s.io/gateway-api/conformance/utils/suite" + "sigs.k8s.io/gateway-api/conformance/utils/weight" "sigs.k8s.io/gateway-api/pkg/features" ) @@ -59,8 +52,26 @@ var MeshHTTPRouteWeight = suite.ConformanceTest{ // Assert request succeeds before doing our distribution check client.MakeRequestAndExpectEventuallyConsistentResponse(t, expected, s.TimeoutConfig) + + expectedWeights := map[string]float64{ + "echo-v1": 0.7, + "echo-v2": 0.3, + } + + sender := weight.NewFunctionBasedSender(func() (string, error) { + uniqueExpected := expected + if err := http.AddEntropy(&uniqueExpected); err != nil { + return "", fmt.Errorf("error adding entropy: %w", err) + } + _, cRes, err := client.CaptureRequestResponseAndCompare(t, uniqueExpected) + if err != nil { + return "", fmt.Errorf("failed mesh request: %w", err) + } + return cRes.Hostname, nil + }) + for i := 0; i < 10; i++ { - if err := testDistribution(t, client, expected); err != nil { + if err := weight.TestWeightedDistribution(sender, expectedWeights); err != nil { t.Logf("Traffic distribution test failed (%d/10): %s", i+1, err) } else { return @@ -70,78 +81,3 @@ var MeshHTTPRouteWeight = suite.ConformanceTest{ }) }, } - -func testDistribution(t *testing.T, client echo.MeshPod, expected http.ExpectedResponse) error { - const ( - concurrentRequests = 10 - tolerancePercentage = 0.05 - totalRequests = 500.0 - ) - var ( - g errgroup.Group - seenMutex sync.Mutex - seen = make(map[string]float64, 2 /* number of backends */) - expectedWeights = map[string]float64{ - "echo-v1": 0.7, - "echo-v2": 0.3, - } - ) - g.SetLimit(concurrentRequests) - for i := 0.0; i < totalRequests; i++ { - g.Go(func() error { - uniqueExpected := expected - if err := http.AddEntropy(&uniqueExpected); err != nil { - return fmt.Errorf("error adding entropy: %w", err) - } - _, cRes, err := client.CaptureRequestResponseAndCompare(t, uniqueExpected) - if err != nil { - return fmt.Errorf("failed: %w", err) - } - - seenMutex.Lock() - defer seenMutex.Unlock() - - for expectedBackend := range expectedWeights { - if strings.HasPrefix(cRes.Hostname, expectedBackend) { - seen[expectedBackend]++ - return nil - } - } - - return fmt.Errorf("request was handled by an unexpected pod %q", cRes.Hostname) - }) - } - - if err := g.Wait(); err != nil { - return fmt.Errorf("error while sending requests: %w", err) - } - - var errs []error - if len(seen) != 2 { - errs = append(errs, fmt.Errorf("expected only two backends to receive traffic")) - } - - for wantBackend, wantPercent := range expectedWeights { - gotCount, ok := seen[wantBackend] - - if !ok && wantPercent != 0.0 { - errs = append(errs, fmt.Errorf("expect traffic to hit backend %q - but none was received", wantBackend)) - continue - } - - gotPercent := gotCount / totalRequests - - if math.Abs(gotPercent-wantPercent) > tolerancePercentage { - errs = append(errs, fmt.Errorf("backend %q weighted traffic of %v not within tolerance %v (+/-%f)", - wantBackend, - gotPercent, - wantPercent, - tolerancePercentage, - )) - } - } - slices.SortFunc(errs, func(a, b error) int { - return cmp.Compare(a.Error(), b.Error()) - }) - return errors.Join(errs...) -} diff --git a/conformance/utils/echo/pod.go b/conformance/utils/echo/pod.go index 0222ff12dd..4101053650 100644 --- a/conformance/utils/echo/pod.go +++ b/conformance/utils/echo/pod.go @@ -80,7 +80,15 @@ func makeRequest(t *testing.T, exp *http.ExpectedResponse) []string { if exp.Request.Host == "" { exp.Request.Host = "echo" } - if exp.Request.Method == "" { + + r := exp.Request + protocol := strings.ToLower(r.Protocol) + if protocol == "" { + protocol = "http" + } + + // Only set default method for HTTP protocols, not for gRPC + if exp.Request.Method == "" && protocol != "grpc" { exp.Request.Method = "GET" } @@ -88,14 +96,9 @@ func makeRequest(t *testing.T, exp *http.ExpectedResponse) []string { exp.Response.StatusCode = 200 } - r := exp.Request - protocol := strings.ToLower(r.Protocol) - if protocol == "" { - protocol = "http" - } host := http.CalculateHost(t, r.Host, protocol) args := []string{"client", fmt.Sprintf("%s://%s%s", protocol, host, r.Path)} - if r.Method != "" { + if r.Method != "" && protocol != "grpc" { args = append(args, "--method="+r.Method) } for k, v := range r.Headers { diff --git a/conformance/utils/grpc/grpc.go b/conformance/utils/grpc/grpc.go index 70adfe2eac..25b1d7abf8 100644 --- a/conformance/utils/grpc/grpc.go +++ b/conformance/utils/grpc/grpc.go @@ -35,6 +35,7 @@ import ( "sigs.k8s.io/gateway-api/conformance/utils/config" "sigs.k8s.io/gateway-api/conformance/utils/http" "sigs.k8s.io/gateway-api/conformance/utils/tlog" + "sigs.k8s.io/gateway-api/conformance/utils/weight" ) const ( @@ -289,3 +290,20 @@ func MakeRequestAndExpectEventuallyConsistentResponse(t *testing.T, c Client, ti http.AwaitConvergence(t, timeoutConfig.RequiredConsecutiveSuccesses, timeoutConfig.MaxTimeToConsistency, sendRPC) tlog.Logf(t, "Request passed") } + +// AddEntropy adds randomness to ExpectedResponse to avoid caching issues and ensure each request is unique. +// It randomly chooses to add a delay, random metadata, or both. +func AddEntropy(exp *ExpectedResponse) error { + addRandomMetadata := func(randomValue string) error { + if exp.RequestMetadata == nil { + exp.RequestMetadata = &RequestMetadata{} + } + if exp.RequestMetadata.Metadata == nil { + exp.RequestMetadata.Metadata = make(map[string]string) + } + exp.RequestMetadata.Metadata["x-jitter"] = randomValue + return nil + } + + return weight.AddRandomEntropy(addRandomMetadata) +} diff --git a/conformance/utils/http/http.go b/conformance/utils/http/http.go index da7036daf6..baff0a7b1f 100644 --- a/conformance/utils/http/http.go +++ b/conformance/utils/http/http.go @@ -17,9 +17,7 @@ limitations under the License. package http import ( - "crypto/rand" "fmt" - "math/big" "net" "net/url" "strings" @@ -29,6 +27,7 @@ import ( "sigs.k8s.io/gateway-api/conformance/utils/config" "sigs.k8s.io/gateway-api/conformance/utils/roundtripper" "sigs.k8s.io/gateway-api/conformance/utils/tlog" + "sigs.k8s.io/gateway-api/conformance/utils/weight" ) // ExpectedResponse defines the response expected for a given request. @@ -481,53 +480,13 @@ func setRedirectRequestDefaults(req *roundtripper.Request, cRes *roundtripper.Ca } } -// addEntropy adds jitter to the request by adding either a delay up to 1 second, or a random header value, or both. +// AddEntropy adds jitter to the request by adding either a delay up to 1 second, or a random header value, or both. func AddEntropy(exp *ExpectedResponse) error { - randomNumber := func(limit int64) (*int64, error) { - number, err := rand.Int(rand.Reader, big.NewInt(limit)) - if err != nil { - return nil, err - } - n := number.Int64() - return &n, nil - } - - // adds a delay - delay := func(limit int64) error { - randomSleepDuration, err := randomNumber(limit) - if err != nil { - return err - } - time.Sleep(time.Duration(*randomSleepDuration) * time.Millisecond) - return nil - } - // adds random header value - randomHeader := func(limit int64) error { - randomHeaderValue, err := randomNumber(limit) - if err != nil { - return err - } + addRandomHeader := func(randomValue string) error { exp.Request.Headers = make(map[string]string) - exp.Request.Headers["X-Jitter"] = fmt.Sprintf("%d", *randomHeaderValue) + exp.Request.Headers["X-Jitter"] = randomValue return nil } - random, err := randomNumber(3) - if err != nil { - return err - } - - switch *random { - case 0: - return delay(1000) - case 1: - return randomHeader(10000) - case 2: - if err := delay(1000); err != nil { - return err - } - return randomHeader(10000) - default: - return fmt.Errorf("invalid random value: %d", *random) - } + return weight.AddRandomEntropy(addRandomHeader) } diff --git a/conformance/utils/weight/senders.go b/conformance/utils/weight/senders.go new file mode 100644 index 0000000000..304ead0e9f --- /dev/null +++ b/conformance/utils/weight/senders.go @@ -0,0 +1,31 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package weight + +// FunctionBasedSender implements RequestSender using a function +type FunctionBasedSender struct { + sendFunc func() (string, error) +} + +func (s *FunctionBasedSender) SendRequest() (string, error) { + return s.sendFunc() +} + +// NewFunctionBasedSender creates a RequestSender from a function +func NewFunctionBasedSender(sendFunc func() (string, error)) RequestSender { + return &FunctionBasedSender{sendFunc: sendFunc} +} diff --git a/conformance/utils/weight/weight.go b/conformance/utils/weight/weight.go new file mode 100644 index 0000000000..5507b45e13 --- /dev/null +++ b/conformance/utils/weight/weight.go @@ -0,0 +1,179 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package weight + +import ( + "cmp" + "crypto/rand" + "errors" + "fmt" + "math" + "math/big" + "slices" + "strings" + "sync" + "time" + + "golang.org/x/sync/errgroup" +) + +// RequestSender defines an interface for sending requests (HTTP, gRPC, or mesh) +type RequestSender interface { + SendRequest() (podName string, err error) +} + +// TestWeightedDistribution tests that requests are distributed according to expected weights +func TestWeightedDistribution(sender RequestSender, expectedWeights map[string]float64) error { + const ( + concurrentRequests = 10 + tolerancePercentage = 0.05 + totalRequests = 500.0 + ) + + var ( + g errgroup.Group + seenMutex sync.Mutex + seen = make(map[string]float64, len(expectedWeights)) + ) + + g.SetLimit(concurrentRequests) + for i := 0.0; i < totalRequests; i++ { + g.Go(func() error { + podName, err := sender.SendRequest() + if err != nil { + return err + } + + seenMutex.Lock() + defer seenMutex.Unlock() + + for expectedBackend := range expectedWeights { + if strings.HasPrefix(podName, expectedBackend) { + seen[expectedBackend]++ + return nil + } + } + + return fmt.Errorf("request was handled by an unexpected pod %q", podName) + }) + } + + if err := g.Wait(); err != nil { + return fmt.Errorf("error while sending requests: %w", err) + } + + // Count how many backends should receive traffic (weight > 0) + expectedActiveBackends := 0 + for _, weight := range expectedWeights { + if weight > 0.0 { + expectedActiveBackends++ + } + } + + var errs []error + if len(seen) != expectedActiveBackends { + errs = append(errs, fmt.Errorf("expected %d backends to receive traffic, but got %d", expectedActiveBackends, len(seen))) + } + + for wantBackend, wantPercent := range expectedWeights { + gotCount, ok := seen[wantBackend] + + if !ok && wantPercent != 0.0 { + errs = append(errs, fmt.Errorf("expect traffic to hit backend %q - but none was received", wantBackend)) + continue + } + + gotPercent := gotCount / totalRequests + + if math.Abs(gotPercent-wantPercent) > tolerancePercentage { + errs = append(errs, fmt.Errorf("backend %q weighted traffic of %v not within tolerance %v (+/-%f)", + wantBackend, + gotPercent, + wantPercent, + tolerancePercentage, + )) + } + } + + slices.SortFunc(errs, func(a, b error) int { + return cmp.Compare(a.Error(), b.Error()) + }) + return errors.Join(errs...) +} + +// Entropy utilities + +// randomNumber generates a random number between 0 and limit-1 +func randomNumber(limit int64) (*int64, error) { + number, err := rand.Int(rand.Reader, big.NewInt(limit)) + if err != nil { + return nil, err + } + n := number.Int64() + return &n, nil +} + +// AddDelay adds a random delay up to the specified limit in milliseconds +func AddDelay(limit int64) error { + randomSleepDuration, err := randomNumber(limit) + if err != nil { + return err + } + time.Sleep(time.Duration(*randomSleepDuration) * time.Millisecond) + return nil +} + +// GenerateRandomValue generates a random value as a string for use in headers/metadata +func GenerateRandomValue(limit int64) (string, error) { + randomVal, err := randomNumber(limit) + if err != nil { + return "", err + } + return fmt.Sprintf("%d", *randomVal), nil +} + +// AddRandomEntropy randomly chooses to add delay, random value, or both +// The addRandomValue function should be provided by the caller to handle +// protocol-specific ways of adding the random value (HTTP headers, gRPC metadata, etc.) +func AddRandomEntropy(addRandomValue func(string) error) error { + random, err := randomNumber(3) + if err != nil { + return err + } + + switch *random { + case 0: + return AddDelay(1000) + case 1: + randomValue, err := GenerateRandomValue(10000) + if err != nil { + return err + } + return addRandomValue(randomValue) + case 2: + if err := AddDelay(1000); err != nil { + return err + } + randomValue, err := GenerateRandomValue(10000) + if err != nil { + return err + } + return addRandomValue(randomValue) + default: + return fmt.Errorf("invalid random value: %d", *random) + } +}