@@ -11,18 +11,20 @@ module EnzymeExt
1111 EnzymeRules. inactive (:: Type{StaticSize} , x... ) = nothing
1212
1313 function fwd (ctx, f, args... )
14- EnzymeCore. autodiff_deferred (Forward, Const (f), Const, Const (ctx), args... )
14+ EnzymeCore. autodiff_deferred (Forward, Const (f), Const{Nothing} , Const (ctx), args... )
1515 return nothing
1616 end
1717
1818 function aug_fwd (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
19- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
19+ TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
20+ forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
2021 subtape[__groupindex (ctx)] = forward (Const (f), Const (ctx), args... )[1 ]
2122 return nothing
2223 end
2324
2425 function rev (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
25- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
26+ TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
27+ forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
2628 tp = subtape[__groupindex (ctx)]
2729 reverse (Const (f), Const (ctx), args... , tp)
2830 return nothing
@@ -92,7 +94,7 @@ module EnzymeExt
9294 end
9395
9496 # TODO in KA backends like CUDAKernels, etc have a version with a parent job type
95- TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map (Core. Typeof, args2)... )
97+ TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween), FT, Const{Nothing} , Const{ctxTy}, map (Core. Typeof, args2)... )
9698
9799
98100 subtape = Array {TapeType} (undef, size (blocks (iterspace)))
0 commit comments