Skip to content

Commit eb79121

Browse files
authored
Fix kw rrule closure argument index (#1543)
1 parent e980433 commit eb79121

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

src/rules/customrules.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
734734
end
735735
end
736736

737-
# push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
737+
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
738738

739739
needsTape = !isghostty(TapeT) && !Core.Compiler.isconstType(TapeT)
740740

@@ -765,11 +765,11 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
765765
funcTy = rev_TT.parameters[isKWCall ? 4 : 2]
766766
if needsTape
767767
@assert tape != C_NULL
768-
tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])) + !isghostty(funcTy)
769-
trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself+(RT <: Active)
768+
tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup)) + !isghostty(funcTy)
769+
trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself + (RT <: Active)
770770
innerTy = value_type(parameters(llvmf)[trueidx])
771771
if innerTy != value_type(tape)
772-
if isabstracttype(TapeT) || TapeT == Tuple || TapeT.layout == C_NULL
772+
if isabstracttype(TapeT) || TapeT == Tuple || TapeT.layout == C_NULL || TapeT == Array
773773
msg = sprint() do io
774774
println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))")
775775
println(io, "tape_idx=", tape_idx)
@@ -831,7 +831,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
831831
if any_jltypes(llty)
832832
emit_writebarrier!(B, get_julia_inner_types(B, al0, val))
833833
end
834-
insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])), al)
834+
insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup)), al)
835835
end
836836
end
837837

test/kwrrules.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,47 @@ g4(x, y) = f_kw4(x; y)
111111
@test autodiff(Reverse, g4, Active(2.0), Const(42.0))[1][1] 42004.0
112112
@test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, g4, Active(2.0), Active(42.0))[1]
113113

114+
struct Closure2
115+
v::Vector{Float64}
116+
str::String
117+
end
118+
119+
function (cl::Closure2)(x; width=7)
120+
val = cl.v[1] * x * width
121+
cl.v[1] = 0.0
122+
return val
123+
end
124+
125+
function wrapclos(cl, x)
126+
cl(x; width=9)
127+
end
128+
129+
function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure2},
130+
::Type{<:Active}, args::Vararg{Active,N}; width=7) where {N}
131+
vec = copy(func.val.v)
132+
pval = func.val(args[1].val)
133+
primal = if EnzymeRules.needs_primal(config)
134+
pval
135+
else
136+
nothing
137+
end
138+
return AugmentedReturn(primal, nothing, vec)
139+
end
140+
141+
function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure2},
142+
dret::Active, tape, args::Vararg{Active,N}; width=7) where {N}
143+
dargs = ntuple(Val(N)) do i
144+
7 * args[1].val * dret.val + tape[1] * 1000 + width * 100000
145+
end
146+
return dargs
147+
end
148+
149+
@testset "KWClosure rule" begin
150+
cl = Closure2([3.14], "3.14")
151+
res = autodiff(Reverse, wrapclos, Active, Const(cl), Active(2.7))[1][2]
152+
@test res 7 * 2.7 + 3.14 * 1000 + 9 * 100000
153+
@test cl.v[1] 0.0
154+
end
155+
114156
end # KWReverseRules
115157

0 commit comments

Comments
 (0)