@@ -1391,14 +1391,11 @@ Generate Julia code to call the XLA executable.
13911391
13921392# Arguments
13931393
1394- - `exec`: The XLA executable to call.
13951394- `flatten_names`: A list of `Symbol`s representing the names of the flattened linear arguments.
13961395- `donated_args_mask`: A list of `UInt8`s representing whether the argument is donated.
13971396- `nresults`: The number of results to expect.
13981397"""
13991398function codegen_xla_call (
1400- exec,
1401- device,
14021399 flatten_names,
14031400 donated_args_mask,
14041401 nresults,
@@ -1420,7 +1417,7 @@ function codegen_xla_call(
14201417 quote
14211418 GC. @preserve $ (flatten_names... ) begin
14221419 linearized_results = XLA. execute (
1423- $ exec,
1420+ thunk . exec,
14241421 ($ (flatten_buffer_refs... ),),
14251422 $ (Tuple (donated_args_mask)),
14261423 Val ($ nresults),
@@ -1433,8 +1430,8 @@ function codegen_xla_call(
14331430 quote
14341431 GC. @preserve $ (flatten_names... ) begin
14351432 linearized_results = XLA. execute_sharded (
1436- $ exec,
1437- $ ( device) ,
1433+ thunk . exec,
1434+ thunk . device,
14381435 ($ (flatten_buffer_refs... ),),
14391436 $ (Tuple (donated_args_mask)),
14401437 Val ($ nresults),
@@ -1600,8 +1597,6 @@ function compile(f, args; sync=false, kwargs...)
16001597 )
16011598
16021599 concretized_res_names, xla_call_code = codegen_xla_call (
1603- exec,
1604- device,
16051600 flatten_arg_names,
16061601 donated_args_mask,
16071602 length (linear_results),
@@ -1653,15 +1648,23 @@ function compile(f, args; sync=false, kwargs...)
16531648 end
16541649
16551650 return register_thunk (
1656- fname, Tuple{map (Core. Typeof, args)... }, body, f, mlir_fn_res. fnwrapped
1651+ fname,
1652+ Tuple{map (Core. Typeof, args)... },
1653+ body,
1654+ f,
1655+ mlir_fn_res. fnwrapped,
1656+ exec,
1657+ mlir_fn_res. is_sharded ? nothing : device,
16571658 )
16581659end
16591660
16601661# inspired by RuntimeGeneratedFunction.jl
16611662const __thunk_body_cache = Dict {Symbol,Expr} ()
16621663
1663- struct Thunk{FTy,tag,IsClosure,ArgTypes}
1664+ struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy }
16641665 f:: FTy
1666+ exec:: ExecTy
1667+ device:: DeviceTy
16651668end
16661669
16671670struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end
@@ -1687,14 +1690,16 @@ function Base.showerror(
16871690 )
16881691end
16891692
1690- @generated function (thunk:: Thunk{FTy,tag,ArgTypes,IsClosure} )(
1693+ @generated function (thunk:: Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy } )(
16911694 args...
1692- ) where {FTy,tag,ArgTypes,IsClosure}
1695+ ) where {FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy }
16931696 FoundTypes = Tuple{args... }
16941697 if ArgTypes != FoundTypes
16951698 return quote
16961699 throw (
1697- $ (MisMatchedThunkTypeError {Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes} ())
1700+ $ (MisMatchedThunkTypeError{
1701+ Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy},FoundTypes
1702+ }()),
16981703 )
16991704 end
17001705 end
@@ -1710,10 +1715,18 @@ end
17101715end
17111716
17121717function register_thunk (
1713- tag:: Symbol , @nospecialize (argtys:: Type ), body:: Expr , @nospecialize (f), isclosure:: Bool
1718+ tag:: Symbol ,
1719+ @nospecialize (argtys:: Type ),
1720+ body:: Expr ,
1721+ @nospecialize (f),
1722+ isclosure:: Bool ,
1723+ exec,
1724+ device,
17141725)
17151726 __thunk_body_cache[tag] = body
1716- return Thunk {Core.Typeof(f),tag,argtys,isclosure} (f)
1727+ return Thunk {Core.Typeof(f),tag,argtys,isclosure,Core.Typeof(exec),Core.Typeof(device)} (
1728+ f, exec, device
1729+ )
17171730end
17181731
17191732for cache_type in (:callcache , :sdycache )
0 commit comments