Skip to content

Commit 1ff5fd3

Browse files
committed
TUN-5744: Add a test to make sure cloudflared uses scheme defined in ingress rule, not X-Forwarded-Proto header
1 parent 5b12e74 commit 1ff5fd3

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

ingress/origin_proxy_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,48 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
147147
respBody, err := ioutil.ReadAll(resp.Body)
148148
require.NoError(t, err)
149149
require.Equal(t, respBody, []byte(originURL.Host))
150+
}
151+
152+
// TestHTTPServiceUsesIngressRuleScheme makes sure httpService uses scheme defined in ingress rule and not by eyeball request
153+
func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
154+
handler := func(w http.ResponseWriter, r *http.Request) {
155+
require.NotNil(t, r.TLS)
156+
// Echo the X-Forwarded-Proto header for assertions
157+
w.Write([]byte(r.Header.Get("X-Forwarded-Proto")))
158+
}
159+
origin := httptest.NewTLSServer(http.HandlerFunc(handler))
160+
defer origin.Close()
150161

162+
originURL, err := url.Parse(origin.URL)
163+
require.NoError(t, err)
164+
require.Equal(t, "https", originURL.Scheme)
165+
166+
cfg := OriginRequestConfig{
167+
NoTLSVerify: true,
168+
}
169+
httpService := &httpService{
170+
url: originURL,
171+
}
172+
var wg sync.WaitGroup
173+
shutdownC := make(chan struct{})
174+
errC := make(chan error)
175+
require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg))
176+
177+
// Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
178+
protos := []string{"https", "http", "dne"}
179+
for _, p := range protos {
180+
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
181+
require.NoError(t, err)
182+
req.Header.Add("X-Forwarded-Proto", p)
183+
184+
resp, err := httpService.RoundTrip(req)
185+
require.NoError(t, err)
186+
require.Equal(t, http.StatusOK, resp.StatusCode)
187+
188+
respBody, err := ioutil.ReadAll(resp.Body)
189+
require.NoError(t, err)
190+
require.Equal(t, respBody, []byte(p))
191+
}
151192
}
152193

153194
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {

0 commit comments

Comments
 (0)