Skip to content

Commit f6bd4aa

Browse files
committed
TUN-6676: Add suport for trailers in http2 connections
1 parent d2bc15e commit f6bd4aa

File tree

7 files changed

+89
-89
lines changed

7 files changed

+89
-89
lines changed

connection/connection.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,16 @@ const (
2424
LogFieldConnIndex = "connIndex"
2525
MaxGracePeriod = time.Minute * 3
2626
MaxConcurrentStreams = math.MaxUint32
27+
28+
contentTypeHeader = "content-type"
29+
sseContentType = "text/event-stream"
30+
grpcContentType = "application/grpc"
2731
)
2832

29-
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
33+
var (
34+
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
35+
flushableContentTypes = []string{sseContentType, grpcContentType}
36+
)
3037

3138
type Orchestrator interface {
3239
UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse
@@ -190,6 +197,7 @@ func (h *HTTPResponseReadWriteAcker) AckConnection(tracePropagation string) erro
190197

191198
type ResponseWriter interface {
192199
WriteRespHeaders(status int, header http.Header) error
200+
AddTrailer(trailerName, trailerValue string)
193201
io.Writer
194202
}
195203

@@ -198,10 +206,18 @@ type ConnectedFuse interface {
198206
IsConnected() bool
199207
}
200208

201-
func IsServerSentEvent(headers http.Header) bool {
202-
if contentType := headers.Get("content-type"); contentType != "" {
203-
return strings.HasPrefix(strings.ToLower(contentType), "text/event-stream")
209+
// Helper method to let the caller know what content-types should require a flush on every
210+
// write to a ResponseWriter.
211+
func shouldFlush(headers http.Header) bool {
212+
if contentType := headers.Get(contentTypeHeader); contentType != "" {
213+
contentType = strings.ToLower(contentType)
214+
for _, c := range flushableContentTypes {
215+
if strings.HasPrefix(contentType, c) {
216+
return true
217+
}
218+
}
204219
}
220+
205221
return false
206222
}
207223

connection/connection_test.go

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ import (
66
"io"
77
"math/rand"
88
"net/http"
9-
"testing"
109
"time"
1110

1211
"github.com/rs/zerolog"
13-
"github.com/stretchr/testify/assert"
1412

1513
"github.com/cloudflare/cloudflared/tracing"
1614
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@@ -197,40 +195,3 @@ func (mcf mockConnectedFuse) Connected() {}
197195
func (mcf mockConnectedFuse) IsConnected() bool {
198196
return true
199197
}
200-
201-
func TestIsEventStream(t *testing.T) {
202-
tests := []struct {
203-
headers http.Header
204-
isEventStream bool
205-
}{
206-
{
207-
headers: newHeader("Content-Type", "text/event-stream"),
208-
isEventStream: true,
209-
},
210-
{
211-
headers: newHeader("content-type", "text/event-stream"),
212-
isEventStream: true,
213-
},
214-
{
215-
headers: newHeader("Content-Type", "text/event-stream; charset=utf-8"),
216-
isEventStream: true,
217-
},
218-
{
219-
headers: newHeader("Content-Type", "application/json"),
220-
isEventStream: false,
221-
},
222-
{
223-
headers: http.Header{},
224-
isEventStream: false,
225-
},
226-
}
227-
for _, test := range tests {
228-
assert.Equal(t, test.isEventStream, IsServerSentEvent(test.headers))
229-
}
230-
}
231-
232-
func newHeader(key, value string) http.Header {
233-
header := http.Header{}
234-
header.Add(key, value)
235-
return header
236-
}

connection/h2mux.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ type h2muxRespWriter struct {
259259
*h2mux.MuxedStream
260260
}
261261

262+
func (rp *h2muxRespWriter) AddTrailer(trailerName, trailerValue string) {
263+
// do nothing. we don't support trailers over h2mux
264+
}
265+
262266
func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
263267
headers := H1ResponseToH2ResponseHeaders(status, header)
264268
headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin})

connection/http2.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,12 @@ func (c *HTTP2Connection) close() {
191191
}
192192

193193
type http2RespWriter struct {
194-
r io.Reader
195-
w http.ResponseWriter
196-
flusher http.Flusher
197-
shouldFlush bool
198-
log *zerolog.Logger
194+
r io.Reader
195+
w http.ResponseWriter
196+
flusher http.Flusher
197+
shouldFlush bool
198+
statusWritten bool
199+
log *zerolog.Logger
199200
}
200201

201202
func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, log *zerolog.Logger) (*http2RespWriter, error) {
@@ -219,11 +220,20 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l
219220
}, nil
220221
}
221222

