@@ -22,6 +22,7 @@ import (
22
22
"math/big"
23
23
"net"
24
24
"net/url"
25
+ "slices"
25
26
"strings"
26
27
"testing"
27
28
"time"
@@ -91,8 +92,9 @@ type ExpectedRequest struct {
91
92
92
93
// Response defines expected properties of a response from a backend.
93
94
type Response struct {
95
+ // Deprecated: Use StatusCodes instead, which supports matching against multiple status codes.
94
96
StatusCode int
95
- StatusCodes []int // alternative to StatusCode, allows multiple acceptable codes
97
+ StatusCodes []int
96
98
Headers map [string ]string
97
99
AbsentHeaders []string
98
100
Protocol string
@@ -139,8 +141,11 @@ func MakeRequest(t *testing.T, expected *ExpectedResponse, gwAddr, protocol, sch
139
141
expected .Request .Method = "GET"
140
142
}
141
143
142
- if expected .Response .StatusCode == 0 {
143
- expected .Response .StatusCode = 200
144
+ if expected .Response .StatusCode != 0 {
145
+ expected .Response .StatusCodes = append (expected .Response .StatusCodes , expected .Response .StatusCode )
146
+ }
147
+ if len (expected .Response .StatusCodes ) == 0 {
148
+ expected .Response .StatusCodes = []int {200 }
144
149
}
145
150
146
151
if expected .Request .Protocol == "" {
@@ -301,23 +306,14 @@ func WaitForConsistentFailureResponse(t *testing.T, r roundtripper.RoundTripper,
301
306
302
307
func CompareRoundTrip (t * testing.T , req * roundtripper.Request , cReq * roundtripper.CapturedRequest , cRes * roundtripper.CapturedResponse , expected ExpectedResponse ) error {
303
308
if roundtripper .IsTimeoutError (cRes .StatusCode ) {
304
- if roundtripper .IsTimeoutError (expected .Response .StatusCode ) {
305
- return nil
306
- }
307
- }
308
- if len (expected .Response .StatusCodes ) > 0 {
309
- matched := false
310
- for _ , code := range expected .Response .StatusCodes {
311
- if code == cRes .StatusCode {
312
- matched = true
313
- break
309
+ for _ , statusCode := range expected .Response .StatusCodes {
310
+ if roundtripper .IsTimeoutError (statusCode ) {
311
+ return nil
314
312
}
315
313
}
316
- if ! matched {
317
- return fmt .Errorf ("expected status code to be one of %v, got %d" , expected .Response .StatusCodes , cRes .StatusCode )
318
- }
319
- } else if expected .Response .StatusCode != cRes .StatusCode {
320
- return fmt .Errorf ("expected status code to be %d, got %d. CRes: %v" , expected .Response .StatusCode , cRes .StatusCode , cRes )
314
+ }
315
+ if ! slices .Contains (expected .Response .StatusCodes , cRes .StatusCode ) {
316
+ return fmt .Errorf ("expected status code to be one of %v, got %d. CRes: %v" , expected .Response .StatusCodes , cRes .StatusCode , cRes )
321
317
}
322
318
if expected .Response .Protocol != "" && expected .Response .Protocol != cRes .Protocol {
323
319
return fmt .Errorf ("expected protocol to be %s, got %s" , expected .Response .Protocol , cRes .Protocol )
@@ -479,10 +475,8 @@ func (er *ExpectedResponse) GetTestCaseName(i int) string {
479
475
if er .Backend != "" {
480
476
return fmt .Sprintf ("%s should go to %s" , reqStr , er .Backend )
481
477
}
482
- if len (er .Response .StatusCodes ) > 0 {
483
- return fmt .Sprintf ("%s should receive one of %v" , reqStr , er .Response .StatusCodes )
484
- }
485
- return fmt .Sprintf ("%s should receive a %d" , reqStr , er .Response .StatusCode )
478
+
479
+ return fmt .Sprintf ("%s should receive one of %v" , reqStr , er .Response .StatusCodes )
486
480
}
487
481
488
482
func setRedirectRequestDefaults (req * roundtripper.Request , cRes * roundtripper.CapturedResponse , expected * ExpectedResponse ) {
0 commit comments