@@ -98,7 +98,7 @@ describe("StreamableHTTPServerTransport", () => {
98
98
99
99
await transport . handleRequest ( req , mockResponse ) ;
100
100
101
- expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 404 ) ;
101
+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 404 , { } ) ;
102
102
// check if the error response is a valid JSON-RPC error format
103
103
expect ( mockResponse . end ) . toHaveBeenCalledWith ( expect . stringContaining ( '"jsonrpc":"2.0"' ) ) ;
104
104
expect ( mockResponse . end ) . toHaveBeenCalledWith ( expect . stringContaining ( '"error"' ) ) ;
@@ -115,7 +115,7 @@ describe("StreamableHTTPServerTransport", () => {
115
115
116
116
await transport . handleRequest ( req , mockResponse ) ;
117
117
118
- expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 400 ) ;
118
+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 400 , { } ) ;
119
119
expect ( mockResponse . end ) . toHaveBeenCalledWith ( expect . stringContaining ( '"jsonrpc":"2.0"' ) ) ;
120
120
expect ( mockResponse . end ) . toHaveBeenCalledWith ( expect . stringContaining ( '"message":"Bad Request: Mcp-Session-Id header is required"' ) ) ;
121
121
} ) ;
@@ -342,7 +342,7 @@ describe("StreamableHTTPServerTransport", () => {
342
342
343
343
await transport . handleRequest ( req , mockResponse ) ;
344
344
345
- expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 406 ) ;
345
+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith ( 406 , { } ) ;
346
346
expect ( mockResponse . end ) . toHaveBeenCalledWith ( expect . stringContaining ( '"jsonrpc":"2.0"' ) ) ;
347
347
} ) ;
348
348
@@ -788,4 +788,141 @@ describe("StreamableHTTPServerTransport", () => {
788
788
expect ( onMessageMock ) . not . toHaveBeenCalledWith ( requestBodyMessage ) ;
789
789
} ) ;
790
790
} ) ;
791
+
792
+ describe ( "Custom Headers" , ( ) => {
793
+ const customHeaders = {
794
+ "X-Custom-Header" : "custom-value" ,
795
+ "X-API-Version" : "1.0" ,
796
+ "Access-Control-Allow-Origin" : "*"
797
+ } ;
798
+
799
+ let transportWithHeaders : StreamableHTTPServerTransport ;
800
+ let mockResponse : jest . Mocked < ServerResponse > ;
801
+
802
+ beforeEach ( ( ) => {
803
+ transportWithHeaders = new StreamableHTTPServerTransport ( endpoint , { customHeaders } ) ;
804
+ mockResponse = createMockResponse ( ) ;
805
+ } ) ;
806
+
807
+ it ( "should include custom headers in SSE response" , async ( ) => {
808
+ const req = createMockRequest ( {
809
+ method : "GET" ,
810
+ headers : {
811
+ accept : "text/event-stream" ,
812
+ "mcp-session-id" : transportWithHeaders . sessionId
813
+ } ,
814
+ } ) ;
815
+
816
+ await transportWithHeaders . handleRequest ( req , mockResponse ) ;
817
+
818
+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith (
819
+ 200 ,
820
+ expect . objectContaining ( {
821
+ ...customHeaders ,
822
+ "Content-Type" : "text/event-stream" ,
823
+ "Cache-Control" : "no-cache" ,
824
+ "Connection" : "keep-alive" ,
825
+ "mcp-session-id" : transportWithHeaders . sessionId
826
+ } )
827
+ ) ;
828
+ } ) ;
829
+
830
+ it ( "should include custom headers in JSON response" , async ( ) => {
831
+ const message : JSONRPCMessage = {
832
+ jsonrpc : "2.0" ,
833
+ method : "test" ,
834
+ params : { } ,
835
+ id : 1 ,
836
+ } ;
837
+
838
+ const req = createMockRequest ( {
839
+ method : "POST" ,
840
+ headers : {
841
+ "content-type" : "application/json" ,
842
+ "accept" : "application/json" ,
843
+ "mcp-session-id" : transportWithHeaders . sessionId
844
+ } ,
845
+ body : JSON . stringify ( message ) ,
846
+ } ) ;
847
+
848
+ await transportWithHeaders . handleRequest ( req , mockResponse ) ;
849
+
850
+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith (
851
+ 200 ,
852
+ expect . objectContaining ( {
853
+ ...customHeaders ,
854
+ "Content-Type" : "application/json" ,
855
+ "mcp-session-id" : transportWithHeaders . sessionId
856
+ } )
857
+ ) ;
858
+ } ) ;
859
+
860
+ it ( "should include custom headers in error responses" , async ( ) => {
861
+ const req = createMockRequest ( {
862
+ method : "GET" ,
863
+ headers : {
864
+ accept : "text/event-stream" ,
865
+ "mcp-session-id" : "invalid-session-id"
866
+ } ,
867
+ } ) ;
868
+
869
+ await transportWithHeaders . handleRequest ( req , mockResponse ) ;
870
+
871
+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith (
872
+ 404 ,
873
+ expect . objectContaining ( customHeaders )
874
+ ) ;
875
+ } ) ;
876
+
877
+ it ( "should not override essential headers with custom headers" , async ( ) => {
878
+ const transportWithConflictingHeaders = new StreamableHTTPServerTransport ( endpoint , {
879
+ customHeaders : {
880
+ "Content-Type" : "text/plain" , // 尝试覆盖必要的 Content-Type 头
881
+ "X-Custom-Header" : "custom-value"
882
+ }
883
+ } ) ;
884
+
885
+ const req = createMockRequest ( {
886
+ method : "GET" ,
887
+ headers : {
888
+ accept : "text/event-stream" ,
889
+ "mcp-session-id" : transportWithConflictingHeaders . sessionId
890
+ } ,
891
+ } ) ;
892
+
893
+ await transportWithConflictingHeaders . handleRequest ( req , mockResponse ) ;
894
+
895
+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith (
896
+ 200 ,
897
+ expect . objectContaining ( {
898
+ "Content-Type" : "text/event-stream" , // 应该保持原有的 Content-Type
899
+ "X-Custom-Header" : "custom-value"
900
+ } )
901
+ ) ;
902
+ } ) ;
903
+
904
+ it ( "should work with empty custom headers" , async ( ) => {
905
+ const transportWithoutHeaders = new StreamableHTTPServerTransport ( endpoint ) ;
906
+
907
+ const req = createMockRequest ( {
908
+ method : "GET" ,
909
+ headers : {
910
+ accept : "text/event-stream" ,
911
+ "mcp-session-id" : transportWithoutHeaders . sessionId
912
+ } ,
913
+ } ) ;
914
+
915
+ await transportWithoutHeaders . handleRequest ( req , mockResponse ) ;
916
+
917
+ expect ( mockResponse . writeHead ) . toHaveBeenCalledWith (
918
+ 200 ,
919
+ expect . objectContaining ( {
920
+ "Content-Type" : "text/event-stream" ,
921
+ "Cache-Control" : "no-cache" ,
922
+ "Connection" : "keep-alive" ,
923
+ "mcp-session-id" : transportWithoutHeaders . sessionId
924
+ } )
925
+ ) ;
926
+ } ) ;
927
+ } ) ;
791
928
} ) ;
0 commit comments