@@ -2,7 +2,7 @@ package connection
22
33import (
44 "context"
5- "fmt "
5+ "errors "
66 "io"
77 "math"
88 "net"
@@ -23,6 +23,10 @@ const (
2323 controlStreamUpgrade = "control-stream"
2424)
2525
26+ var (
27+ errNotFlusher = errors .New ("ResponseWriter doesn't implement http.Flusher" )
28+ )
29+
2630type HTTP2Connection struct {
2731 conn net.Conn
2832 server * http2.Server
@@ -37,7 +41,16 @@ type HTTP2Connection struct {
3741 connectedFuse ConnectedFuse
3842}
3943
40- func NewHTTP2Connection (conn net.Conn , config * Config , originURL * url.URL , namedTunnelConfig * NamedTunnelConfig , connOptions * tunnelpogs.ConnectionOptions , observer * Observer , connIndex uint8 , connectedFuse ConnectedFuse ) * HTTP2Connection {
44+ func NewHTTP2Connection (
45+ conn net.Conn ,
46+ config * Config ,
47+ originURL * url.URL ,
48+ namedTunnelConfig * NamedTunnelConfig ,
49+ connOptions * tunnelpogs.ConnectionOptions ,
50+ observer * Observer ,
51+ connIndex uint8 ,
52+ connectedFuse ConnectedFuse ,
53+ ) * HTTP2Connection {
4154 return & HTTP2Connection {
4255 conn : conn ,
4356 server : & http2.Server {
@@ -77,34 +90,33 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
7790 r : r .Body ,
7891 w : w ,
7992 }
93+ flusher , isFlusher := w .(http.Flusher )
94+ if ! isFlusher {
95+ c .observer .Errorf ("%T doesn't implement http.Flusher" , w )
96+ respWriter .WriteErrorResponse (errNotFlusher )
97+ return
98+ }
99+ respWriter .flusher = flusher
80100 if isControlStreamUpgrade (r ) {
101+ respWriter .shouldFlush = true
81102 err := c .serveControlStream (r .Context (), respWriter )
82103 if err != nil {
83104 respWriter .WriteErrorResponse (err )
84105 }
85106 } else if isWebsocketUpgrade (r ) {
86- wsRespWriter , err := newWSRespWriter (respWriter )
87- if err != nil {
88- respWriter .WriteErrorResponse (err )
89- return
90- }
107+ respWriter .shouldFlush = true
91108 stripWebsocketUpgradeHeader (r )
92- c .config .OriginClient .Proxy (wsRespWriter , r , true )
109+ c .config .OriginClient .Proxy (respWriter , r , true )
93110 } else {
94111 c .config .OriginClient .Proxy (respWriter , r , false )
95112 }
96113}
97114
98- func (c * HTTP2Connection ) serveControlStream (ctx context.Context , h2RespWriter * http2RespWriter ) error {
99- stream , err := newWSRespWriter (h2RespWriter )
100- if err != nil {
101- return err
102- }
103-
104- rpcClient := newRegistrationRPCClient (ctx , stream , c .observer )
115+ func (c * HTTP2Connection ) serveControlStream (ctx context.Context , respWriter * http2RespWriter ) error {
116+ rpcClient := newRegistrationRPCClient (ctx , respWriter , c .observer )
105117 defer rpcClient .close ()
106118
107- if err = registerConnection (ctx , rpcClient , c .namedTunnel , c .connOptions , c .connIndex , c .observer ); err != nil {
119+ if err : = registerConnection (ctx , rpcClient , c .namedTunnel , c .connOptions , c .connIndex , c .observer ); err != nil {
108120 return err
109121 }
110122 c .connectedFuse .Connected ()
@@ -146,8 +158,10 @@ func (c *HTTP2Connection) close() {
146158}
147159
148160type http2RespWriter struct {
149- r io.Reader
150- w http.ResponseWriter
161+ r io.Reader
162+ w http.ResponseWriter
163+ flusher http.Flusher
164+ shouldFlush bool
151165}
152166
153167func (rp * http2RespWriter ) WriteRespHeaders (resp * http.Response ) error {
@@ -172,13 +186,19 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
172186
173187 // Perform user header serialization and set them in the single header
174188 dest .Set (canonicalResponseUserHeadersField , h2mux .SerializeHeaders (userHeaders ))
175- rp .setResponseMetaHeader (responseMetaHeaderCfd )
189+ rp .setResponseMetaHeader (responseMetaHeaderOrigin )
176190 status := resp .StatusCode
177191 // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
178192 if status == http .StatusSwitchingProtocols {
179193 status = http .StatusOK
180194 }
181195 rp .w .WriteHeader (status )
196+ if isServerSentEvent (resp .Header ) {
197+ rp .shouldFlush = true
198+ }
199+ if rp .shouldFlush {
200+ rp .flusher .Flush ()
201+ }
182202 return nil
183203}
184204
@@ -195,43 +215,15 @@ func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
195215 return rp .r .Read (p )
196216}
197217
198- func (wr * http2RespWriter ) Write (p []byte ) (n int , err error ) {
199- return wr .w .Write (p )
200- }
201-
202- type wsRespWriter struct {
203- * http2RespWriter
204- flusher http.Flusher
205- }
206-
207- func newWSRespWriter (h2 * http2RespWriter ) (* wsRespWriter , error ) {
208- flusher , ok := h2 .w .(http.Flusher )
209- if ! ok {
210- return nil , fmt .Errorf ("ResponseWriter doesn't implement http.Flusher" )
211- }
212- return & wsRespWriter {
213- h2 ,
214- flusher ,
215- }, nil
216- }
217-
218- func (rw * wsRespWriter ) WriteRespHeaders (resp * http.Response ) (err error ) {
219- err = rw .http2RespWriter .WriteRespHeaders (resp )
220- if err == nil {
221- rw .flusher .Flush ()
222- }
223- return
224- }
225-
226- func (rw * wsRespWriter ) Write (p []byte ) (n int , err error ) {
227- n , err = rw .http2RespWriter .Write (p )
228- if err == nil {
229- rw .flusher .Flush ()
218+ func (rp * http2RespWriter ) Write (p []byte ) (n int , err error ) {
219+ n , err = rp .w .Write (p )
220+ if err == nil && rp .shouldFlush {
221+ rp .flusher .Flush ()
230222 }
231- return
223+ return n , err
232224}
233225
234- func (rw * wsRespWriter ) Close () error {
226+ func (rp * http2RespWriter ) Close () error {
235227 return nil
236228}
237229
0 commit comments