|
474 | 474 | end |
475 | 475 |
|
476 | 476 | @assert !Base.isabstracttype(T) |
477 | | - if !(Base.isconcretetype(T) || is_concrete_tuple(T) || T isa UnionAll) |
| 477 | + if !(Base.isconcretetype(T) || (T <: Tuple && T != Tuple) || T isa UnionAll) |
478 | 478 | throw(AssertionError("Type $T is not concrete type or concrete tuple")) |
479 | 479 | end |
480 | 480 |
|
|
515 | 515 | return active_reg_inner(T, (), world) |
516 | 516 | end |
517 | 517 |
|
518 | | -@inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T} |
| 518 | +Base.@pure @inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T} |
519 | 519 | seen = () |
520 | 520 |
|
521 | 521 | # check if it could contain an active |
@@ -3342,6 +3342,8 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr |
3342 | 3342 | world = job.world |
3343 | 3343 | interp = GPUCompiler.get_interpreter(job) |
3344 | 3344 | rt = job.config.params.rt |
| 3345 | + @assert eltype(rt) != Union{} |
| 3346 | + |
3345 | 3347 | shadow_init = job.config.params.shadowInit |
3346 | 3348 | ctx = context(mod) |
3347 | 3349 | dl = string(LLVM.datalayout(mod)) |
@@ -3546,6 +3548,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, |
3546 | 3548 | pactualRetType = actualRetType |
3547 | 3549 | sret_union = is_sret_union(actualRetType) |
3548 | 3550 | literal_rt = eltype(rettype) |
| 3551 | + @assert literal_rt != Union{} |
3549 | 3552 | sret_union_rt = is_sret_union(literal_rt) |
3550 | 3553 | @assert sret_union == sret_union_rt |
3551 | 3554 | if sret_union |
@@ -3684,9 +3687,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, |
3684 | 3687 | end |
3685 | 3688 | end |
3686 | 3689 |
|
3687 | | - combinedReturn = Tuple{sret_types...} |
3688 | | - if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) |
3689 | | - combinedReturn = AnonymousStruct(combinedReturn) |
| 3690 | + combinedReturn = if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) |
| 3691 | + AnonymousStruct(Tuple{sret_types...}) |
| 3692 | + else |
| 3693 | + Tuple{sret_types...} |
3690 | 3694 | end |
3691 | 3695 |
|
3692 | 3696 | uses_sret = is_sret(combinedReturn) |
@@ -4794,14 +4798,19 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; |
4794 | 4798 | libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true, |
4795 | 4799 | strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing) |
4796 | 4800 | params = job.config.params |
| 4801 | + if params.run_enzyme |
| 4802 | + @assert eltype(params.rt) != Union{} |
| 4803 | + end |
4797 | 4804 | expectedTapeType = params.expectedTapeType |
4798 | 4805 | mode = params.mode |
4799 | 4806 | TT = params.TT |
4800 | 4807 | width = params.width |
4801 | 4808 | abiwrap = params.abiwrap |
4802 | 4809 | primal = job.source |
4803 | 4810 | modifiedBetween = params.modifiedBetween |
4804 | | - @assert length(modifiedBetween) == length(TT.parameters) |
| 4811 | + if length(modifiedBetween) != length(TT.parameters) |
| 4812 | + throw(AssertionError("length(modifiedBetween) [aka $(length(modifiedBetween))] != length(TT.parameters) [aka $(length(TT.parameters))] at TT=$TT")) |
| 4813 | + end |
4805 | 4814 | returnPrimal = params.returnPrimal |
4806 | 4815 |
|
4807 | 4816 | if !(params.rt <: Const) |
@@ -5297,6 +5306,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; |
5297 | 5306 | end |
5298 | 5307 |
|
5299 | 5308 | @assert actualRetType !== nothing |
| 5309 | + if params.run_enzyme |
| 5310 | + @assert actualRetType != Union{} |
| 5311 | + end |
5300 | 5312 |
|
5301 | 5313 | if must_wrap |
5302 | 5314 | llvmfn = primalf |
@@ -5838,7 +5850,11 @@ end |
5838 | 5850 | end |
5839 | 5851 |
|
5840 | 5852 | push!(ccexprs, argexpr) |
5841 | | - if !(FA <: Const) |
| 5853 | + if (FA <: Active) |
| 5854 | + return quote |
| 5855 | + error("Cannot have function with Active annotation, $FA") |
| 5856 | + end |
| 5857 | + elseif !(FA <: Const) |
5842 | 5858 | argexpr = :(fn.dval) |
5843 | 5859 | if isboxed |
5844 | 5860 | push!(types, Any) |
@@ -6274,9 +6290,16 @@ end |
6274 | 6290 | compile_result = cached_compilation(job) |
6275 | 6291 | if !run_enzyme |
6276 | 6292 | ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal, World} |
6277 | | - return quote |
6278 | | - Base.@_inline_meta |
6279 | | - $ErrT($(compile_result.adjoint)) |
| 6293 | + if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient |
| 6294 | + return quote |
| 6295 | + Base.@_inline_meta |
| 6296 | + ($ErrT($(compile_result.adjoint)), $ErrT($(compile_result.adjoint))) |
| 6297 | + end |
| 6298 | + else |
| 6299 | + return quote |
| 6300 | + Base.@_inline_meta |
| 6301 | + $ErrT($(compile_result.adjoint)) |
| 6302 | + end |
6280 | 6303 | end |
6281 | 6304 | elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient |
6282 | 6305 | TapeType = compile_result.TapeType |
|
0 commit comments