@@ -10,6 +10,7 @@ use datafusion_udf_wasm_host::{
1010 WasmPermissions , WasmScalarUdf ,
1111 http:: { AllowCertainHttpRequests , Matcher } ,
1212} ;
13+ use wasmtime_wasi_http:: types:: DEFAULT_FORBIDDEN_HEADERS ;
1314use wiremock:: { Mock , MockServer , ResponseTemplate , matchers} ;
1415
1516use crate :: integration_tests:: {
@@ -128,63 +129,161 @@ def perform_request(url: str) -> str:
128129}
129130
130131#[ tokio:: test( flavor = "multi_thread" ) ]
131- async fn test_urllib3_happy_path ( ) {
132+ async fn test_integration ( ) {
132133 const CODE : & str = r#"
133134import urllib3
134135
135- def perform_request(method: str, url: str) -> str:
136- resp = urllib3.request(method, url)
136+ def _headers_str_to_dict(headers: str) -> dict[str, str]:
137+ headers_dct = {}
138+ if headers is not None:
139+ for k_v in headers.split(";"):
140+ [k, v] = k_v.split(":")
141+ headers_dct[k] = v
142+ return headers_dct
143+
144+ def _headers_dict_to_str(headers: dict[str, str]) -> str:
145+ headers = ";".join((
146+ f"{k}:{v}"
147+ for k, v in headers.items()
148+ if k not in ["content-length", "content-type", "date"]
149+ ))
150+ if not headers:
151+ return "n/a"
152+ else:
153+ return headers
154+
155+ def perform_request(method: str, url: str, headers: str | None) -> str:
156+ try:
157+ resp = urllib3.request(
158+ method=method,
159+ url=url,
160+ headers=_headers_str_to_dict(headers),
161+ )
162+ except Exception as e:
163+ return f"ERR: {e}"
137164
138165 resp_status = resp.status
139166 resp_body = resp.data.decode("utf-8")
167+ resp_headers = _headers_dict_to_str(resp.headers)
140168
141- return f"status={resp_status} body='{resp_body}'"
169+ return f"OK: status={resp_status} headers={resp_headers } body='{resp_body}'"
142170"# ;
143171
144- let cases = [
172+ let mut cases = vec ! [
145173 TestCase {
146- resp_body : "case_1" ,
174+ resp: Ok ( TestResponse {
175+ body: "case_1" ,
176+ ..Default :: default ( )
177+ } ) ,
147178 ..Default :: default ( )
148179 } ,
149180 TestCase {
150- resp_body : "case_2" ,
151181 method: "POST" ,
182+ resp: Ok ( TestResponse {
183+ body: "case_2" ,
184+ ..Default :: default ( )
185+ } ) ,
152186 ..Default :: default ( )
153187 } ,
154188 TestCase {
155- resp_body : "case_3" ,
156- path : "/foo" ,
189+ path: "/foo" . to_owned( ) ,
190+ resp: Ok ( TestResponse {
191+ body: "case_3" ,
192+ ..Default :: default ( )
193+ } ) ,
157194 ..Default :: default ( )
158195 } ,
159196 TestCase {
160- resp_body : "case_4" ,
161- path : "/201" ,
162- resp_status : 500 ,
197+ path: "/500" . to_owned( ) ,
198+ resp: Ok ( TestResponse {
199+ status: 500 ,
200+ body: "case_4" ,
201+ ..Default :: default ( )
202+ } ) ,
203+ ..Default :: default ( )
204+ } ,
205+ TestCase {
206+ path: "/headers_in" . to_owned( ) ,
207+ requ_headers: vec![
208+ ( "foo" . to_owned( ) , & [ "bar" ] ) ,
209+ ( "multi" . to_owned( ) , & [ "some" , "thing" ] ) ,
210+ ] ,
211+ resp: Ok ( TestResponse {
212+ body: "case_5" ,
213+ ..Default :: default ( )
214+ } ) ,
215+ ..Default :: default ( )
216+ } ,
217+ TestCase {
218+ path: "/headers_out" . to_owned( ) ,
219+ resp: Ok ( TestResponse {
220+ headers: vec![
221+ ( "foo" . to_owned( ) , & [ "bar" ] ) ,
222+ ( "multi" . to_owned( ) , & [ "some" , "thing" ] ) ,
223+ ] ,
224+ body: "case_6" ,
225+ ..Default :: default ( )
226+ } ) ,
227+ ..Default :: default ( )
228+ } ,
229+ TestCase {
230+ base: Some ( "http://test.com" ) ,
231+ resp: Err ( "HTTPConnectionPool(host='test.com', port=80): Max retries exceeded with url: / (Caused by ProtocolError('Connection aborted.', WasiErrorCode('Request failed with wasi http error ErrorCode_HttpRequestDenied')))" . to_owned( ) ) ,
163232 ..Default :: default ( )
164233 } ,
165234 ] ;
235+ cases. extend ( DEFAULT_FORBIDDEN_HEADERS . iter ( ) . map ( |h| TestCase {
236+ path : format ! ( "/forbidden_header/{h}" ) ,
237+ requ_headers : vec ! [ ( h. to_string( ) , & [ "foo" ] ) ] ,
238+ resp : Err ( "Err { value: HeaderError_Forbidden }" . to_owned ( ) ) ,
239+ ..Default :: default ( )
240+ } ) ) ;
166241
167242 let server = MockServer :: start ( ) . await ;
168243 let mut permissions = AllowCertainHttpRequests :: default ( ) ;
169244
170245 let mut builder_method = StringBuilder :: new ( ) ;
171246 let mut builder_url = StringBuilder :: new ( ) ;
247+ let mut builder_headers = StringBuilder :: new ( ) ;
172248 let mut builder_result = StringBuilder :: new ( ) ;
173249
174250 for case in & cases {
175- case. mock ( ) . mount ( & server) . await ;
251+ if let Some ( mock) = case. mock ( & server) {
252+ mock. mount ( & server) . await ;
253+ }
176254 permissions. allow ( case. matcher ( & server) ) ;
177255
178256 let TestCase {
257+ base,
179258 method,
180259 path,
181- resp_body ,
182- resp_status ,
260+ requ_headers ,
261+ resp ,
183262 } = case;
184263
185264 builder_method. append_value ( method) ;
186- builder_url. append_value ( format ! ( "{}{}" , server. uri( ) , path) ) ;
187- builder_result. append_value ( format ! ( "status={resp_status} body='{resp_body}'" ) ) ;
265+ builder_url. append_value ( format ! (
266+ "{}{}" ,
267+ base. map( |b| b. to_owned( ) ) . unwrap_or_else( || server. uri( ) ) ,
268+ path
269+ ) ) ;
270+ builder_headers. append_option ( headers_to_string ( requ_headers) ) ;
271+
272+ match resp {
273+ Ok ( TestResponse {
274+ status,
275+ headers,
276+ body,
277+ } ) => {
278+ let resp_headers = headers_to_string ( headers) . unwrap_or_else ( || "n/a" . to_owned ( ) ) ;
279+ builder_result. append_value ( format ! (
280+ "OK: status={status} headers={resp_headers} body='{body}'"
281+ ) ) ;
282+ }
283+ Err ( e) => {
284+ builder_result. append_value ( format ! ( "ERR: {e}" ) ) ;
285+ }
286+ }
188287 }
189288
190289 let udfs = WasmScalarUdf :: new (
@@ -202,10 +301,12 @@ def perform_request(method: str, url: str) -> str:
202301 args : vec ! [
203302 ColumnarValue :: Array ( Arc :: new( builder_method. finish( ) ) ) ,
204303 ColumnarValue :: Array ( Arc :: new( builder_url. finish( ) ) ) ,
304+ ColumnarValue :: Array ( Arc :: new( builder_headers. finish( ) ) ) ,
205305 ] ,
206306 arg_fields : vec ! [
207307 Arc :: new( Field :: new( "method" , DataType :: Utf8 , true ) ) ,
208308 Arc :: new( Field :: new( "url" , DataType :: Utf8 , true ) ) ,
309+ Arc :: new( Field :: new( "headers" , DataType :: Utf8 , true ) ) ,
209310 ] ,
210311 number_rows : cases. len ( ) ,
211312 return_field : Arc :: new ( Field :: new ( "r" , DataType :: Utf8 , true ) ) ,
@@ -216,20 +317,39 @@ def perform_request(method: str, url: str) -> str:
216317 assert_eq ! ( array. as_ref( ) , & builder_result. finish( ) as & dyn Array , ) ;
217318}
218319
320+ #[ derive( Debug , Clone ) ]
321+ struct TestResponse {
322+ status : u16 ,
323+ headers : Vec < ( String , & ' static [ & ' static str ] ) > ,
324+ body : & ' static str ,
325+ }
326+
327+ impl Default for TestResponse {
328+ fn default ( ) -> Self {
329+ Self {
330+ status : 200 ,
331+ headers : vec ! [ ] ,
332+ body : "" ,
333+ }
334+ }
335+ }
336+
219337struct TestCase {
338+ base : Option < & ' static str > ,
220339 method : & ' static str ,
221- path : & ' static str ,
222- resp_body : & ' static str ,
223- resp_status : u16 ,
340+ path : String ,
341+ requ_headers : Vec < ( String , & ' static [ & ' static str ] ) > ,
342+ resp : Result < TestResponse , String > ,
224343}
225344
226345impl Default for TestCase {
227346 fn default ( ) -> Self {
228347 Self {
348+ base : None ,
229349 method : "GET" ,
230- path : "/" ,
231- resp_body : "" ,
232- resp_status : 200 ,
350+ path : "/" . to_owned ( ) ,
351+ requ_headers : vec ! [ ] ,
352+ resp : Ok ( TestResponse :: default ( ) ) ,
233353 }
234354 }
235355}
@@ -243,17 +363,83 @@ impl TestCase {
243363 }
244364 }
245365
246- fn mock ( & self ) -> Mock {
366+ fn mock ( & self , server : & MockServer ) -> Option < Mock > {
247367 let Self {
368+ base,
248369 method,
249370 path,
250- resp_body ,
251- resp_status ,
371+ requ_headers ,
372+ resp ,
252373 } = self ;
374+ if base. is_some ( ) {
375+ return None ;
376+ }
377+
378+ let TestResponse {
379+ status : resp_status,
380+ headers : resp_headers,
381+ body : resp_body,
382+ } = resp. clone ( ) . unwrap_or_default ( ) ;
383+
384+ let mut builder = Mock :: given ( matchers:: method ( method) )
385+ . and ( matchers:: path ( path. as_str ( ) ) )
386+ . and ( NoForbiddenHeaders :: new (
387+ server. address ( ) . ip ( ) . to_string ( ) ,
388+ server. address ( ) . port ( ) ,
389+ ) ) ;
390+
391+ for ( k, v) in requ_headers {
392+ builder = builder. and ( matchers:: headers ( k. as_str ( ) , v. to_vec ( ) ) ) ;
393+ }
394+
395+ let mock = builder
396+ . respond_with (
397+ ResponseTemplate :: new ( resp_status)
398+ . set_body_string ( resp_body)
399+ . append_headers ( resp_headers. iter ( ) . map ( |( k, v) | ( k, v. join ( "," ) ) ) ) ,
400+ )
401+ . expect ( resp. is_ok ( ) as u64 ) ;
402+ Some ( mock)
403+ }
404+ }
405+
406+ fn headers_to_string ( headers : & [ ( String , & [ & str ] ) ] ) -> Option < String > {
407+ if headers. is_empty ( ) {
408+ None
409+ } else {
410+ let headers = headers
411+ . iter ( )
412+ . map ( |( k, v) | format ! ( "{k}:{}" , v. join( "," ) ) )
413+ . collect :: < Vec < _ > > ( ) ;
414+ Some ( headers. join ( ";" ) )
415+ }
416+ }
417+
418+ struct NoForbiddenHeaders {
419+ host : String ,
420+ port : u16 ,
421+ }
422+
423+ impl NoForbiddenHeaders {
424+ fn new ( host : String , port : u16 ) -> Self {
425+ Self { host, port }
426+ }
427+ }
428+
429+ impl wiremock:: Match for NoForbiddenHeaders {
430+ fn matches ( & self , request : & wiremock:: Request ) -> bool {
431+ // "host" is part of the forbidden headers that the client is not supposed to use, but it is set by our own
432+ // host HTTP lib
433+ let Some ( host_val) = request. headers . get ( http:: header:: HOST ) else {
434+ return false ;
435+ } ;
436+ if host_val. to_str ( ) . expect ( "always a string" ) != format ! ( "{}:{}" , self . host, self . port) {
437+ return false ;
438+ }
253439
254- Mock :: given ( matchers :: method ( method ) )
255- . and ( matchers :: path ( * path ) )
256- . respond_with ( ResponseTemplate :: new ( * resp_status ) . set_body_string ( * resp_body ) )
257- . expect ( 1 )
440+ DEFAULT_FORBIDDEN_HEADERS
441+ . iter ( )
442+ . filter ( |h| * h != http :: header :: HOST )
443+ . all ( |h| !request . headers . contains_key ( h ) )
258444 }
259445}
0 commit comments