Skip to content

Commit 4e23a04

Browse files
Move executable and device to thunk from expr (EnzymeAD#855)
* Move executable and device to thunk from expr * fix err * fix err * fixup * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 9177412 commit 4e23a04

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

src/Compiler.jl

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,14 +1391,11 @@ Generate Julia code to call the XLA executable.
13911391
13921392
# Arguments
13931393
1394-
- `exec`: The XLA executable to call.
13951394
- `flatten_names`: A list of `Symbol`s representing the names of the flattened linear arguments.
13961395
- `donated_args_mask`: A list of `UInt8`s representing whether the argument is donated.
13971396
- `nresults`: The number of results to expect.
13981397
"""
13991398
function codegen_xla_call(
1400-
exec,
1401-
device,
14021399
flatten_names,
14031400
donated_args_mask,
14041401
nresults,
@@ -1420,7 +1417,7 @@ function codegen_xla_call(
14201417
quote
14211418
GC.@preserve $(flatten_names...) begin
14221419
linearized_results = XLA.execute(
1423-
$exec,
1420+
thunk.exec,
14241421
($(flatten_buffer_refs...),),
14251422
$(Tuple(donated_args_mask)),
14261423
Val($nresults),
@@ -1433,8 +1430,8 @@ function codegen_xla_call(
14331430
quote
14341431
GC.@preserve $(flatten_names...) begin
14351432
linearized_results = XLA.execute_sharded(
1436-
$exec,
1437-
$(device),
1433+
thunk.exec,
1434+
thunk.device,
14381435
($(flatten_buffer_refs...),),
14391436
$(Tuple(donated_args_mask)),
14401437
Val($nresults),
@@ -1600,8 +1597,6 @@ function compile(f, args; sync=false, kwargs...)
16001597
)
16011598

16021599
concretized_res_names, xla_call_code = codegen_xla_call(
1603-
exec,
1604-
device,
16051600
flatten_arg_names,
16061601
donated_args_mask,
16071602
length(linear_results),
@@ -1653,15 +1648,23 @@ function compile(f, args; sync=false, kwargs...)
16531648
end
16541649

16551650
return register_thunk(
1656-
fname, Tuple{map(Core.Typeof, args)...}, body, f, mlir_fn_res.fnwrapped
1651+
fname,
1652+
Tuple{map(Core.Typeof, args)...},
1653+
body,
1654+
f,
1655+
mlir_fn_res.fnwrapped,
1656+
exec,
1657+
mlir_fn_res.is_sharded ? nothing : device,
16571658
)
16581659
end
16591660

16601661
# inspired by RuntimeGeneratedFunction.jl
16611662
const __thunk_body_cache = Dict{Symbol,Expr}()
16621663

1663-
struct Thunk{FTy,tag,IsClosure,ArgTypes}
1664+
struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
16641665
f::FTy
1666+
exec::ExecTy
1667+
device::DeviceTy
16651668
end
16661669

16671670
struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end
@@ -1687,14 +1690,16 @@ function Base.showerror(
16871690
)
16881691
end
16891692

1690-
@generated function (thunk::Thunk{FTy,tag,ArgTypes,IsClosure})(
1693+
@generated function (thunk::Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy})(
16911694
args...
1692-
) where {FTy,tag,ArgTypes,IsClosure}
1695+
) where {FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy}
16931696
FoundTypes = Tuple{args...}
16941697
if ArgTypes != FoundTypes
16951698
return quote
16961699
throw(
1697-
$(MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}())
1700+
$(MisMatchedThunkTypeError{
1701+
Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy},FoundTypes
1702+
}()),
16981703
)
16991704
end
17001705
end
@@ -1710,10 +1715,18 @@ end
17101715
end
17111716

17121717
function register_thunk(
1713-
tag::Symbol, @nospecialize(argtys::Type), body::Expr, @nospecialize(f), isclosure::Bool
1718+
tag::Symbol,
1719+
@nospecialize(argtys::Type),
1720+
body::Expr,
1721+
@nospecialize(f),
1722+
isclosure::Bool,
1723+
exec,
1724+
device,
17141725
)
17151726
__thunk_body_cache[tag] = body
1716-
return Thunk{Core.Typeof(f),tag,argtys,isclosure}(f)
1727+
return Thunk{Core.Typeof(f),tag,argtys,isclosure,Core.Typeof(exec),Core.Typeof(device)}(
1728+
f, exec, device
1729+
)
17171730
end
17181731

17191732
for cache_type in (:callcache, :sdycache)

0 commit comments

Comments
 (0)