Skip to content

Commit 86da3cd

Browse files
authored
Reverse mode apply iterate (#1485)
* Reverse mode apply iterate * fixed * fixup * cleanup * debugging fixes * fixup * cleanup * fix tests * fix batch getfield rev * fix tests * more test fix * fix tuple fast path * fix * Update Project.toml * fix sym index rev * fix test * fixup * Fix unionall * fix * fix sym offset * ix constantarray * Update Project.toml
1 parent df9bff9 commit 86da3cd

File tree

9 files changed

+1145
-488
lines changed

9 files changed

+1145
-488
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Enzyme"
22
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
4-
version = "0.12.11"
4+
version = "0.12.12"
55

66
[deps]
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
@@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2020
CEnum = "0.4, 0.5"
2121
ChainRulesCore = "1"
2222
EnzymeCore = "0.7.4"
23-
Enzyme_jll = "0.0.119"
23+
Enzyme_jll = "0.0.121"
2424
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
2525
LLVM = "6.1, 7"
2626
ObjectFile = "0.4"

src/Enzyme.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ end
7474
end)...}
7575
end
7676

77+
@inline function vaEltypes(args::Type{Ty}) where {Ty <: Tuple}
78+
return Tuple{(ntuple(Val(length(Ty.parameters))) do i
79+
Base.@_inline_meta
80+
eltype(Ty.parameters[i])
81+
end)...}
82+
end
83+
7784
@inline function same_or_one_helper(current, next)
7885
if current == -1
7986
return next

src/compiler.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ end
380380
end
381381

382382
@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret}
383-
384383
if T === Any
385384
return DupState
386385
end
@@ -422,7 +421,9 @@ end
422421
else
423422
inmi = GPUCompiler.methodinstance(typeof(EnzymeCore.EnzymeRules.inactive_type), Tuple{Type{T}}, world)
424423
args = Any[EnzymeCore.EnzymeRules.inactive_type, T];
425-
ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi)
424+
GC.@preserve T begin
425+
ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi)
426+
end
426427
end
427428

428429
if inactivety
@@ -480,11 +481,13 @@ end
480481
@static if VERSION < v"1.7.0"
481482
nT = T
482483
else
483-
nT = if is_concrete_tuple(T)
484+
nT = if T <: Tuple && T != Tuple && !(T isa UnionAll)
484485
Tuple{(ntuple(length(T.parameters)) do i
485486
Base.@_inline_meta
486487
sT = T.parameters[i]
487-
if sT isa Core.TypeofVararg
488+
if sT isa TypeVar
489+
Any
490+
elseif sT isa Core.TypeofVararg
488491
Any
489492
else
490493
sT

src/compiler/validation.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -743,19 +743,18 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width
743743
end
744744
end
745745

746-
seen = Dict{LLVM.Value,Tuple}()
746+
seen = Set{Tuple{LLVM.Value,Tuple}}()
747747
while length(todo) != 0
748748
cur, off = pop!(todo)
749749

750750
while isa(cur, LLVM.AddrSpaceCastInst) # || isa(cur, LLVM.BitCastInst)
751751
cur = operands(cur)[1]
752752
end
753753

754-
if cur in keys(seen)
755-
@assert seen[cur] == off
754+
if cur in seen
756755
continue
757756
end
758-
seen[cur] = off
757+
push!(seen, (cur, off))
759758

760759
if isa(cur, LLVM.PHIInst)
761760
for (v, _) in LLVM.incoming(cur)
@@ -781,7 +780,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width
781780

782781
# if inserting at the current desired offset, we have found the value we need
783782
if ind == off[1]
784-
push!(todo, (operands(cur)[2], -1))
783+
push!(todo, (operands(cur)[2], off[2:end]))
785784
# otherwise it must be inserted at a different point
786785
else
787786
push!(todo, (operands(cur)[1], off))
@@ -880,10 +879,15 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width
880879
end
881880
end
882881

882+
if isa(cur, LLVM.ConstantArray)
883+
push!(todo, (cur[off[1]], off[2:end]))
884+
continue
885+
end
886+
883887
msg = sprint() do io::IO
884888
println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[2])")
885889
println(io, string(enzymefn))
886-
println(io, "cur=", cur)
890+
println(io, "cur=", string(cur))
887891
println(io, "off=", off)
888892
end
889893
throw(AssertionError(msg))

0 commit comments

Comments
 (0)