Skip to content

Commit 84d918d

Browse files
Verify forwarded proto in service test
1 parent d7444e5 commit 84d918d

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

internal/server/service_test.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ func TestService_RedirectToHTTPSWhenTLSRequired(t *testing.T) {
4444
}
4545

4646
func TestService_DontRedirectToHTTPSWhenTLSAndPlainHTTPAllowed(t *testing.T) {
47-
service := testCreateService(t, []string{"example.com"}, ServiceOptions{TLSEnabled: true, TLSDisableRedirect: true}, defaultTargetOptions)
47+
var forwardedProto string
48+
49+
service := testCreateServiceWithHandler(t, []string{"example.com"}, ServiceOptions{TLSEnabled: true, TLSDisableRedirect: true}, defaultTargetOptions,
50+
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
51+
forwardedProto = r.Header.Get("X-Forwarded-Proto")
52+
}),
53+
)
4854

4955
require.True(t, service.options.TLSEnabled)
5056

@@ -53,12 +59,14 @@ func TestService_DontRedirectToHTTPSWhenTLSAndPlainHTTPAllowed(t *testing.T) {
5359
service.ServeHTTP(w, req)
5460

5561
require.Equal(t, http.StatusOK, w.Result().StatusCode)
62+
assert.Equal(t, "http", forwardedProto)
5663

5764
req = httptest.NewRequest(http.MethodGet, "https://example.com", nil)
5865
w = httptest.NewRecorder()
5966
service.ServeHTTP(w, req)
6067

6168
require.Equal(t, http.StatusOK, w.Result().StatusCode)
69+
assert.Equal(t, "https", forwardedProto)
6270
}
6371

6472
func TestService_UseStaticTLSCertificateWhenConfigured(t *testing.T) {
@@ -154,7 +162,13 @@ func TestService_MarshallingState(t *testing.T) {
154162
}
155163

156164
func testCreateService(t *testing.T, hosts []string, options ServiceOptions, targetOptions TargetOptions) *Service {
157-
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
165+
return testCreateServiceWithHandler(t, hosts, options, targetOptions,
166+
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
167+
)
168+
}
169+
170+
func testCreateServiceWithHandler(t *testing.T, hosts []string, options ServiceOptions, targetOptions TargetOptions, handler http.Handler) *Service {
171+
server := httptest.NewServer(handler)
158172
t.Cleanup(server.Close)
159173

160174
serverURL, err := url.Parse(server.URL)

0 commit comments

Comments
 (0)