|
543 | 543 | return res |
544 | 544 | end |
545 | 545 |
|
| 546 | +# check if a value is guaranteed to be not contain active[register] data |
| 547 | +# (aka not either mixed or active) |
| 548 | +@inline function guaranteed_nonactive(::Type{T}) where T |
| 549 | + rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing)) |
| 550 | + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState |
| 551 | +end |
| 552 | + |
546 | 553 | @inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode)) |
547 | 554 |
|
548 | 555 | @inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} |
|
555 | 562 | else |
556 | 563 | if ActReg == ActiveState |
557 | 564 | return Active{T} |
| 565 | + elseif ActReg == MixedState |
| 566 | + return MixedDuplicated{T} |
558 | 567 | else |
559 | 568 | return Duplicated{T} |
560 | 569 | end |
@@ -2494,7 +2503,7 @@ function store_nonjl_types!(B, startval, p) |
2494 | 2503 | return |
2495 | 2504 | end |
2496 | 2505 |
|
2497 | | -function get_julia_inner_types(B, p, startvals...; added=[]) |
| 2506 | +function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[]) |
2498 | 2507 | T_jlvalue = LLVM.StructType(LLVMType[]) |
2499 | 2508 | T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) |
2500 | 2509 | vals = LLVM.Value[] |
@@ -2547,8 +2556,20 @@ function get_julia_inner_types(B, p, startvals...; added=[]) |
2547 | 2556 | end |
2548 | 2557 | continue |
2549 | 2558 | end |
2550 | | - GPUCompiler.@safe_warn "Enzyme illegal subtype", ty, cur, SI, p, v |
2551 | | - @assert false |
| 2559 | + if isa(ty, LLVM.IntegerType) |
| 2560 | + continue |
| 2561 | + end |
| 2562 | + if isa(ty, LLVM.FloatingPointType) |
| 2563 | + continue |
| 2564 | + end |
| 2565 | + msg = sprint() do io |
| 2566 | + println(io, "Enzyme illegal subtype") |
| 2567 | + println(io, "ty=", ty) |
| 2568 | + println(io, "cur=", cur) |
| 2569 | + println(io, "p=", p) |
| 2570 | + println(io, "startvals=", startvals) |
| 2571 | + end |
| 2572 | + throw(AssertionError(msg)) |
2552 | 2573 | end |
2553 | 2574 | return vals |
2554 | 2575 | end |
@@ -3474,7 +3495,11 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr |
3474 | 3495 | # If requested, the shadow return value of the function |
3475 | 3496 | # For each active (non duplicated) argument |
3476 | 3497 | # The adjoint of that argument |
3477 | | - retType = convert(API.CDIFFE_TYPE, rt) |
| 3498 | + retType = if rt <: MixedDuplicated || rt <: BatchMixedDuplicated |
| 3499 | + API.DFT_OUT_DIFF |
| 3500 | + else |
| 3501 | + convert(API.CDIFFE_TYPE, rt) |
| 3502 | + end |
3478 | 3503 |
|
3479 | 3504 | rules = Dict{String, API.CustomRuleType}( |
3480 | 3505 | "jl_array_copy" => @cfunction(inout_rule, |
@@ -3513,7 +3538,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr |
3513 | 3538 |
|
3514 | 3539 | if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient |
3515 | 3540 | returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) |
3516 | | - shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED) |
| 3541 | + shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED || rt <: MixedDuplicated || rt <: BatchMixedDuplicated) |
3517 | 3542 | returnUsed &= returnPrimal |
3518 | 3543 | augmented = API.EnzymeCreateAugmentedPrimal( |
3519 | 3544 | logic, primalf, retType, args_activity, TA, #=returnUsed=# returnUsed, |
@@ -3679,16 +3704,20 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, |
3679 | 3704 | end |
3680 | 3705 |
|
3681 | 3706 | # API.DFT_OUT_DIFF |
3682 | | - if is_adjoint && rettype <: Active |
3683 | | - @assert !sret_union |
3684 | | - if allocatedinline(actualRetType) != allocatedinline(literal_rt) |
3685 | | - throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)")) |
3686 | | - end |
3687 | | - if !allocatedinline(actualRetType) |
3688 | | - throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)")) |
| 3707 | + if is_adjoint |
| 3708 | + if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated |
| 3709 | + @assert !sret_union |
| 3710 | + if allocatedinline(actualRetType) != allocatedinline(literal_rt) |
| 3711 | + throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)")) |
| 3712 | + end |
| 3713 | + if rettype <: Active |
| 3714 | + if !allocatedinline(actualRetType) |
| 3715 | + throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)")) |
| 3716 | + end |
| 3717 | + end |
| 3718 | + dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType; allow_boxed=!(rettype <: Active)))) |
| 3719 | + push!(T_wrapperargs, dretTy) |
3689 | 3720 | end |
3690 | | - dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType))) |
3691 | | - push!(T_wrapperargs, dretTy) |
3692 | 3721 | end |
3693 | 3722 |
|
3694 | 3723 | data = Array{Int64}(undef, 3) |
@@ -3730,6 +3759,12 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, |
3730 | 3759 | else |
3731 | 3760 | push!(sret_types, AnonymousStruct(NTuple{width, literal_rt})) |
3732 | 3761 | end |
| 3762 | + elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated |
| 3763 | + if width == 1 |
| 3764 | + push!(sret_types, Base.RefValue{literal_rt}) |
| 3765 | + else |
| 3766 | + push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{literal_rt}})) |
| 3767 | + end |
3733 | 3768 | end |
3734 | 3769 | else |
3735 | 3770 | @assert rettype <: Const || rettype <: Active |
@@ -3953,7 +3988,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, |
3953 | 3988 | end |
3954 | 3989 | end |
3955 | 3990 |
|
3956 | | - if is_adjoint && rettype <: Active |
| 3991 | + if is_adjoint && (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated) |
3957 | 3992 | push!(realparms, params[i]) |
3958 | 3993 | i += 1 |
3959 | 3994 | end |
@@ -3999,12 +4034,26 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, |
3999 | 4034 | if data[i] != -1 |
4000 | 4035 | eval = extract_value!(builder, val, data[i]) |
4001 | 4036 | end |
| 4037 | + if i == 3 |
| 4038 | + if rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated |
| 4039 | + ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue))) |
| 4040 | + for idx in 1:width |
| 4041 | + pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1) |
| 4042 | + al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)}) |
| 4043 | + llty = value_type(pv) |
| 4044 | + al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) |
| 4045 | + store!(builder, pv, al) |
| 4046 | + emit_writebarrier!(builder, get_julia_inner_types(builder, al0, pv)) |
| 4047 | + ival = (width == 1 ) ? al0 : insert_value!(builder, ival, al0, idx-1) |
| 4048 | + end |
| 4049 | + eval = ival |
| 4050 | + end |
| 4051 | + end |
4002 | 4052 | eval = fixup_abi(i, eval) |
4003 | 4053 | ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) |
4004 | 4054 | ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) |
4005 | 4055 | si = store!(builder, eval, ptr) |
4006 | 4056 | returnNum+=1 |
4007 | | - |
4008 | 4057 | if i == 3 && shadow_init |
4009 | 4058 | shadows = LLVM.Value[] |
4010 | 4059 | if width == 1 |
@@ -5943,34 +5992,35 @@ end |
5943 | 5992 | end |
5944 | 5993 |
|
5945 | 5994 | if !RawCall && !(CC <: PrimalErrorThunk) |
5946 | | - if rettype <: Active |
| 5995 | + if rettype <: Active |
5947 | 5996 | if length(argtypes) + is_adjoint + needs_tape != length(argexprs) |
5948 | 5997 | return quote |
5949 | | - throw(MethodError($CC(fptr), $args)) |
| 5998 | + throw(MethodError($CC(fptr), (fn, args...))) |
| 5999 | + end |
| 6000 | + end |
| 6001 | + elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated |
| 6002 | + if length(argtypes) + is_adjoint * width + needs_tape != length(argexprs) |
| 6003 | + return quote |
| 6004 | + throw(MethodError($CC(fptr), (fn, args...))) |
5950 | 6005 | end |
5951 | 6006 | end |
5952 | 6007 | elseif rettype <: Const |
5953 | 6008 | if length(argtypes) + needs_tape != length(argexprs) |
5954 | 6009 | return quote |
5955 | | - throw(MethodError($CC(fptr), $args)) |
| 6010 | + throw(MethodError($CC(fptr), (fn, args...))) |
5956 | 6011 | end |
5957 | 6012 | end |
5958 | 6013 | else |
5959 | 6014 | if length(argtypes) + needs_tape != length(argexprs) |
5960 | 6015 | return quote |
5961 | | - throw(MethodError($CC(fptr), $args)) |
| 6016 | + throw(MethodError($CC(fptr), (fn, args...))) |
5962 | 6017 | end |
5963 | 6018 | end |
5964 | 6019 | end |
5965 | 6020 | end |
5966 | 6021 |
|
5967 | 6022 | types = DataType[] |
5968 | 6023 |
|
5969 | | - if eltype(rettype) === Union{} && false |
5970 | | - return quote |
5971 | | - error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") |
5972 | | - end |
5973 | | - end |
5974 | 6024 | if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType) |
5975 | 6025 | rrt = eltype(rettype) |
5976 | 6026 | error("Return type `$rrt` not marked Const, but is ghost or const type.") |
@@ -6133,17 +6183,28 @@ end |
6133 | 6183 | end |
6134 | 6184 |
|
6135 | 6185 | # API.DFT_OUT_DIFF |
6136 | | - if is_adjoint && rettype <: Active |
6137 | | - # TODO handle batch width |
6138 | | - @assert allocatedinline(jlRT) |
6139 | | - j_drT = if width == 1 |
6140 | | - jlRT |
6141 | | - else |
6142 | | - NTuple{width, jlRT} |
| 6186 | + if is_adjoint |
| 6187 | + if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated |
| 6188 | + # TODO handle batch width |
| 6189 | + if rettype <: Active |
| 6190 | + @assert allocatedinline(jlRT) |
| 6191 | + end |
| 6192 | + j_drT = if width == 1 |
| 6193 | + jlRT |
| 6194 | + else |
| 6195 | + NTuple{width, jlRT} |
| 6196 | + end |
| 6197 | + push!(types, j_drT) |
| 6198 | + if width == 1 || rettype <: Active |
| 6199 | + push!(ccexprs, argexprs[i]) |
| 6200 | + i+=1 |
| 6201 | + else |
| 6202 | + push!(ccexprs, quote |
| 6203 | + ($(argexprs[i:i+width-1]...),) |
| 6204 | + end) |
| 6205 | + i+=width |
| 6206 | + end |
6143 | 6207 | end |
6144 | | - push!(types, j_drT) |
6145 | | - push!(ccexprs, argexprs[i]) |
6146 | | - i+=1 |
6147 | 6208 | end |
6148 | 6209 |
|
6149 | 6210 | if needs_tape |
@@ -6181,8 +6242,12 @@ end |
6181 | 6242 | end |
6182 | 6243 | if rettype <: Duplicated || rettype <: DuplicatedNoNeed |
6183 | 6244 | push!(sret_types, jlRT) |
| 6245 | + elseif rettype <: MixedDuplicated |
| 6246 | + push!(sret_types, Base.RefValue{jlRT}) |
6184 | 6247 | elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed |
6185 | 6248 | push!(sret_types, AnonymousStruct(NTuple{width, jlRT})) |
| 6249 | + elseif rettype <: BatchMixedDuplicated |
| 6250 | + push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{jlRT}})) |
6186 | 6251 | elseif CC <: AugmentedForwardThunk |
6187 | 6252 | push!(sret_types, Nothing) |
6188 | 6253 | elseif rettype <: Const |
@@ -6406,6 +6471,8 @@ end |
6406 | 6471 | @inline remove_innerty(::Type{<:DuplicatedNoNeed}) = DuplicatedNoNeed |
6407 | 6472 | @inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated |
6408 | 6473 | @inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed |
| 6474 | +@inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated |
| 6475 | +@inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated |
6409 | 6476 |
|
6410 | 6477 | @inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} |
6411 | 6478 | JuliaContext() do ctx |
|
0 commit comments