@@ -16,8 +16,19 @@ func handlerCausingNetworkError() http.Handler {
1616 return httphelpers .BrokenConnectionHandler ()
1717}
1818
19- func handlerCausingHTTPError (status int ) http.Handler {
20- return httphelpers .HandlerWithStatus (status )
19+ func handlerCausingHTTPError (status int , header * http.Header ) http.Handler {
20+ if header == nil {
21+ return httphelpers .HandlerWithStatus (status )
22+ }
23+
24+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
25+ for key , values := range * header {
26+ for _ , value := range values {
27+ w .Header ().Add (key , value )
28+ }
29+ }
30+ httphelpers .HandlerWithStatus (status ).ServeHTTP (w , r )
31+ })
2132}
2233
2334func shouldBeNetworkError (t * testing.T ) func (error ) {
@@ -28,9 +39,24 @@ func shouldBeNetworkError(t *testing.T) func(error) {
2839 }
2940}
3041
31- func shouldBeHTTPError (t * testing.T , status int ) func (error ) {
42+ func shouldBeHTTPError (t * testing.T , status int , header * http. Header ) func (error ) {
3243 return func (err error ) {
33- assert .Equal (t , SubscriptionError {Code : status }, err )
44+ switch e := err .(type ) {
45+ case SubscriptionError :
46+ assert .Equal (t , status , e .Code )
47+ if header != nil {
48+ for key , value := range * header {
49+ if v , ok := e .Header [key ]; ok {
50+ assert .Equal (t , v , value )
51+ } else {
52+ assert .Fail (t , "header not found" , "header %s not found in error headers" , key )
53+ }
54+ }
55+ }
56+ default :
57+ t .Errorf ("expected SubscriptionError, got %T" , e )
58+ return
59+ }
3460 }
3561}
3662
@@ -39,7 +65,10 @@ func TestStreamDoesNotRetryInitialConnectionByDefaultAfterNetworkError(t *testin
3965}
4066
4167func TestStreamDoesNotRetryInitialConnectionByDefaultAfterHTTPError (t * testing.T ) {
42- testStreamDoesNotRetryInitialConnectionByDefault (t , handlerCausingHTTPError (401 ), shouldBeHTTPError (t , 401 ))
68+ header := http.Header {
69+ "X-My-Header" : []string {"my-value" },
70+ }
71+ testStreamDoesNotRetryInitialConnectionByDefault (t , handlerCausingHTTPError (401 , & header ), shouldBeHTTPError (t , 401 , & header ))
4372}
4473
4574func testStreamDoesNotRetryInitialConnectionByDefault (t * testing.T , errorHandler http.Handler , checkError func (error )) {
@@ -135,7 +164,7 @@ func TestStreamErrorHandlerCanAllowRetryOfInitialConnectionAfterNetworkError(t *
135164}
136165
137166func TestStreamErrorHandlerCanAllowRetryOfInitialConnectionAfterHTTPError (t * testing.T ) {
138- testStreamErrorHandlerCanAllowRetryOfInitialConnection (t , handlerCausingHTTPError (401 ), shouldBeHTTPError (t , 401 ))
167+ testStreamErrorHandlerCanAllowRetryOfInitialConnection (t , handlerCausingHTTPError (401 , nil ), shouldBeHTTPError (t , 401 , nil ))
139168}
140169
141170func testStreamErrorHandlerCanAllowRetryOfInitialConnection (t * testing.T , errorHandler http.Handler , checkError func (error )) {
0 commit comments