@@ -377,6 +377,77 @@ func TestForwardResponseMessage(t *testing.T) {
377377}
378378
379379func TestOutgoingHeaderMatcher (t * testing.T ) {
380+ t .Parallel ()
381+ msg := & pb.SimpleMessage {Id : "foo" }
382+ for _ , tc := range []struct {
383+ name string
384+ md runtime.ServerMetadata
385+ headers http.Header
386+ matcher runtime.HeaderMatcherFunc
387+ }{
388+ {
389+ name : "default matcher" ,
390+ md : runtime.ServerMetadata {
391+ HeaderMD : metadata .Pairs (
392+ "foo" , "bar" ,
393+ "baz" , "qux" ,
394+ ),
395+ },
396+ headers : http.Header {
397+ "Content-Type" : []string {"application/json" },
398+ "Grpc-Metadata-Foo" : []string {"bar" },
399+ "Grpc-Metadata-Baz" : []string {"qux" },
400+ },
401+ },
402+ {
403+ name : "custom matcher" ,
404+ md : runtime.ServerMetadata {
405+ HeaderMD : metadata .Pairs (
406+ "foo" , "bar" ,
407+ "baz" , "qux" ,
408+ ),
409+ },
410+ headers : http.Header {
411+ "Content-Type" : []string {"application/json" },
412+ "Custom-Foo" : []string {"bar" },
413+ },
414+ matcher : func (key string ) (string , bool ) {
415+ switch key {
416+ case "foo" :
417+ return "custom-foo" , true
418+ default :
419+ return "" , false
420+ }
421+ },
422+ },
423+ } {
424+ tc := tc
425+ t .Run (tc .name , func (t * testing.T ) {
426+ t .Parallel ()
427+ ctx := runtime .NewServerMetadataContext (context .Background (), tc .md )
428+
429+ req := httptest .NewRequest ("GET" , "http://example.com/foo" , nil )
430+ resp := httptest .NewRecorder ()
431+
432+ mux := runtime .NewServeMux (
433+ runtime .WithOutgoingHeaderMatcher (tc .matcher ),
434+ )
435+ runtime .ForwardResponseMessage (ctx , mux , & runtime.JSONPb {}, resp , req , msg )
436+
437+ w := resp .Result ()
438+ defer w .Body .Close ()
439+ if w .StatusCode != http .StatusOK {
440+ t .Fatalf ("StatusCode %d want %d" , w .StatusCode , http .StatusOK )
441+ }
442+
443+ if ! reflect .DeepEqual (w .Header , tc .headers ) {
444+ t .Fatalf ("Header %v want %v" , w .Header , tc .headers )
445+ }
446+ })
447+ }
448+ }
449+
450+ func TestOutgoingHeaderMatcherWithContentLength (t * testing.T ) {
380451 t .Parallel ()
381452 msg := & pb.SimpleMessage {Id : "foo" }
382453 for _ , tc := range []struct {
@@ -431,7 +502,11 @@ func TestOutgoingHeaderMatcher(t *testing.T) {
431502 req := httptest .NewRequest ("GET" , "http://example.com/foo" , nil )
432503 resp := httptest .NewRecorder ()
433504
434- runtime .ForwardResponseMessage (ctx , runtime .NewServeMux (runtime .WithOutgoingHeaderMatcher (tc .matcher )), & runtime.JSONPb {}, resp , req , msg )
505+ mux := runtime .NewServeMux (
506+ runtime .WithOutgoingHeaderMatcher (tc .matcher ),
507+ runtime .WithWriteContentLength (),
508+ )
509+ runtime .ForwardResponseMessage (ctx , mux , & runtime.JSONPb {}, resp , req , msg )
435510
436511 w := resp .Result ()
437512 defer w .Body .Close ()
@@ -529,7 +604,11 @@ func TestOutgoingTrailerMatcher(t *testing.T) {
529604 req .Header = tc .caller
530605 resp := httptest .NewRecorder ()
531606
532- runtime .ForwardResponseMessage (ctx , runtime .NewServeMux (runtime .WithOutgoingTrailerMatcher (tc .matcher )), & runtime.JSONPb {}, resp , req , msg )
607+ mux := runtime .NewServeMux (
608+ runtime .WithOutgoingTrailerMatcher (tc .matcher ),
609+ runtime .WithWriteContentLength (),
610+ )
611+ runtime .ForwardResponseMessage (ctx , mux , & runtime.JSONPb {}, resp , req , msg )
533612
534613 w := resp .Result ()
535614 _ , _ = io .Copy (io .Discard , w .Body )
0 commit comments