223+
func (rp *http2RespWriter) AddTrailer(trailerName, trailerValue string) {
224+
if !rp.statusWritten {
225+
rp.log.Warn().Msg("Tried to add Trailer to response before status written. Ignoring...")
226+
return
227+
}
228+
229+
rp.w.Header().Add(http2.TrailerPrefix+trailerName, trailerValue)
230+
}
231+
222232
func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error {
223233
dest := rp.w.Header()
224234
userHeaders := make(http.Header, len(header))
225235
for name, values := range header {
226-
// Since these are http2 headers, they're required to be lowercase
236+
// lowercase headers for simplicity check
227237
h2name := strings.ToLower(name)
228238

229239
if h2name == "content-length" {
@@ -234,7 +244,7 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro
234244

235245
if h2name == tracing.IntCloudflaredTracingHeader {
236246
// Add cf-int-cloudflared-tracing header outside of serialized userHeaders
237-
rp.w.Header()[tracing.CanonicalCloudflaredTracingHeader] = values
247+
dest[tracing.CanonicalCloudflaredTracingHeader] = values
238248
continue
239249
}
240250

@@ -247,18 +257,21 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro
247257

248258
// Perform user header serialization and set them in the single header
249259
dest.Set(CanonicalResponseUserHeaders, SerializeHeaders(userHeaders))
260+
250261
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
251262
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
252263
if status == http.StatusSwitchingProtocols {
253264
status = http.StatusOK
254265
}
255266
rp.w.WriteHeader(status)
256-
if IsServerSentEvent(header) {
267+
if shouldFlush(header) {
257268
rp.shouldFlush = true
258269
}
259270
if rp.shouldFlush {
260271
rp.flusher.Flush()
261272
}
273+
274+
rp.statusWritten = true
262275
return nil
263276
}
264277

connection/quic.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter
329329
return httpResponseAdapter{s}
330330
}
331331

332+
func (hrw httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
333+
// we do not support trailers over QUIC
334+
}
335+
332336
func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
333337
metadata := make([]quicpogs.Metadata, 0)
334338
metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})

proxy/proxy.go

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package proxy
22

