diff --git a/ext/EnzymeExt.jl b/ext/EnzymeExt.jl index da316f12d..40721d144 100644 --- a/ext/EnzymeExt.jl +++ b/ext/EnzymeExt.jl @@ -100,7 +100,7 @@ _augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type) = function _create_tape_kernel( kernel::Kernel{CPU}, - ModifiedBetween, + Mode, FT, ctxTy, ndrange, @@ -108,7 +108,7 @@ function _create_tape_kernel( args2..., ) TapeType = EnzymeCore.tape_type( - ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), + Mode, FT, Const{Nothing}, Const{ctxTy}, @@ -121,7 +121,7 @@ end function _create_tape_kernel( kernel::Kernel{<:GPU}, - ModifiedBetween, + Mode, FT, ctxTy, ndrange, @@ -139,7 +139,7 @@ function _create_tape_kernel( EnzymeCore.compiler_job_from_backend(backend(kernel), typeof(() -> return), Tuple{}) TapeType = EnzymeCore.tape_type( job, - ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), + Mode, FT, Const{Nothing}, Const{ctxTy}, @@ -159,14 +159,14 @@ _create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev) function cpu_aug_fwd( ctx, f::FT, - ::Val{ModifiedBetween}, + mode::Mode, subtape, ::Val{TapeType}, args..., - ) where {ModifiedBetween, FT, TapeType} + ) where {Mode, FT, TapeType} # A2 = Const{Nothing} -- since f->Nothing forward, _ = EnzymeCore.autodiff_thunk( - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + mode, Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, @@ -183,13 +183,13 @@ end function cpu_rev( ctx, f::FT, - ::Val{ModifiedBetween}, + mode::Mode, subtape, ::Val{TapeType}, args..., - ) where {ModifiedBetween, FT, TapeType} + ) where {Mode, FT, TapeType} _, reverse = EnzymeCore.autodiff_thunk( - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + mode, Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, @@ -205,14 +205,14 @@ end function gpu_aug_fwd( ctx, f::FT, - ::Val{ModifiedBetween}, + mode::Mode, subtape, ::Val{TapeType}, args..., - ) where {ModifiedBetween, FT, TapeType} + ) where {Mode, FT, TapeType} # A2 = Const{Nothing} -- since f->Nothing forward, _ = EnzymeCore.autodiff_deferred_thunk( - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + mode, TapeType, Const{Core.Typeof(f)}, Const{Nothing}, @@ -232,14 +232,14 @@ end function gpu_rev( ctx, f::FT, - ::Val{ModifiedBetween}, + mode::Mode, subtape, ::Val{TapeType}, args..., - ) where {ModifiedBetween, FT, TapeType} + ) where {Mode, FT, TapeType} # XXX: TapeType and A2 as args to autodiff_deferred_thunk _, reverse = EnzymeCore.autodiff_deferred_thunk( - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + mode, TapeType, Const{Core.Typeof(f)}, Const{Nothing}, @@ -294,17 +294,17 @@ function EnzymeRules.augmented_primal( args[i] end end - + Mode = EnzymeCore.set_runtime_activity(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), config) TapeType, subtape, aug_kernel = _create_tape_kernel( kernel, - ModifiedBetween, + Mode, FT, ctxTy, ndrange, iterspace, args2..., ) - aug_kernel(f, ModifiedBetween, subtape, Val(TapeType), args2...; ndrange, workgroupsize) + aug_kernel(f, Mode, subtape, Val(TapeType), args2...; ndrange, workgroupsize) # TODO the fact that ctxTy is type unstable means this is all type unstable. # Since custom rules require a fixed return type, explicitly cast to Any, rather @@ -336,11 +336,11 @@ function EnzymeRules.reverse( f = kernel.f ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...)) - + Mode = EnzymeCore.set_runtime_activity(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), config) rev_kernel = _create_rev_kernel(kernel) rev_kernel( f, - ModifiedBetween, + Mode, subtape, Val(tape_type), args2...;