Skip to content

Commit eef5b78

Browse files
committed
TUN-3480: Support SSE with http2 connection, and add SSE handler to hello-world server
1 parent 6b86f81 commit eef5b78

File tree

7 files changed

+156
-62
lines changed

7 files changed

+156
-62
lines changed

connection/connection.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"io"
55
"net/http"
66
"strconv"
7+
"strings"
78
"time"
89

910
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@@ -55,3 +56,7 @@ type ConnectedFuse interface {
5556
func uint8ToString(input uint8) string {
5657
return strconv.FormatUint(uint64(input), 10)
5758
}
59+
60+
func isServerSentEvent(headers http.Header) bool {
61+
return strings.ToLower(headers.Get("content-type")) == "text/event-stream"
62+
}

connection/h2mux.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,14 @@ type h2muxRespWriter struct {
205205

206206
func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
207207
headers := h2mux.H1ResponseToH2ResponseHeaders(resp)
208-
headers = append(headers, h2mux.Header{Name: responseMetaHeaderField, Value: responseSourceOrigin})
208+
headers = append(headers, h2mux.Header{Name: responseMetaHeaderField, Value: responseMetaHeaderOrigin})
209209
return rp.WriteHeaders(headers)
210210
}
211211

212212
func (rp *h2muxRespWriter) WriteErrorResponse(err error) {
213213
rp.WriteHeaders([]h2mux.Header{
214214
{Name: ":status", Value: "502"},
215-
{Name: responseMetaHeaderField, Value: responseSourceCloudflared},
215+
{Name: responseMetaHeaderField, Value: responseMetaHeaderCfd},
216216
})
217217
rp.Write([]byte("502 Bad Gateway"))
218218
}

connection/header.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ type responseMetaHeader struct {
2525
func mustInitRespMetaHeader(src string) string {
2626
header, err := json.Marshal(responseMetaHeader{Source: src})
2727
if err != nil {
28-
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", responseSourceCloudflared, err))
28+
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
2929
}
3030
return string(header)
3131
}

connection/http2.go

Lines changed: 45 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package connection
22

33
import (
44
"context"
5-
"fmt"
5+
"errors"
66
"io"
77
"math"
88
"net"
@@ -23,6 +23,10 @@ const (
2323
controlStreamUpgrade = "control-stream"
2424
)
2525

26+
var (
27+
errNotFlusher = errors.New("ResponseWriter doesn't implement http.Flusher")
28+
)
29+
2630
type HTTP2Connection struct {
2731
conn net.Conn
2832
server *http2.Server
@@ -37,7 +41,16 @@ type HTTP2Connection struct {
3741
connectedFuse ConnectedFuse
3842
}
3943

40-
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) *HTTP2Connection {
44+
func NewHTTP2Connection(
45+
conn net.Conn,
46+
config *Config,
47+
originURL *url.URL,
48+
namedTunnelConfig *NamedTunnelConfig,
49+
connOptions *tunnelpogs.ConnectionOptions,
50+
observer *Observer,
51+
connIndex uint8,
52+
connectedFuse ConnectedFuse,
53+
) *HTTP2Connection {
4154
return &HTTP2Connection{
4255
conn: conn,
4356
server: &http2.Server{
@@ -77,34 +90,33 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
7790
r: r.Body,
7891
w: w,
7992
}
93+
flusher, isFlusher := w.(http.Flusher)
94+
if !isFlusher {
95+
c.observer.Errorf("%T doesn't implement http.Flusher", w)
96+
respWriter.WriteErrorResponse(errNotFlusher)
97+
return
98+
}
99+
respWriter.flusher = flusher
80100
if isControlStreamUpgrade(r) {
101+
respWriter.shouldFlush = true
81102
err := c.serveControlStream(r.Context(), respWriter)
82103
if err != nil {
83104
respWriter.WriteErrorResponse(err)
84105
}
85106
} else if isWebsocketUpgrade(r) {
86-
wsRespWriter, err := newWSRespWriter(respWriter)
87-
if err != nil {
88-
respWriter.WriteErrorResponse(err)
89-
return
90-
}
107+
respWriter.shouldFlush = true
91108
stripWebsocketUpgradeHeader(r)
92-
c.config.OriginClient.Proxy(wsRespWriter, r, true)
109+
c.config.OriginClient.Proxy(respWriter, r, true)
93110
} else {
94111
c.config.OriginClient.Proxy(respWriter, r, false)
95112
}
96113
}
97114

98-
func (c *HTTP2Connection) serveControlStream(ctx context.Context, h2RespWriter *http2RespWriter) error {
99-
stream, err := newWSRespWriter(h2RespWriter)
100-
if err != nil {
101-
return err
102-
}
103-
104-
rpcClient := newRegistrationRPCClient(ctx, stream, c.observer)
115+
func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
116+
rpcClient := newRegistrationRPCClient(ctx, respWriter, c.observer)
105117
defer rpcClient.close()
106118

107-
if err = registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
119+
if err := registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
108120
return err
109121
}
110122
c.connectedFuse.Connected()
@@ -146,8 +158,10 @@ func (c *HTTP2Connection) close() {
146158
}
147159

148160
type http2RespWriter struct {
149-
r io.Reader
150-
w http.ResponseWriter
161+
r io.Reader
162+
w http.ResponseWriter
163+
flusher http.Flusher
164+
shouldFlush bool
151165
}
152166

153167
func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
@@ -172,13 +186,19 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
172186

173187
// Perform user header serialization and set them in the single header
174188
dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
175-
rp.setResponseMetaHeader(responseMetaHeaderCfd)
189+
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
176190
status := resp.StatusCode
177191
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
178192
if status == http.StatusSwitchingProtocols {
179193
status = http.StatusOK
180194
}
181195
rp.w.WriteHeader(status)
196+
if isServerSentEvent(resp.Header) {
197+
rp.shouldFlush = true
198+
}
199+
if rp.shouldFlush {
200+
rp.flusher.Flush()
201+
}
182202
return nil
183203
}
184204

@@ -195,43 +215,15 @@ func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
195215
return rp.r.Read(p)
196216
}
197217

198-
func (wr *http2RespWriter) Write(p []byte) (n int, err error) {
199-
return wr.w.Write(p)
200-
}
201-
202-
type wsRespWriter struct {
203-
*http2RespWriter
204-
flusher http.Flusher
205-
}
206-
207-
func newWSRespWriter(h2 *http2RespWriter) (*wsRespWriter, error) {
208-
flusher, ok := h2.w.(http.Flusher)
209-
if !ok {
210-
return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher")
211-
}
212-
return &wsRespWriter{
213-
h2,
214-
flusher,
215-
}, nil
216-
}
217-
218-
func (rw *wsRespWriter) WriteRespHeaders(resp *http.Response) (err error) {
219-
err = rw.http2RespWriter.WriteRespHeaders(resp)
220-
if err == nil {
221-
rw.flusher.Flush()
222-
}
223-
return
224-
}
225-
226-
func (rw *wsRespWriter) Write(p []byte) (n int, err error) {
227-
n, err = rw.http2RespWriter.Write(p)
228-
if err == nil {
229-
rw.flusher.Flush()
218+
func (rp *http2RespWriter) Write(p []byte) (n int, err error) {
219+
n, err = rp.w.Write(p)
220+
if err == nil && rp.shouldFlush {
221+
rp.flusher.Flush()
230222
}
231-
return
223+
return n, err
232224
}
233225

234-
func (rw *wsRespWriter) Close() error {
226+
func (rp *http2RespWriter) Close() error {
235227
return nil
236228
}
237229

hello/hello.go

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ import (
1919
)
2020

2121
const (
22-
UptimeRoute = "/uptime"
23-
WSRoute = "/ws"
22+
UptimeRoute = "/uptime"
23+
WSRoute = "/ws"
24+
SSERoute = "/sse"
25+
defaultSSEFreq = time.Second * 10
2426
)
2527

2628
type templateData struct {
@@ -111,6 +113,7 @@ func StartHelloWorldServer(logger logger.Service, listener net.Listener, shutdow
111113
muxer := http.NewServeMux()
112114
muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now()))
113115
muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader))
116+
muxer.HandleFunc(SSERoute, sseHandler(logger))
114117
muxer.HandleFunc("/", rootHandler(serverName))
115118
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer}
116119
go func() {
@@ -182,6 +185,42 @@ func websocketHandler(logger logger.Service, upgrader websocket.Upgrader) http.H
182185
}
183186
}
184187

188+
func sseHandler(logger logger.Service) http.HandlerFunc {
189+
return func(w http.ResponseWriter, r *http.Request) {
190+
w.Header().Set("Content-Type", "text/event-stream")
191+
flusher, ok := w.(http.Flusher)
192+
if !ok {
193+
w.WriteHeader(http.StatusInternalServerError)
194+
logger.Errorf("Can't support SSE. ResponseWriter %T doesn't implement http.Flusher interface", w)
195+
return
196+
}
197+
198+
freq := defaultSSEFreq
199+
if requestedFreq := r.URL.Query()["freq"]; len(requestedFreq) > 0 {
200+
parsedFreq, err := time.ParseDuration(requestedFreq[0])
201+
if err == nil {
202+
freq = parsedFreq
203+
}
204+
}
205+
logger.Infof("Server Sent Events every %s", freq)
206+
ticker := time.NewTicker(freq)
207+
counter := 0
208+
for {
209+
select {
210+
case <-r.Context().Done():
211+
return
212+
case <-ticker.C:
213+
}
214+
_, err := fmt.Fprintf(w, "%d\n\n", counter)
215+
if err != nil {
216+
return
217+
}
218+
flusher.Flush()
219+
counter++
220+
}
221+
}
222+
}
223+
185224
func rootHandler(serverName string) http.HandlerFunc {
186225
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
187226
return func(w http.ResponseWriter, r *http.Request) {

origin/proxy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt
9999
return nil, errors.Wrap(err, "Error writing response header")
100100
}
101101
if isEventStream(resp) {
102-
//h.observer.Debug("Detected Server-Side Events from Origin")
102+
c.logger.Debug("Detected Server-Side Events from Origin")
103103
c.writeEventStream(w, resp.Body)
104104
} else {
105105
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream

origin/proxy_test.go

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"net/url"
1313
"sync"
1414
"testing"
15+
"time"
1516

1617
"github.com/cloudflare/cloudflared/connection"
1718
"github.com/cloudflare/cloudflared/hello"
@@ -55,9 +56,9 @@ type mockWSRespWriter struct {
5556
reader io.Reader
5657
}
5758

58-
func newMockWSRespWriter(httpRespWriter *mockHTTPRespWriter, reader io.Reader) *mockWSRespWriter {
59+
func newMockWSRespWriter(reader io.Reader) *mockWSRespWriter {
5960
return &mockWSRespWriter{
60-
httpRespWriter,
61+
newMockHTTPRespWriter(),
6162
make(chan []byte),
6263
reader,
6364
}
@@ -77,6 +78,27 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) {
7778
return w.reader.Read(data)
7879
}
7980

81+
type mockSSERespWriter struct {
82+
*mockHTTPRespWriter
83+
writeNotification chan []byte
84+
}
85+
86+
func newMockSSERespWriter() *mockSSERespWriter {
87+
return &mockSSERespWriter{
88+
newMockHTTPRespWriter(),
89+
make(chan []byte),
90+
}
91+
}
92+
93+
func (w *mockSSERespWriter) Write(data []byte) (int, error) {
94+
w.writeNotification <- data
95+
return len(data), nil
96+
}
97+
98+
func (w *mockSSERespWriter) ReadBytes() []byte {
99+
return <-w.writeNotification
100+
}
101+
80102
func TestProxy(t *testing.T) {
81103
logger, err := logger.New()
82104
require.NoError(t, err)
@@ -112,6 +134,7 @@ func TestProxy(t *testing.T) {
112134
client := NewClient(proxyConfig, logger)
113135
t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL))
114136
t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS))
137+
t.Run("testProxySSE", testProxySSE(t, client, originURL))
115138
cancel()
116139
}
117140

@@ -135,7 +158,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL
135158
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil)
136159

137160
readPipe, writePipe := io.Pipe()
138-
respWriter := newMockWSRespWriter(newMockHTTPRespWriter(), readPipe)
161+
respWriter := newMockWSRespWriter(readPipe)
139162

140163
var wg sync.WaitGroup
141164
wg.Add(1)
@@ -167,3 +190,38 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL
167190
wg.Wait()
168191
}
169192
}
193+
194+
func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) {
195+
return func(t *testing.T) {
196+
var (
197+
pushCount = 50
198+
pushFreq = time.Duration(time.Millisecond * 10)
199+
)
200+
respWriter := newMockSSERespWriter()
201+
ctx, cancel := context.WithCancel(context.Background())
202+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s?freq=%s", originURL, hello.SSERoute, pushFreq), nil)
203+
require.NoError(t, err)
204+
205+
var wg sync.WaitGroup
206+
wg.Add(1)
207+
go func() {
208+
defer wg.Done()
209+
err = client.Proxy(respWriter, req, false)
210+
require.NoError(t, err)
211+
212+
require.Equal(t, http.StatusOK, respWriter.Code)
213+
}()
214+
215+
for i := 0; i < pushCount; i++ {
216+
line := respWriter.ReadBytes()
217+
expect := fmt.Sprintf("%d\n", i)
218+
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
219+
220+
line = respWriter.ReadBytes()
221+
require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line))
222+
}
223+
224+
cancel()
225+
wg.Wait()
226+
}
227+
}

0 commit comments

Comments
 (0)