33
import (
4-
"bufio"
54
"context"
65
"fmt"
76
"io"
@@ -29,6 +28,8 @@ const (
2928
LogFieldRule = "ingressRule"
3029
LogFieldOriginService = "originService"
3130
LogFieldFlowID = "flowID"
31+
32+
trailerHeaderName = "Trailer"
3233
)
3334

3435
// Proxy represents a means to Proxy between cloudflared and the origin services.
@@ -207,15 +208,16 @@ func (p *Proxy) proxyHTTPRequest(
207208
tracing.EndWithStatusCode(ttfbSpan, resp.StatusCode)
208209
defer resp.Body.Close()
209210

210-
// resp headers can be nil
211-
if resp.Header == nil {
212-
resp.Header = make(http.Header)
211+
headers := make(http.Header, len(resp.Header))
212+
// copy headers
213+
for k, v := range resp.Header {
214+
headers[k] = v
213215
}
214216

215217
// Add spans to response header (if available)
216-
tr.AddSpans(resp.Header)
218+
tr.AddSpans(headers)
217219

218-
err = w.WriteRespHeaders(resp.StatusCode, resp.Header)
220+
err = w.WriteRespHeaders(resp.StatusCode, headers)
219221
if err != nil {
220222
return errors.Wrap(err, "Error writing response header")
221223
}
@@ -236,12 +238,10 @@ func (p *Proxy) proxyHTTPRequest(
236238
return nil
237239
}
238240

239-
if connection.IsServerSentEvent(resp.Header) {
240-
p.log.Debug().Msg("Detected Server-Side Events from Origin")
241-
p.writeEventStream(w, resp.Body)
242-
} else {
243-
_, _ = cfio.Copy(w, resp.Body)
244-
}
241+
_, _ = cfio.Copy(w, resp.Body)
242+
243+
// copy trailers
244+
copyTrailers(w, resp)
245245

246246
p.logOriginResponse(resp, fields)
247247
return nil
@@ -296,26 +296,6 @@ func (wr *bidirectionalStream) Write(p []byte) (n int, err error) {
296296
return wr.writer.Write(p)
297297
}
298298

299-
func (p *Proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
300-
reader := bufio.NewReader(respBody)
301-
for {
302-
line, readErr := reader.ReadBytes('\n')
303-
304-
// We first try to write whatever we read even if an error occurred
305-
// The reason for doing it is to guarantee we really push everything to the eyeball side
306-
// before returning
307-
if len(line) > 0 {
308-
if _, writeErr := w.Write(line); writeErr != nil {
309-
return
310-
}
311-
}
312-
313-
if readErr != nil {
314-
return
315-
}
316-
}
317-
}
318-
319299
func (p *Proxy) appendTagHeaders(r *http.Request) {
320300
for _, tag := range p.tags {
321301
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
@@ -329,6 +309,14 @@ type logFields struct {
329309
flowID string
330310
}
331311

312+
func copyTrailers(w connection.ResponseWriter, response *http.Response) {
313+
for trailerHeader, trailerValues := range response.Trailer {
314+
for _, trailerValue := range trailerValues {
315+
w.AddTrailer(trailerHeader, trailerValue)
316+
}
317+
}
318+
}
319+
332320
func (p *Proxy) logRequest(r *http.Request, fields logFields) {
333321
if fields.cfRay != "" {
334322
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)

proxy/proxy_test.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222
"github.com/urfave/cli/v2"
2323
"golang.org/x/sync/errgroup"
2424

25+
"github.com/cloudflare/cloudflared/cfio"
26+
2527
"github.com/cloudflare/cloudflared/config"
2628
"github.com/cloudflare/cloudflared/connection"
2729
"github.com/cloudflare/cloudflared/hello"
@@ -62,6 +64,10 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) er
6264
return nil
6365
}
6466

67+
func (w *mockHTTPRespWriter) AddTrailer(trailerName, trailerValue string) {
68+
// do nothing
69+
}
70+
6571
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
6672
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
6773
}
@@ -117,7 +123,10 @@ func newMockSSERespWriter() *mockSSERespWriter {
117123
}
118124

119125
func (w *mockSSERespWriter) Write(data []byte) (int, error) {
120-
w.writeNotification <- data
126+
newData := make([]byte, len(data))
127+
copy(newData, data)
128+
129+
w.writeNotification <- newData
121130
return len(data), nil
122131
}
123132

@@ -256,11 +265,8 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
256265

257266
for i := 0; i < pushCount; i++ {
258267
line := responseWriter.ReadBytes()
259-
expect := fmt.Sprintf("%d\n", i)
268+
expect := fmt.Sprintf("%d\n\n", i)
260269
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
261-
262-
line = responseWriter.ReadBytes()
263-
require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line))
264270
}
265271

266272
cancel()
@@ -276,7 +282,7 @@ func testProxySSEAllData(proxy *Proxy) func(t *testing.T) {
276282
responseWriter := newMockSSERespWriter()
277283

278284
// responseWriter uses an unbuffered channel, so we call in a different go-routine
279-
go proxy.writeEventStream(responseWriter, eyeballReader)
285+
go cfio.Copy(responseWriter, eyeballReader)
280286

281287
result := string(<-responseWriter.writeNotification)
282288
require.Equal(t, "data\r\r", result)
@@ -825,6 +831,10 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
825831
return nil
826832
}
827833

834+
func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) {
835+
// do nothing
836+
}
837+
828838
// respHeaders is a test function to read respHeaders
829839
func (w *wsRespWriter) headers() http.Header {
830840
// Removing indeterminstic header because it cannot be asserted.
@@ -852,6 +862,10 @@ func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
852862
return m.w.Write(p)
853863
}
854864

865+
func (w *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) {
866+
// do nothing
867+
}
868+
855869
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
856870
m.responseHeaders = header
857871
m.code = status

0 commit comments

Comments
 (0)