Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ _augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type) =

function _create_tape_kernel(
kernel::Kernel{CPU},
ModifiedBetween,
Mode,
FT,
ctxTy,
ndrange,
iterspace,
args2...,
)
TapeType = EnzymeCore.tape_type(
ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
Mode,
FT,
Const{Nothing},
Const{ctxTy},
Expand All @@ -121,7 +121,7 @@ end

function _create_tape_kernel(
kernel::Kernel{<:GPU},
ModifiedBetween,
Mode,
FT,
ctxTy,
ndrange,
Expand All @@ -139,7 +139,7 @@ function _create_tape_kernel(
EnzymeCore.compiler_job_from_backend(backend(kernel), typeof(() -> return), Tuple{})
TapeType = EnzymeCore.tape_type(
job,
ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
Mode,
FT,
Const{Nothing},
Const{ctxTy},
Expand All @@ -159,14 +159,14 @@ _create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev)
function cpu_aug_fwd(
ctx,
f::FT,
::Val{ModifiedBetween},
mode::Mode,
subtape,
::Val{TapeType},
args...,
) where {ModifiedBetween, FT, TapeType}
) where {Mode, FT, TapeType}
# A2 = Const{Nothing} -- since f->Nothing
forward, _ = EnzymeCore.autodiff_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
mode,
Const{Core.Typeof(f)},
Const{Nothing},
Const{Core.Typeof(ctx)},
Expand All @@ -183,13 +183,13 @@ end
function cpu_rev(
ctx,
f::FT,
::Val{ModifiedBetween},
mode::Mode,
subtape,
::Val{TapeType},
args...,
) where {ModifiedBetween, FT, TapeType}
) where {Mode, FT, TapeType}
_, reverse = EnzymeCore.autodiff_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
mode,
Const{Core.Typeof(f)},
Const{Nothing},
Const{Core.Typeof(ctx)},
Expand All @@ -205,14 +205,14 @@ end
function gpu_aug_fwd(
ctx,
f::FT,
::Val{ModifiedBetween},
mode::Mode,
subtape,
::Val{TapeType},
args...,
) where {ModifiedBetween, FT, TapeType}
) where {Mode, FT, TapeType}
# A2 = Const{Nothing} -- since f->Nothing
forward, _ = EnzymeCore.autodiff_deferred_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
mode,
TapeType,
Const{Core.Typeof(f)},
Const{Nothing},
Expand All @@ -232,14 +232,14 @@ end
function gpu_rev(
ctx,
f::FT,
::Val{ModifiedBetween},
mode::Mode,
subtape,
::Val{TapeType},
args...,
) where {ModifiedBetween, FT, TapeType}
) where {Mode, FT, TapeType}
# XXX: TapeType and A2 as args to autodiff_deferred_thunk
_, reverse = EnzymeCore.autodiff_deferred_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
mode,
TapeType,
Const{Core.Typeof(f)},
Const{Nothing},
Expand Down Expand Up @@ -294,17 +294,17 @@ function EnzymeRules.augmented_primal(
args[i]
end
end

Mode = EnzymeCore.set_runtime_activity(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), config)
TapeType, subtape, aug_kernel = _create_tape_kernel(
kernel,
ModifiedBetween,
Mode,
FT,
ctxTy,
ndrange,
iterspace,
args2...,
)
aug_kernel(f, ModifiedBetween, subtape, Val(TapeType), args2...; ndrange, workgroupsize)
aug_kernel(f, Mode, subtape, Val(TapeType), args2...; ndrange, workgroupsize)

# TODO the fact that ctxTy is type unstable means this is all type unstable.
# Since custom rules require a fixed return type, explicitly cast to Any, rather
Expand Down Expand Up @@ -336,11 +336,11 @@ function EnzymeRules.reverse(
f = kernel.f

ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

Mode = EnzymeCore.set_runtime_activity(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), config)
rev_kernel = _create_rev_kernel(kernel)
rev_kernel(
f,
ModifiedBetween,
Mode,
subtape,
Val(tape_type),
args2...;
Expand Down
Loading