@@ -411,15 +411,26 @@ func TestConnections(t *testing.T) {
411411 originService func (* testing.T , net.Listener )
412412 eyeballService connection.ResponseWriter
413413 connectionType connection.Type
414+ requestHeaders http.Header
414415 wantMessage []byte
416+ wantHeaders http.Header
415417 }{
416418 {
417419 name : "ws-ws proxy" ,
418420 ingressServicePrefix : "ws://" ,
419421 originService : runEchoWSService ,
420422 eyeballService : newWSRespWriter ([]byte ("test1" ), replayer ),
421423 connectionType : connection .TypeWebsocket ,
422- wantMessage : []byte ("test1" ),
424+ requestHeaders : map [string ][]string {
425+ "Test-Cloudflared-Echo" : []string {"Echo" },
426+ },
427+ wantMessage : []byte ("echo-test1" ),
428+ wantHeaders : map [string ][]string {
429+ "Connection" : []string {"Upgrade" },
430+ "Sec-Websocket-Accept" : []string {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w=" },
431+ "Upgrade" : []string {"websocket" },
432+ "Test-Cloudflared-Echo" : []string {"Echo" },
433+ },
423434 },
424435 {
425436 name : "tcp-tcp proxy" ,
@@ -430,15 +441,25 @@ func TestConnections(t *testing.T) {
430441 replayer ,
431442 ),
432443 connectionType : connection .TypeTCP ,
433- wantMessage : []byte ("echo-test2" ),
444+ requestHeaders : map [string ][]string {
445+ "Cf-Cloudflared-Proxy-Src" : []string {"non-blank-value" },
446+ },
447+ wantMessage : []byte ("echo-test2" ),
448+ wantHeaders : http.Header {},
434449 },
435450 {
436451 name : "tcp-ws proxy" ,
437452 ingressServicePrefix : "ws://" ,
438453 originService : runEchoWSService ,
439454 eyeballService : newPipedWSWriter (& mockTCPRespWriter {}, []byte ("test3" )),
440- connectionType : connection .TypeTCP ,
441- wantMessage : []byte ("test3" ),
455+ requestHeaders : map [string ][]string {
456+ "Cf-Cloudflared-Proxy-Src" : []string {"non-blank-value" },
457+ },
458+ connectionType : connection .TypeTCP ,
459+ wantMessage : []byte ("echo-test3" ),
460+ // We expect no headers here because they are sent back via
461+ // the stream.
462+ wantHeaders : http.Header {},
442463 },
443464 {
444465 name : "ws-tcp proxy" ,
@@ -447,14 +468,12 @@ func TestConnections(t *testing.T) {
447468 eyeballService : newWSRespWriter ([]byte ("test4" ), replayer ),
448469 connectionType : connection .TypeWebsocket ,
449470 wantMessage : []byte ("echo-test4" ),
471+ wantHeaders : http.Header {},
450472 },
451473 }
452474
453475 for _ , test := range tests {
454476 t .Run (test .name , func (t * testing.T ) {
455- if test .skip {
456- t .Skip ("todo: skipping a failing test. THis should be fixed before merge" )
457- }
458477 ctx , cancel := context .WithCancel (context .Background ())
459478 ln , err := net .Listen ("tcp" , "127.0.0.1:0" )
460479 require .NoError (t , err )
@@ -466,29 +485,41 @@ func TestConnections(t *testing.T) {
466485 proxy := NewOriginProxy (ingressRule , ingress .NewWarpRoutingService (), testTags , logger )
467486 req , err := http .NewRequest (http .MethodGet , test .ingressServicePrefix + ln .Addr ().String (), nil )
468487 require .NoError (t , err )
469- req .Header .Set ("Cf-Cloudflared-Proxy-Src" , "non-blank-value" )
488+ reqHeaders := make (http.Header )
489+ for k , vs := range test .requestHeaders {
490+ reqHeaders [k ] = vs
491+ }
492+ req .Header = reqHeaders
470493
471494 if pipedWS , ok := test .eyeballService .(* pipedWSWriter ); ok {
472495 go func () {
473496 resp := pipedWS .roundtrip (test .ingressServicePrefix + ln .Addr ().String ())
474497 replayer .Write (resp )
475498 }()
476499 }
500+
477501 err = proxy .Proxy (test .eyeballService , req , test .connectionType )
478502 require .NoError (t , err )
479503
480504 cancel ()
481505 assert .Equal (t , test .wantMessage , replayer .Bytes ())
506+ respPrinter := test .eyeballService .(responsePrinter )
507+ assert .Equal (t , test .wantHeaders , respPrinter .printRespHeaders ())
482508 replayer .rw .Reset ()
483509 })
484510 }
485511}
486512
513+ type responsePrinter interface {
514+ printRespHeaders () http.Header
515+ }
516+
487517type pipedWSWriter struct {
488518 dialer gorillaWS.Dialer
489519 wsConn net.Conn
490520 pipedConn net.Conn
491521 respWriter connection.ResponseWriter
522+ respHeaders http.Header
492523 messageToWrite []byte
493524}
494525
@@ -547,14 +578,21 @@ func (p *pipedWSWriter) WriteErrorResponse() {
547578}
548579
549580func (p * pipedWSWriter ) WriteRespHeaders (status int , header http.Header ) error {
581+ p .respHeaders = header
550582 return nil
551583}
552584
585+ // printRespHeaders is a test function to read respHeaders
586+ func (p * pipedWSWriter ) printRespHeaders () http.Header {
587+ return p .respHeaders
588+ }
589+
553590type wsRespWriter struct {
554- w io.Writer
555- pr * io.PipeReader
556- pw * io.PipeWriter
557- code int
591+ w io.Writer
592+ pr * io.PipeReader
593+ pw * io.PipeWriter
594+ respHeaders http.Header
595+ code int
558596}
559597
560598// newWSRespWriter uses wsutil.WriteClientText to generate websocket frames.
@@ -589,13 +627,19 @@ func (w *wsRespWriter) Write(p []byte) (int, error) {
589627}
590628
591629func (w * wsRespWriter ) WriteRespHeaders (status int , header http.Header ) error {
630+ w .respHeaders = header
592631 w .code = status
593632 return nil
594633}
595634
596635func (w * wsRespWriter ) WriteErrorResponse () {
597636}
598637
638+ // printRespHeaders is a test function to read respHeaders
639+ func (w * wsRespWriter ) printRespHeaders () http.Header {
640+ return w .respHeaders
641+ }
642+
599643func runEchoTCPService (t * testing.T , l net.Listener ) {
600644 go func () {
601645 for {
@@ -628,7 +672,13 @@ func runEchoWSService(t *testing.T, l net.Listener) {
628672 }
629673
630674 var ws = func (w http.ResponseWriter , r * http.Request ) {
631- conn , err := upgrader .Upgrade (w , r , nil )
675+ header := make (http.Header )
676+ for k , vs := range r .Header {
677+ if k == "Test-Cloudflared-Echo" {
678+ header [k ] = vs
679+ }
680+ }
681+ conn , err := upgrader .Upgrade (w , r , header )
632682 require .NoError (t , err )
633683 defer conn .Close ()
634684
@@ -637,8 +687,9 @@ func runEchoWSService(t *testing.T, l net.Listener) {
637687 if err != nil {
638688 return
639689 }
640-
641- if err := conn .WriteMessage (messageType , p ); err != nil {
690+ data := []byte ("echo-" )
691+ data = append (data , p ... )
692+ if err := conn .WriteMessage (messageType , data ); err != nil {
642693 return
643694 }
644695 }
@@ -672,10 +723,11 @@ type tcpWrappedWs struct {
672723}
673724
674725type mockTCPRespWriter struct {
675- w io.Writer
676- pr io.Reader
677- pw * io.PipeWriter
678- code int
726+ w io.Writer
727+ pr io.Reader
728+ pw * io.PipeWriter
729+ respHeaders http.Header
730+ code int
679731}
680732
681733func newTCPRespWriter (data []byte , w io.Writer ) * mockTCPRespWriter {
@@ -701,6 +753,12 @@ func (m *mockTCPRespWriter) WriteErrorResponse() {
701753}
702754
703755func (m * mockTCPRespWriter ) WriteRespHeaders (status int , header http.Header ) error {
756+ m .respHeaders = header
704757 m .code = status
705758 return nil
706759}
760+
761+ // printRespHeaders is a test function to read respHeaders
762+ func (m * mockTCPRespWriter ) printRespHeaders () http.Header {
763+ return m .respHeaders
764+ }
0 commit comments