Skip to content

Commit ffcc20c

Browse files
authored
Fix const-only apply iterate (#1526)
* Fix const-only apply iterate * fix ct * Fix mixed activity for type unstable * Update jitrules.jl * Update jitrules.jl * wip tuple * fix batch tuple generation * Ensure runtime store error * fix * cleanup * ignore 1.8 * newstructv * ignore test
1 parent 86da3cd commit ffcc20c

File tree

7 files changed

+808
-159
lines changed

7 files changed

+808
-159
lines changed

src/compiler.jl

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ end
474474
end
475475

476476
@assert !Base.isabstracttype(T)
477-
if !(Base.isconcretetype(T) || is_concrete_tuple(T) || T isa UnionAll)
477+
if !(Base.isconcretetype(T) || (T <: Tuple && T != Tuple) || T isa UnionAll)
478478
throw(AssertionError("Type $T is not concrete type or concrete tuple"))
479479
end
480480

@@ -515,7 +515,7 @@ end
515515
return active_reg_inner(T, (), world)
516516
end
517517

518-
@inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T}
518+
Base.@pure @inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T}
519519
seen = ()
520520

521521
# check if it could contain an active
@@ -3342,6 +3342,8 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
33423342
world = job.world
33433343
interp = GPUCompiler.get_interpreter(job)
33443344
rt = job.config.params.rt
3345+
@assert eltype(rt) != Union{}
3346+
33453347
shadow_init = job.config.params.shadowInit
33463348
ctx = context(mod)
33473349
dl = string(LLVM.datalayout(mod))
@@ -3546,6 +3548,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
35463548
pactualRetType = actualRetType
35473549
sret_union = is_sret_union(actualRetType)
35483550
literal_rt = eltype(rettype)
3551+
@assert literal_rt != Union{}
35493552
sret_union_rt = is_sret_union(literal_rt)
35503553
@assert sret_union == sret_union_rt
35513554
if sret_union
@@ -3684,9 +3687,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
36843687
end
36853688
end
36863689

3687-
combinedReturn = Tuple{sret_types...}
3688-
if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types)
3689-
combinedReturn = AnonymousStruct(combinedReturn)
3690+
combinedReturn = if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types)
3691+
AnonymousStruct(Tuple{sret_types...})
3692+
else
3693+
Tuple{sret_types...}
36903694
end
36913695

36923696
uses_sret = is_sret(combinedReturn)
@@ -4794,14 +4798,19 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
47944798
libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true,
47954799
strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing)
47964800
params = job.config.params
4801+
if params.run_enzyme
4802+
@assert eltype(params.rt) != Union{}
4803+
end
47974804
expectedTapeType = params.expectedTapeType
47984805
mode = params.mode
47994806
TT = params.TT
48004807
width = params.width
48014808
abiwrap = params.abiwrap
48024809
primal = job.source
48034810
modifiedBetween = params.modifiedBetween
4804-
@assert length(modifiedBetween) == length(TT.parameters)
4811+
if length(modifiedBetween) != length(TT.parameters)
4812+
throw(AssertionError("length(modifiedBetween) [aka $(length(modifiedBetween))] != length(TT.parameters) [aka $(length(TT.parameters))] at TT=$TT"))
4813+
end
48054814
returnPrimal = params.returnPrimal
48064815

48074816
if !(params.rt <: Const)
@@ -5297,6 +5306,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
52975306
end
52985307

52995308
@assert actualRetType !== nothing
5309+
if params.run_enzyme
5310+
@assert actualRetType != Union{}
5311+
end
53005312

53015313
if must_wrap
53025314
llvmfn = primalf
@@ -5838,7 +5850,11 @@ end
58385850
end
58395851

58405852
push!(ccexprs, argexpr)
5841-
if !(FA <: Const)
5853+
if (FA <: Active)
5854+
return quote
5855+
error("Cannot have function with Active annotation, $FA")
5856+
end
5857+
elseif !(FA <: Const)
58425858
argexpr = :(fn.dval)
58435859
if isboxed
58445860
push!(types, Any)
@@ -6274,9 +6290,16 @@ end
62746290
compile_result = cached_compilation(job)
62756291
if !run_enzyme
62766292
ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal, World}
6277-
return quote
6278-
Base.@_inline_meta
6279-
$ErrT($(compile_result.adjoint))
6293+
if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient
6294+
return quote
6295+
Base.@_inline_meta
6296+
($ErrT($(compile_result.adjoint)), $ErrT($(compile_result.adjoint)))
6297+
end
6298+
else
6299+
return quote
6300+
Base.@_inline_meta
6301+
$ErrT($(compile_result.adjoint))
6302+
end
62806303
end
62816304
elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient
62826305
TapeType = compile_result.TapeType

0 commit comments

Comments
 (0)