Skip to content

Commit a889bb6

Browse files
authored
Mixed activity for getfield (#1535)
* Mixed activity for getfield * bump ver * fixup runtime iterate for mixed * fix iter * mixedduplicated return * fixup * fix * try inference fix re ref * try more * Update Project.toml * Update jitrules.jl
1 parent 835b6d5 commit a889bb6

File tree

8 files changed

+651
-174
lines changed

8 files changed

+651
-174
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Enzyme"
22
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
4-
version = "0.12.13"
4+
version = "0.12.14"
55

66
[deps]
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
@@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2020
CEnum = "0.4, 0.5"
2121
ChainRulesCore = "1"
2222
EnzymeCore = "0.7.5"
23-
Enzyme_jll = "0.0.121"
23+
Enzyme_jll = "0.0.122"
2424
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
2525
LLVM = "6.1, 7"
2626
ObjectFile = "0.4"

src/Enzyme.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ end
6464
arg = @inbounds args[i]
6565
if arg isa Active
6666
return true
67+
elseif arg isa MixedDuplicated
68+
return true
69+
elseif arg isa BatchMixedDuplicated
70+
return true
6771
else
6872
return false
6973
end
@@ -95,6 +99,10 @@ end
9599
end
96100

97101
@inline same_or_one_rec(current) = current
102+
@inline same_or_one_rec(current, arg::BatchMixedDuplicated{T, N}, args...) where {T,N} =
103+
same_or_one_rec(same_or_one_helper(current, N), args...)
104+
@inline same_or_one_rec(current, arg::Type{BatchMixedDuplicated{T, N}}, args...) where {T,N} =
105+
same_or_one_rec(same_or_one_helper(current, N), args...)
98106
@inline same_or_one_rec(current, arg::BatchDuplicatedFunc{T, N}, args...) where {T,N} =
99107
same_or_one_rec(same_or_one_helper(current, N), args...)
100108
@inline same_or_one_rec(current, arg::Type{BatchDuplicatedFunc{T, N}}, args...) where {T,N} =
@@ -844,6 +852,12 @@ result, ∂v, ∂A
844852
else
845853
BatchDuplicatedNoNeed{eltype(A2), width}
846854
end
855+
elseif A2 <: MixedDuplicated && width != 1
856+
if A2 isa UnionAll
857+
BatchMixedDuplicated{T, width} where T
858+
else
859+
BatchMixedDuplicated{eltype(A2), width}
860+
end
847861
else
848862
A2
849863
end

src/compiler.jl

Lines changed: 102 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,13 @@ end
543543
return res
544544
end
545545

546+
# check if a value is guaranteed to be not contain active[register] data
547+
# (aka not either mixed or active)
548+
@inline function guaranteed_nonactive(::Type{T}) where T
549+
rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing))
550+
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
551+
end
552+
546553
@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode))
547554

548555
@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T}
@@ -555,6 +562,8 @@ end
555562
else
556563
if ActReg == ActiveState
557564
return Active{T}
565+
elseif ActReg == MixedState
566+
return MixedDuplicated{T}
558567
else
559568
return Duplicated{T}
560569
end
@@ -2494,7 +2503,7 @@ function store_nonjl_types!(B, startval, p)
24942503
return
24952504
end
24962505

