Skip to content

Commit 39847a7

Browse files
TUN-7558: Flush on Writes for StreamBasedOriginProxy
In the streambased origin proxy flow (example ssh over access), there is a chance when we do not flush on http.ResponseWriter writes. This PR guarantees that the response writer passed to proxy stream has a flusher embedded after writes. This means we write much more often back to the ResponseWriter and are not waiting. Note, this is only something we do when proxyHTTP-ing to a StreamBasedOriginProxy because we do not want to have situations where we are not sending information that is needed by the other side (eyeball).
1 parent d1e338e commit 39847a7

File tree

7 files changed

+30
-9
lines changed

7 files changed

+30
-9
lines changed

connection/connection.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,16 @@ type ReadWriteAcker interface {
157157
type HTTPResponseReadWriteAcker struct {
158158
r io.Reader
159159
w ResponseWriter
160+
f http.Flusher
160161
req *http.Request
161162
}
162163

163164
// NewHTTPResponseReadWriterAcker returns a new instance of HTTPResponseReadWriteAcker.
164-
func NewHTTPResponseReadWriterAcker(w ResponseWriter, req *http.Request) *HTTPResponseReadWriteAcker {
165+
func NewHTTPResponseReadWriterAcker(w ResponseWriter, flusher http.Flusher, req *http.Request) *HTTPResponseReadWriteAcker {
165166
return &HTTPResponseReadWriteAcker{
166167
r: req.Body,
167168
w: w,
169+
f: flusher,
168170
req: req,
169171
}
170172
}
@@ -174,7 +176,11 @@ func (h *HTTPResponseReadWriteAcker) Read(p []byte) (int, error) {
174176
}
175177

176178
func (h *HTTPResponseReadWriteAcker) Write(p []byte) (int, error) {
177-
return h.w.Write(p)
179+
n, err := h.w.Write(p)
180+
if n > 0 {
181+
h.f.Flush()
182+
}
183+
return n, err
178184
}
179185

180186
// AckConnection acks an HTTP connection by sending a switch protocols status code that enables the caller to

connection/connection_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
130130
}
131131
wsCtx, cancel := context.WithCancel(r.Context())
132132
readPipe, writePipe := io.Pipe()
133-
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
133+
134+
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
134135
go func() {
135136
select {
136137
case <-wsCtx.Done():
@@ -175,7 +176,7 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
175176
}
176177
wsCtx, cancel := context.WithCancel(r.Context())
177178

178-
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
179+
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
179180

180181
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
181182
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}

connection/http2.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
142142
break
143143
}
144144

145-
rws := NewHTTPResponseReadWriterAcker(respWriter, r)
145+
rws := NewHTTPResponseReadWriterAcker(respWriter, respWriter, r)
146146
requestErr = originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{
147147
Dest: host,
148148
CFRay: FindCfRayHeader(r),
@@ -289,6 +289,10 @@ func (rp *http2RespWriter) Header() http.Header {
289289
return rp.respHeaders
290290
}
291291

292+
func (rp *http2RespWriter) Flush() {
293+
rp.flusher.Flush()
294+
}
295+
292296
func (rp *http2RespWriter) WriteHeader(status int) {
293297
if rp.hijacked() {
294298
rp.log.Warn().Msg("WriteHeader after hijack")

connection/quic.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,9 @@ func (hrw *httpResponseAdapter) Header() http.Header {
461461
return hrw.headers
462462
}
463463

464+
// This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here.
465+
func (hrw *httpResponseAdapter) Flush() {}
466+
464467
func (hrw *httpResponseAdapter) WriteHeader(status int) {
465468
hrw.WriteRespHeaders(status, hrw.headers)
466469
}

orchestration/orchestrator_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ func proxyTCP(ctx context.Context, originProxy connection.OriginProxy, originAdd
450450
CFRay: "123",
451451
LBProbe: false,
452452
}
453-
rws := connection.NewHTTPResponseReadWriterAcker(respWriter, req)
453+
rws := connection.NewHTTPResponseReadWriterAcker(respWriter, w.(http.Flusher), req)
454454

455455
return originProxy.ProxyTCP(ctx, rws, tcpReq)
456456
}

proxy/proxy.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,11 @@ func (p *Proxy) ProxyHTTP(
136136
if err != nil {
137137
return err
138138
}
139-
140-
rws := connection.NewHTTPResponseReadWriterAcker(w, req)
139+
flusher, ok := w.(http.Flusher)
140+
if !ok {
141+
return fmt.Errorf("response writer is not a flusher")
142+
}
143+
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
141144
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy); err != nil {
142145
rule, srv := ruleField(p.ingressRules, ruleNum)
143146
p.logRequestError(err, cfRay, "", rule, srv)

proxy/proxy_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ func TestConnections(t *testing.T) {
698698
}()
699699
}
700700
if test.args.connectionType == connection.TypeTCP {
701-
rwa := connection.NewHTTPResponseReadWriterAcker(respWriter, req)
701+
rwa := connection.NewHTTPResponseReadWriterAcker(respWriter, respWriter.(http.Flusher), req)
702702
err = proxy.ProxyTCP(ctx, rwa, &connection.TCPRequest{Dest: dest})
703703
} else {
704704
log := zerolog.Nop()
@@ -834,6 +834,8 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
834834
return nil
835835
}
836836

837+
func (w *wsRespWriter) Flush() {}
838+
837839
func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) {
838840
// do nothing
839841
}
@@ -873,6 +875,8 @@ func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
873875
return m.w.Write(p)
874876
}
875877

878+
func (m *mockTCPRespWriter) Flush() {}
879+
876880
func (m *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) {
877881
// do nothing
878882
}

0 commit comments

Comments
 (0)