Skip to content

Commit fdb1f96

Browse files
committed
TUN-3557: Detect SSE if content-type starts with text/event-stream
1 parent 293b9af commit fdb1f96

File tree

5 files changed

+49
-18
lines changed

5 files changed

+49
-18
lines changed

connection/connection.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,13 @@ type ConnectedFuse interface {
5353
IsConnected() bool
5454
}
5555

56-
func uint8ToString(input uint8) string {
57-
return strconv.FormatUint(uint64(input), 10)
56+
func IsServerSentEvent(headers http.Header) bool {
57+
if contentType := headers.Get("content-type"); contentType != "" {
58+
return strings.HasPrefix(strings.ToLower(contentType), "text/event-stream")
59+
}
60+
return false
5861
}
5962

60-
func isServerSentEvent(headers http.Header) bool {
61-
return strings.ToLower(headers.Get("content-type")) == "text/event-stream"
63+
func uint8ToString(input uint8) string {
64+
return strconv.FormatUint(uint64(input), 10)
6265
}

connection/connection_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ import (
55
"io"
66
"net/http"
77
"net/url"
8+
"testing"
89
"time"
910

1011
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
1112
"github.com/cloudflare/cloudflared/logger"
1213
"github.com/gobwas/ws/wsutil"
14+
"github.com/stretchr/testify/assert"
1315
)
1416

1517
const (
@@ -111,3 +113,40 @@ func (mcf mockConnectedFuse) Connected() {}
111113
func (mcf mockConnectedFuse) IsConnected() bool {
112114
return true
113115
}
116+
117+
func TestIsEventStream(t *testing.T) {
118+
tests := []struct {
119+
headers http.Header
120+
isEventStream bool
121+
}{
122+
{
123+
headers: newHeader("Content-Type", "text/event-stream"),
124+
isEventStream: true,
125+
},
126+
{
127+
headers: newHeader("content-type", "text/event-stream"),
128+
isEventStream: true,
129+
},
130+
{
131+
headers: newHeader("Content-Type", "text/event-stream; charset=utf-8"),
132+
isEventStream: true,
133+
},
134+
{
135+
headers: newHeader("Content-Type", "application/json"),
136+
isEventStream: false,
137+
},
138+
{
139+
headers: http.Header{},
140+
isEventStream: false,
141+
},
142+
}
143+
for _, test := range tests {
144+
assert.Equal(t, test.isEventStream, IsServerSentEvent(test.headers))
145+
}
146+
}
147+
148+
func newHeader(key, value string) http.Header {
149+
header := http.Header{}
150+
header.Add(key, value)
151+
return header
152+
}

connection/http2.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
167167
status = http.StatusOK
168168
}
169169
rp.w.WriteHeader(status)
170-
if isServerSentEvent(resp.Header) {
170+
if IsServerSentEvent(resp.Header) {
171171
rp.shouldFlush = true
172172
}
173173
if rp.shouldFlush {

hello/hello.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ func websocketHandler(logger logger.Service, upgrader websocket.Upgrader) http.H
189189

190190
func sseHandler(logger logger.Service) http.HandlerFunc {
191191
return func(w http.ResponseWriter, r *http.Request) {
192-
w.Header().Set("Content-Type", "text/event-stream")
192+
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
193193
flusher, ok := w.(http.Flusher)
194194
if !ok {
195195
w.WriteHeader(http.StatusInternalServerError)

origin/proxy.go

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule
9696
if err != nil {
9797
return nil, errors.Wrap(err, "Error writing response header")
9898
}
99-
if isEventStream(resp) {
99+
if connection.IsServerSentEvent(resp.Header) {
100100
c.logger.Debug("Detected Server-Side Events from Origin")
101101
c.writeEventStream(w, resp.Body)
102102
} else {
@@ -222,14 +222,3 @@ func findCfRayHeader(req *http.Request) string {
222222
func isLBProbeRequest(req *http.Request) bool {
223223
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
224224
}
225-
226-
func uint8ToString(input uint8) string {
227-
return strconv.FormatUint(uint64(input), 10)
228-
}
229-
230-
func isEventStream(response *http.Response) bool {
231-
if response.Header.Get("content-type") == "text/event-stream" {
232-
return true
233-
}
234-
return false
235-
}

0 commit comments

Comments
 (0)