2497-
function get_julia_inner_types(B, p, startvals...; added=[])
2506+
function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[])
24982507
T_jlvalue = LLVM.StructType(LLVMType[])
24992508
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
25002509
vals = LLVM.Value[]
@@ -2547,8 +2556,20 @@ function get_julia_inner_types(B, p, startvals...; added=[])
25472556
end
25482557
continue
25492558
end
2550-
GPUCompiler.@safe_warn "Enzyme illegal subtype", ty, cur, SI, p, v
2551-
@assert false
2559+
if isa(ty, LLVM.IntegerType)
2560+
continue
2561+
end
2562+
if isa(ty, LLVM.FloatingPointType)
2563+
continue
2564+
end
2565+
msg = sprint() do io
2566+
println(io, "Enzyme illegal subtype")
2567+
println(io, "ty=", ty)
2568+
println(io, "cur=", cur)
2569+
println(io, "p=", p)
2570+
println(io, "startvals=", startvals)
2571+
end
2572+
throw(AssertionError(msg))
25522573
end
25532574
return vals
25542575
end
@@ -3474,7 +3495,11 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
34743495
# If requested, the shadow return value of the function
34753496
# For each active (non duplicated) argument
34763497
# The adjoint of that argument
3477-
retType = convert(API.CDIFFE_TYPE, rt)
3498+
retType = if rt <: MixedDuplicated || rt <: BatchMixedDuplicated
3499+
API.DFT_OUT_DIFF
3500+
else
3501+
convert(API.CDIFFE_TYPE, rt)
3502+
end
34783503

34793504
rules = Dict{String, API.CustomRuleType}(
34803505
"jl_array_copy" => @cfunction(inout_rule,
@@ -3513,7 +3538,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
35133538

35143539
if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient
35153540
returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType))
3516-
shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED)
3541+
shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED || rt <: MixedDuplicated || rt <: BatchMixedDuplicated)
35173542
returnUsed &= returnPrimal
35183543
augmented = API.EnzymeCreateAugmentedPrimal(
35193544
logic, primalf, retType, args_activity, TA, #=returnUsed=# returnUsed,
@@ -3679,16 +3704,20 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
36793704
end
36803705

36813706
# API.DFT_OUT_DIFF
3682-
if is_adjoint && rettype <: Active
3683-
@assert !sret_union
3684-
if allocatedinline(actualRetType) != allocatedinline(literal_rt)
3685-
throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)"))
3686-
end
3687-
if !allocatedinline(actualRetType)
3688-
throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)"))
3707+
if is_adjoint
3708+
if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
3709+
@assert !sret_union
3710+
if allocatedinline(actualRetType) != allocatedinline(literal_rt)
3711+
throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)"))
3712+
end
3713+
if rettype <: Active
3714+
if !allocatedinline(actualRetType)
3715+
throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)"))
3716+
end
3717+
end
3718+
dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType; allow_boxed=!(rettype <: Active))))
3719+
push!(T_wrapperargs, dretTy)
36893720
end
3690-
dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType)))
3691-
push!(T_wrapperargs, dretTy)
36923721
end
36933722

36943723
data = Array{Int64}(undef, 3)
@@ -3730,6 +3759,12 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
37303759
else
37313760
push!(sret_types, AnonymousStruct(NTuple{width, literal_rt}))
37323761
end
3762+
elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
3763+
if width == 1
3764+
push!(sret_types, Base.RefValue{literal_rt})
3765+
else
3766+
push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{literal_rt}}))
3767+
end
37333768
end
37343769
else
37353770
@assert rettype <: Const || rettype <: Active
@@ -3953,7 +3988,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
39533988
end
39543989
end
39553990

3956-
if is_adjoint && rettype <: Active
3991+
if is_adjoint && (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated)
39573992
push!(realparms, params[i])
39583993
i += 1
39593994
end
@@ -3999,12 +4034,26 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
39994034
if data[i] != -1
40004035
eval = extract_value!(builder, val, data[i])
40014036
end
4037+
if i == 3
4038+
if rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
4039+
ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue)))
4040+
for idx in 1:width
4041+
pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1)
4042+
al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)})
4043+
llty = value_type(pv)
4044+
al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al))))
4045+
store!(builder, pv, al)
4046+
emit_writebarrier!(builder, get_julia_inner_types(builder, al0, pv))
4047+
ival = (width == 1 ) ? al0 : insert_value!(builder, ival, al0, idx-1)
4048+
end
4049+
eval = ival
4050+
end
4051+
end
40024052
eval = fixup_abi(i, eval)
40034053
ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)])
40044054
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
40054055
si = store!(builder, eval, ptr)
40064056
returnNum+=1
4007-
40084057
if i == 3 && shadow_init
40094058
shadows = LLVM.Value[]
40104059
if width == 1
@@ -5943,34 +5992,35 @@ end
59435992
end
59445993

