@@ -22,6 +22,7 @@ import (
22
22
"math/big"
23
23
"net"
24
24
"net/url"
25
+ "slices"
25
26
"strings"
26
27
"testing"
27
28
"time"
@@ -91,8 +92,9 @@ type ExpectedRequest struct {
91
92
92
93
// Response defines expected properties of a response from a backend.
93
94
type Response struct {
95
+ // Deprecated: Use StatusCodes instead, which supports matching against multiple status codes.
94
96
StatusCode int
95
- StatusCodes []int // alternative to StatusCode, allows multiple acceptable codes
97
+ StatusCodes []int
96
98
Headers map [string ]string
97
99
AbsentHeaders []string
98
100
Protocol string
@@ -139,8 +141,13 @@ func MakeRequest(t *testing.T, expected *ExpectedResponse, gwAddr, protocol, sch
139
141
expected .Request .Method = "GET"
140
142
}
141
143
142
- if expected .Response .StatusCode == 0 {
143
- expected .Response .StatusCode = 200
144
+ // if the deprecated field StatusCode is set, append it to StatusCodes for backwards compatibility
145
+ if expected .Response .StatusCode != 0 {
146
+ expected .Response .StatusCodes = append (expected .Response .StatusCodes , expected .Response .StatusCode )
147
+ }
148
+
149
+ if len (expected .Response .StatusCodes ) == 0 {
150
+ expected .Response .StatusCodes = []int {200 }
144
151
}
145
152
146
153
if expected .Request .Protocol == "" {
@@ -301,23 +308,14 @@ func WaitForConsistentFailureResponse(t *testing.T, r roundtripper.RoundTripper,
301
308
302
309
func CompareRoundTrip (t * testing.T , req * roundtripper.Request , cReq * roundtripper.CapturedRequest , cRes * roundtripper.CapturedResponse , expected ExpectedResponse ) error {
303
310
if roundtripper .IsTimeoutError (cRes .StatusCode ) {
304
- if roundtripper .IsTimeoutError (expected .Response .StatusCode ) {
305
- return nil
306
- }
307
- }
308
- if len (expected .Response .StatusCodes ) > 0 {
309
- matched := false
310
- for _ , code := range expected .Response .StatusCodes {
311
- if code == cRes .StatusCode {
312
- matched = true
313
- break
311
+ for _ , statusCode := range expected .Response .StatusCodes {
312
+ if roundtripper .IsTimeoutError (statusCode ) {
313
+ return nil
314
314
}
315
315
}
316
- if ! matched {
317
- return fmt .Errorf ("expected status code to be one of %v, got %d" , expected .Response .StatusCodes , cRes .StatusCode )
318
- }
319
- } else if expected .Response .StatusCode != cRes .StatusCode {
320
- return fmt .Errorf ("expected status code to be %d, got %d. CRes: %v" , expected .Response .StatusCode , cRes .StatusCode , cRes )
316
+ }
317
+ if ! slices .Contains (expected .Response .StatusCodes , cRes .StatusCode ) {
318
+ return fmt .Errorf ("expected status code to be one of %v, got %d. CRes: %v" , expected .Response .StatusCodes , cRes .StatusCode , cRes )
321
319
}
322
320
if expected .Response .Protocol != "" && expected .Response .Protocol != cRes .Protocol {
323
321
return fmt .Errorf ("expected protocol to be %s, got %s" , expected .Response .Protocol , cRes .Protocol )
@@ -479,10 +477,8 @@ func (er *ExpectedResponse) GetTestCaseName(i int) string {
479
477
if er .Backend != "" {
480
478
return fmt .Sprintf ("%s should go to %s" , reqStr , er .Backend )
481
479
}
482
- if len (er .Response .StatusCodes ) > 0 {
483
- return fmt .Sprintf ("%s should receive one of %v" , reqStr , er .Response .StatusCodes )
484
- }
485
- return fmt .Sprintf ("%s should receive a %d" , reqStr , er .Response .StatusCode )
480
+
481
+ return fmt .Sprintf ("%s should receive one of %v" , reqStr , er .Response .StatusCodes )
486
482
}
487
483
488
484
func setRedirectRequestDefaults (req * roundtripper.Request , cRes * roundtripper.CapturedResponse , expected * ExpectedResponse ) {
0 commit comments