Skip to content

Commit bd60907

Browse files
authored
Fix reverse mode closure issues (#1533)
* Fix custom reverse on closure * fix closure
1 parent 6c2b0d9 commit bd60907

File tree

2 files changed

+64
-9
lines changed

2 files changed

+64
-9
lines changed

src/rules/customrules.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,8 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
687687
end
688688
end
689689
end
690-
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
690+
691+
# push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
691692

692693
needsTape = !isghostty(TapeT) && !Core.Compiler.isconstType(TapeT)
693694

@@ -711,22 +712,37 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
711712

712713
swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(llvmf, i)))) for i in 1:length(collect(parameters(llvmf))))
713714

714-
_, sret, returnRoots = get_return_info(enzyme_custom_extract_mi(llvmf)[2])
715+
miRT = enzyme_custom_extract_mi(llvmf)[2]
716+
_, sret, returnRoots = get_return_info(miRT)
715717

716718
if !forward
719+
funcTy = rev_TT.parameters[isKWCall ? 4 : 2]
717720
if needsTape
718721
@assert tape != C_NULL
719-
tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4]))
720-
innerTy = value_type(parameters(llvmf)[tape_idx+(sret !== nothing)+(RT <: Active)])
722+
tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])) + !isghostty(funcTy)
723+
trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself+(RT <: Active)
724+
innerTy = value_type(parameters(llvmf)[trueidx])
721725
if innerTy != value_type(tape)
722-
if isabstracttype(TapeT)
726+
if isabstracttype(TapeT) || TapeT == Tuple || TapeT.layout == C_NULL
723727
msg = sprint() do io
724728
println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))")
725729
println(io, "tape_idx=", tape_idx)
730+
println(io, "true_idx=", trueidx)
731+
println(io, "isKWCall=", isKWCall)
732+
println(io, "kwtup=", kwtup)
733+
println(io, "funcTy=", funcTy)
734+
println(io, "isghostty(funcTy)=", isghostty(funcTy))
735+
println(io, "miRT=", miRT)
726736
println(io, "sret=", sret)
737+
println(io, "returnRoots=", returnRoots)
738+
println(io, "swiftself=", swiftself)
727739
println(io, "RT=", RT)
728740
println(io, "tape=", tape)
729-
println(io, "llvmf=", string(llvmf))
741+
println(io, "llvmf=", string(LLVM.function_type(llvmf)))
742+
println(io, "TapeT=", TapeT)
743+
println(io, "mi=", mi)
744+
println(io, "ami=", ami)
745+
println(io, "rev_TT =", rev_TT)
730746
end
731747
throw(AssertionError(msg))
732748
end
@@ -749,7 +765,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
749765
val = LLVM.Value(API.EnzymeGradientUtilsDiffe(gutils, orig, B))
750766
else
751767
llety = convert(LLVMType, eltype(RT))
752-
ptr_val = invert_pointer(gutils, operands(orig)[1], B)
768+
ptr_val = invert_pointer(gutils, operands(orig)[1 + !isghostty(funcTy)], B)
753769
val = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llety)))
754770
for idx in 1:width
755771
ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1)
@@ -769,8 +785,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
769785
if any_jltypes(llty)
770786
emit_writebarrier!(B, get_julia_inner_types(B, al0, val))
771787
end
772-
773-
insert!(args, 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])), al)
788+
insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])), al)
774789
end
775790
end
776791

test/rrules.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,4 +305,44 @@ end
305305
@test dU[1] 7 * ( 3.0 + 4.0im )
306306
end
307307
end
308+
309+
310+
struct Closure
311+
v::Vector{Float64}
312+
end
313+
314+
function (cl::Closure)(x)
315+
val = cl.v[1] * x
316+
cl.v[1] = 0.0
317+
return val
318+
end
319+
320+
321+
function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure},
322+
::Type{<:Active}, args::Vararg{Active,N}) where {N}
323+
vec = copy(func.val.v)
324+
pval = func.val(args[1].val)
325+
primal = if EnzymeRules.needs_primal(config)
326+
pval
327+
else
328+
nothing
329+
end
330+
return AugmentedReturn(primal, nothing, vec)
331+
end
332+
333+
function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure},
334+
dret::Active, tape, args::Vararg{Active,N}) where {N}
335+
dargs = ntuple(Val(N)) do i
336+
7 * args[1].val * dret.val + tape[1] * 1000
337+
end
338+
return dargs
339+
end
340+
341+
@testset "Closure rule" begin
342+
cl = Closure([3.14])
343+
res = autodiff(Reverse, cl, Active, Active(2.7))[1][1]
344+
@test res 7 * 2.7 + 3.14 * 1000
345+
@test cl.v[1] 0.0
346+
end
347+
308348
end # ReverseRules

0 commit comments

Comments
 (0)