59455994
if !RawCall && !(CC <: PrimalErrorThunk)
5946-
if rettype <: Active
5995+
if rettype <: Active
59475996
if length(argtypes) + is_adjoint + needs_tape != length(argexprs)
59485997
return quote
5949-
throw(MethodError($CC(fptr), $args))
5998+
throw(MethodError($CC(fptr), (fn, args...)))
5999+
end
6000+
end
6001+
elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
6002+
if length(argtypes) + is_adjoint * width + needs_tape != length(argexprs)
6003+
return quote
6004+
throw(MethodError($CC(fptr), (fn, args...)))
59506005
end
59516006
end
59526007
elseif rettype <: Const
59536008
if length(argtypes) + needs_tape != length(argexprs)
59546009
return quote
5955-
throw(MethodError($CC(fptr), $args))
6010+
throw(MethodError($CC(fptr), (fn, args...)))
59566011
end
59576012
end
59586013
else
59596014
if length(argtypes) + needs_tape != length(argexprs)
59606015
return quote
5961-
throw(MethodError($CC(fptr), $args))
6016+
throw(MethodError($CC(fptr), (fn, args...)))
59626017
end
59636018
end
59646019
end
59656020
end
59666021

59676022
types = DataType[]
59686023

5969-
if eltype(rettype) === Union{} && false
5970-
return quote
5971-
error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up")
5972-
end
5973-
end
59746024
if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType)
59756025
rrt = eltype(rettype)
59766026
error("Return type `$rrt` not marked Const, but is ghost or const type.")
@@ -6133,17 +6183,28 @@ end
61336183
end
61346184

61356185
# API.DFT_OUT_DIFF
6136-
if is_adjoint && rettype <: Active
6137-
# TODO handle batch width
6138-
@assert allocatedinline(jlRT)
6139-
j_drT = if width == 1
6140-
jlRT
6141-
else
6142-
NTuple{width, jlRT}
6186+
if is_adjoint
6187+
if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
6188+
# TODO handle batch width
6189+
if rettype <: Active
6190+
@assert allocatedinline(jlRT)
6191+
end
6192+
j_drT = if width == 1
6193+
jlRT
6194+
else
6195+
NTuple{width, jlRT}
6196+
end
6197+
push!(types, j_drT)
6198+
if width == 1 || rettype <: Active
6199+
push!(ccexprs, argexprs[i])
6200+
i+=1
6201+
else
6202+
push!(ccexprs, quote
6203+
($(argexprs[i:i+width-1]...),)
6204+
end)
6205+
i+=width
6206+
end
61436207
end
6144-
push!(types, j_drT)
6145-
push!(ccexprs, argexprs[i])
6146-
i+=1
61476208
end
61486209

61496210
if needs_tape
@@ -6181,8 +6242,12 @@ end
61816242
end
61826243
if rettype <: Duplicated || rettype <: DuplicatedNoNeed
61836244
push!(sret_types, jlRT)
6245+
elseif rettype <: MixedDuplicated
6246+
push!(sret_types, Base.RefValue{jlRT})
61846247
elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed
61856248
push!(sret_types, AnonymousStruct(NTuple{width, jlRT}))
6249+
elseif rettype <: BatchMixedDuplicated
6250+
push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{jlRT}}))
61866251
elseif CC <: AugmentedForwardThunk
61876252
push!(sret_types, Nothing)
61886253
elseif rettype <: Const
@@ -6406,6 +6471,8 @@ end
64066471
@inline remove_innerty(::Type{<:DuplicatedNoNeed}) = DuplicatedNoNeed
64076472
@inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated
64086473
@inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed
6474+
@inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated
6475+
@inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated
64096476

64106477
@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI}
64116478
JuliaContext() do ctx

0 commit comments

Comments
 (0)