1- use std:: sync:: Arc ;
1+ use std:: { sync:: Arc , time :: Duration } ;
22
33use arrow:: {
44 array:: { Array , StringArray , StringBuilder } ,
@@ -10,6 +10,7 @@ use datafusion_udf_wasm_host::{
1010 WasmPermissions , WasmScalarUdf ,
1111 http:: { AllowCertainHttpRequests , HttpRequestValidator , Matcher } ,
1212} ;
13+ use tokio:: runtime:: Handle ;
1314use wasmtime_wasi_http:: types:: DEFAULT_FORBIDDEN_HEADERS ;
1415use wiremock:: { Mock , MockServer , ResponseTemplate , matchers} ;
1516
@@ -568,6 +569,7 @@ where
568569 WasmScalarUdf :: new (
569570 python_component ( ) . await ,
570571 & WasmPermissions :: new ( ) . with_http ( permissions) ,
572+ Handle :: current ( ) ,
571573 code. to_owned ( ) ,
572574 )
573575 . await
@@ -582,3 +584,79 @@ where
582584 assert_eq ! ( udfs. len( ) , 1 ) ;
583585 udfs. into_iter ( ) . next ( ) . unwrap ( )
584586}
587+
588+ #[ test]
589+ fn test_io_runtime ( ) {
590+ const CODE : & str = r#"
591+ import urllib3
592+
593+ def perform_request(url: str) -> str:
594+ resp = urllib3.request("GET", url)
595+ return resp.data.decode("utf-8")
596+ "# ;
597+
598+ let rt_tmp = tokio:: runtime:: Builder :: new_current_thread ( )
599+ . build ( )
600+ . unwrap ( ) ;
601+ let rt_cpu = tokio:: runtime:: Builder :: new_multi_thread ( )
602+ . worker_threads ( 1 )
603+ // It would be nice if all the timeouts-related timers would also run within the within the I/O runtime, but
604+ // that requires some larger intervention (either upstream or with a custom WASI HTTP implementation).
605+ // Hence, we don't do that yet.
606+ . enable_time ( )
607+ . build ( )
608+ . unwrap ( ) ;
609+ let rt_io = tokio:: runtime:: Builder :: new_multi_thread ( )
610+ . worker_threads ( 1 )
611+ . enable_all ( )
612+ . build ( )
613+ . unwrap ( ) ;
614+
615+ let server = rt_io. block_on ( async {
616+ let server = MockServer :: start ( ) . await ;
617+ Mock :: given ( matchers:: any ( ) )
618+ . respond_with ( ResponseTemplate :: new ( 200 ) . set_body_string ( "hello world!" ) )
619+ . expect ( 1 )
620+ . mount ( & server)
621+ . await ;
622+ server
623+ } ) ;
624+
625+ // deliberately use a runtime what we are going to throw away later to prevent tricks like `Handle::current`
626+ let udf = rt_tmp. block_on ( async {
627+ let mut permissions = AllowCertainHttpRequests :: new ( ) ;
628+ permissions. allow ( Matcher {
629+ method : http:: Method :: GET ,
630+ host : server. address ( ) . ip ( ) . to_string ( ) . into ( ) ,
631+ port : server. address ( ) . port ( ) ,
632+ } ) ;
633+
634+ let udfs = WasmScalarUdf :: new (
635+ python_component ( ) . await ,
636+ & WasmPermissions :: new ( ) . with_http ( permissions) ,
637+ rt_io. handle ( ) . clone ( ) ,
638+ CODE . to_owned ( ) ,
639+ )
640+ . await
641+ . unwrap ( ) ;
642+ assert_eq ! ( udfs. len( ) , 1 ) ;
643+ udfs. into_iter ( ) . next ( ) . unwrap ( )
644+ } ) ;
645+ rt_tmp. shutdown_timeout ( Duration :: from_secs ( 1 ) ) ;
646+
647+ let array = rt_cpu. block_on ( async {
648+ udf. invoke_with_args ( ScalarFunctionArgs {
649+ args : vec ! [ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( server. uri( ) ) ) ) ] ,
650+ arg_fields : vec ! [ Arc :: new( Field :: new( "uri" , DataType :: Utf8 , true ) ) ] ,
651+ number_rows : 1 ,
652+ return_field : Arc :: new ( Field :: new ( "r" , DataType :: Utf8 , true ) ) ,
653+ } )
654+ . unwrap ( )
655+ . unwrap_array ( )
656+ } ) ;
657+
658+ assert_eq ! (
659+ array. as_ref( ) ,
660+ & StringArray :: from_iter( [ Some ( "hello world!" . to_owned( ) ) , ] ) as & dyn Array ,
661+ ) ;
662+ }
0 commit comments