@@ -64,10 +64,164 @@ fn forward_headers(names: &[String], incoming: &HeaderMap, outgoing: &mut Header
64
64
#[ cfg( test) ]
65
65
mod tests {
66
66
use super :: * ;
67
+ use headers:: Authorization ;
68
+ use http:: Extensions ;
67
69
use reqwest:: header:: HeaderValue ;
68
70
71
+ use crate :: auth:: ValidToken ;
72
+
73
+ #[ test]
74
+ fn test_build_request_headers_includes_static_headers ( ) {
75
+ let mut static_headers = HeaderMap :: new ( ) ;
76
+ static_headers. insert ( "x-api-key" , HeaderValue :: from_static ( "static-key" ) ) ;
77
+ static_headers. insert ( "user-agent" , HeaderValue :: from_static ( "mcp-server" ) ) ;
78
+
79
+ let forward_header_names = vec ! [ ] ;
80
+ let incoming_headers = HeaderMap :: new ( ) ;
81
+ let extensions = Extensions :: new ( ) ;
82
+
83
+ let result = build_request_headers (
84
+ & static_headers,
85
+ & forward_header_names,
86
+ & incoming_headers,
87
+ & extensions,
88
+ false ,
89
+ ) ;
90
+
91
+ assert_eq ! ( result. get( "x-api-key" ) . unwrap( ) , "static-key" ) ;
92
+ assert_eq ! ( result. get( "user-agent" ) . unwrap( ) , "mcp-server" ) ;
93
+ }
94
+
95
+ #[ test]
96
+ fn test_build_request_headers_forwards_configured_headers ( ) {
97
+ let static_headers = HeaderMap :: new ( ) ;
98
+ let forward_header_names = vec ! [ "x-tenant-id" . to_string( ) , "x-trace-id" . to_string( ) ] ;
99
+
100
+ let mut incoming_headers = HeaderMap :: new ( ) ;
101
+ incoming_headers. insert ( "x-tenant-id" , HeaderValue :: from_static ( "tenant-123" ) ) ;
102
+ incoming_headers. insert ( "x-trace-id" , HeaderValue :: from_static ( "trace-456" ) ) ;
103
+ incoming_headers. insert ( "other-header" , HeaderValue :: from_static ( "ignored" ) ) ;
104
+
105
+ let extensions = Extensions :: new ( ) ;
106
+
107
+ let result = build_request_headers (
108
+ & static_headers,
109
+ & forward_header_names,
110
+ & incoming_headers,
111
+ & extensions,
112
+ false ,
113
+ ) ;
114
+
115
+ assert_eq ! ( result. get( "x-tenant-id" ) . unwrap( ) , "tenant-123" ) ;
116
+ assert_eq ! ( result. get( "x-trace-id" ) . unwrap( ) , "trace-456" ) ;
117
+ assert ! ( result. get( "other-header" ) . is_none( ) ) ;
118
+ }
119
+
120
+ #[ test]
121
+ fn test_build_request_headers_adds_oauth_token_when_enabled ( ) {
122
+ let static_headers = HeaderMap :: new ( ) ;
123
+ let forward_header_names = vec ! [ ] ;
124
+ let incoming_headers = HeaderMap :: new ( ) ;
125
+
126
+ let mut extensions = Extensions :: new ( ) ;
127
+ let token = ValidToken ( Authorization :: bearer ( "test-token" ) . unwrap ( ) ) ;
128
+ extensions. insert ( token) ;
129
+
130
+ let result = build_request_headers (
131
+ & static_headers,
132
+ & forward_header_names,
133
+ & incoming_headers,
134
+ & extensions,
135
+ false ,
136
+ ) ;
137
+
138
+ assert ! ( result. get( "authorization" ) . is_some( ) ) ;
139
+ assert_eq ! ( result. get( "authorization" ) . unwrap( ) , "Bearer test-token" ) ;
140
+ }
141
+
142
+ #[ test]
143
+ fn test_build_request_headers_skips_oauth_token_when_disabled ( ) {
144
+ let static_headers = HeaderMap :: new ( ) ;
145
+ let forward_header_names = vec ! [ ] ;
146
+ let incoming_headers = HeaderMap :: new ( ) ;
147
+
148
+ let mut extensions = Extensions :: new ( ) ;
149
+ let token = ValidToken ( Authorization :: bearer ( "test-token" ) . unwrap ( ) ) ;
150
+ extensions. insert ( token) ;
151
+
152
+ let result = build_request_headers (
153
+ & static_headers,
154
+ & forward_header_names,
155
+ & incoming_headers,
156
+ & extensions,
157
+ true ,
158
+ ) ;
159
+
160
+ assert ! ( result. get( "authorization" ) . is_none( ) ) ;
161
+ }
162
+
163
+ #[ test]
164
+ fn test_build_request_headers_forwards_mcp_session_id ( ) {
165
+ let static_headers = HeaderMap :: new ( ) ;
166
+ let forward_header_names = vec ! [ ] ;
167
+
168
+ let mut incoming_headers = HeaderMap :: new ( ) ;
169
+ incoming_headers. insert ( "mcp-session-id" , HeaderValue :: from_static ( "session-123" ) ) ;
170
+
171
+ let extensions = Extensions :: new ( ) ;
172
+
173
+ let result = build_request_headers (
174
+ & static_headers,
175
+ & forward_header_names,
176
+ & incoming_headers,
177
+ & extensions,
178
+ false ,
179
+ ) ;
180
+
181
+ assert_eq ! ( result. get( "mcp-session-id" ) . unwrap( ) , "session-123" ) ;
182
+ }
183
+
184
+ #[ test]
185
+ fn test_build_request_headers_combined_scenario ( ) {
186
+ // Static headers
187
+ let mut static_headers = HeaderMap :: new ( ) ;
188
+ static_headers. insert ( "x-api-key" , HeaderValue :: from_static ( "static-key" ) ) ;
189
+
190
+ // Forward specific headers
191
+ let forward_header_names = vec ! [ "x-tenant-id" . to_string( ) ] ;
192
+
193
+ // Incoming headers
194
+ let mut incoming_headers = HeaderMap :: new ( ) ;
195
+ incoming_headers. insert ( "x-tenant-id" , HeaderValue :: from_static ( "tenant-123" ) ) ;
196
+ incoming_headers. insert ( "mcp-session-id" , HeaderValue :: from_static ( "session-456" ) ) ;
197
+ incoming_headers. insert (
198
+ "ignored-header" ,
199
+ HeaderValue :: from_static ( "should-not-appear" ) ,
200
+ ) ;
201
+
202
+ // OAuth token
203
+ let mut extensions = Extensions :: new ( ) ;
204
+ let token = ValidToken ( Authorization :: bearer ( "oauth-token" ) . unwrap ( ) ) ;
205
+ extensions. insert ( token) ;
206
+
207
+ let result = build_request_headers (
208
+ & static_headers,
209
+ & forward_header_names,
210
+ & incoming_headers,
211
+ & extensions,
212
+ false ,
213
+ ) ;
214
+
215
+ // Verify all parts combined correctly
216
+ assert_eq ! ( result. get( "x-api-key" ) . unwrap( ) , "static-key" ) ;
217
+ assert_eq ! ( result. get( "x-tenant-id" ) . unwrap( ) , "tenant-123" ) ;
218
+ assert_eq ! ( result. get( "mcp-session-id" ) . unwrap( ) , "session-456" ) ;
219
+ assert_eq ! ( result. get( "authorization" ) . unwrap( ) , "Bearer oauth-token" ) ;
220
+ assert ! ( result. get( "ignored-header" ) . is_none( ) ) ;
221
+ }
222
+
69
223
#[ test]
70
- fn test_forward_no_headers_by_default ( ) {
224
+ fn test_forward_headers_no_headers_by_default ( ) {
71
225
let names: Vec < String > = vec ! [ ] ;
72
226
73
227
let mut incoming = HeaderMap :: new ( ) ;
@@ -81,7 +235,7 @@ mod tests {
81
235
}
82
236
83
237
#[ test]
84
- fn test_forward_only_allowed_headers ( ) {
238
+ fn test_forward_headers_only_specific_headers ( ) {
85
239
let names = vec ! [
86
240
"x-tenant-id" . to_string( ) , // Multi-tenancy
87
241
"x-trace-id" . to_string( ) , // Distributed tracing
@@ -112,7 +266,7 @@ mod tests {
112
266
}
113
267
114
268
#[ test]
115
- fn test_hop_by_hop_headers_blocked ( ) {
269
+ fn test_forward_headers_blocks_hop_by_hop_headers ( ) {
116
270
let names = vec ! [ "connection" . to_string( ) , "content-length" . to_string( ) ] ;
117
271
118
272
let mut incoming = HeaderMap :: new ( ) ;
@@ -128,7 +282,7 @@ mod tests {
128
282
}
129
283
130
284
#[ test]
131
- fn test_case_insensitive_matching ( ) {
285
+ fn test_forward_headers_case_insensitive_matching ( ) {
132
286
let names = vec ! [ "X-Tenant-ID" . to_string( ) ] ;
133
287
134
288
let mut incoming = HeaderMap :: new ( ) ;
0 commit comments