Skip to content

Commit 15f9bb1

Browse files
authored
MixedDuplicated for custom rules (#1534)
* MixedDuplicated for custom rules * more mixed duplicated * Handle mixed custom rule arg * starting batching * fix * fix tests * simplify mixed activity use
1 parent fb6f959 commit 15f9bb1

File tree

9 files changed

+508
-205
lines changed

9 files changed

+508
-205
lines changed

lib/EnzymeCore/src/EnzymeCore.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,32 @@ end
150150
@inline batch_size(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N
151151

152152

153+
"""
154+
MixedDuplicated(x, ∂f_∂x)
155+
156+
Like [`Duplicated`](@ref), except x may contain both active [immutable] and duplicated [mutable]
157+
data which is differentiable. Only used within custom rules.
158+
"""
159+
struct MixedDuplicated{T} <: Annotation{T}
160+
val::T
161+
dval::Base.RefValue{T}
162+
@inline MixedDuplicated(x::T1, dx::Base.RefValue{T1}, check::Bool=true) where {T1} = new{T1}(x, dx)
163+
end
164+
165+
"""
166+
BatchMixedDuplicated(x, ∂f_∂xs)
167+
168+
Like [`MixedDuplicated`](@ref), except contains several shadows to compute derivatives
169+
for all at once. Only used within custom rules.
170+
"""
171+
struct BatchMixedDuplicated{T,N} <: Annotation{T}
172+
val::T
173+
dval::NTuple{N,Base.RefValue{T}}
174+
@inline BatchMixedDuplicated(x::T1, dx::NTuple{N,Base.RefValue{T1}}, check::Bool=true) where {T1, N} = new{T1, N}(x, dx)
175+
end
176+
@inline batch_size(::BatchMixedDuplicated{T,N}) where {T,N} = N
177+
@inline batch_size(::Type{BatchMixedDuplicated{T,N}}) where {T,N} = N
178+
153179
"""
154180
abstract type ABI
155181

src/Enzyme.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated,
1111
import EnzymeCore: BatchDuplicatedFunc
1212
export BatchDuplicatedFunc
1313

14+
import EnzymeCore: MixedDuplicated, BatchMixedDuplicated
15+
export MixedDuplicated, BatchMixedDuplicated
16+
1417
import EnzymeCore: batch_size, get_func
1518
export batch_size, get_func
1619

src/compiler.jl

Lines changed: 164 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2450,6 +2450,50 @@ else
24502450
end
24512451
end
24522452

2453+
function store_nonjl_types!(B, startval, p)
2454+
T_jlvalue = LLVM.StructType(LLVMType[])
2455+
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
2456+
vals = LLVM.Value[]
2457+
if p != nothing
2458+
push!(vals, p)
2459+
end
2460+
todo = Tuple{Tuple, LLVM.Value}[((), startval)]
2461+
while length(todo) != 0
2462+
path, cur = popfirst!(todo)
2463+
ty = value_type(cur)
2464+
if isa(ty, LLVM.PointerType)
2465+
if any_jltypes(ty)
2466+
continue
2467+
end
2468+
end
2469+
if isa(ty, LLVM.ArrayType)
2470+
if any_jltypes(ty)
2471+
for i=1:length(ty)
2472+
ev = extract_value!(B, cur, i-1)
2473+
push!(todo, ((path..., i-1), ev))
2474+
end
2475+
continue
2476+
end
2477+
end
2478+
if isa(ty, LLVM.StructType)
2479+
if any_jltypes(ty)
2480+
for (i, t) in enumerate(LLVM.elements(ty))
2481+
ev = extract_value!(B, cur, i-1)
2482+
push!(todo, ((path..., i-1), ev))
2483+
end
2484+
continue
2485+
end
2486+
end
2487+
parray = LLVM.Value[LLVM.ConstantInt(LLVM.IntType(64), 0)]
2488+
for v in path
2489+
push!(parray, LLVM.ConstantInt(LLVM.IntType(32), v))
2490+
end
2491+
gptr = gep!(B, value_type(startval), p, parray)
2492+
st = store!(B, cur, gptr)
2493+
end
2494+
return
2495+
end
2496+
24532497
function get_julia_inner_types(B, p, startvals...; added=[])
24542498
T_jlvalue = LLVM.StructType(LLVMType[])
24552499
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
@@ -3404,7 +3448,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
34043448
else
34053449
push!(args_activity, API.DFT_OUT_DIFF)
34063450
end
3407-
elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc
3451+
elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc || T <: MixedDuplicated || T <: BatchMixedDuplicated
34083452
push!(args_activity, API.DFT_DUP_ARG)
34093453
elseif T <: DuplicatedNoNeed || T<: BatchDuplicatedNoNeed
34103454
push!(args_activity, API.DFT_DUP_NONEED)
@@ -3588,7 +3632,6 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
35883632

35893633
isboxed = GPUCompiler.deserves_argbox(source_typ)
35903634
llvmT = isboxed ? T_prjlvalue : convert(LLVMType, source_typ)
3591-
35923635
push!(T_wrapperargs, llvmT)
35933636

35943637
if T <: Const || T <: BatchDuplicatedFunc
@@ -3617,6 +3660,11 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
36173660
if is_adjoint && i != 1
36183661
push!(ActiveRetTypes, Nothing)
36193662
end
3663+
elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
3664+
push!(T_wrapperargs, LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue)))
3665+
if is_adjoint && i != 1
3666+
push!(ActiveRetTypes, Nothing)
3667+
end
36203668
else
36213669
error("calling convention should be annotated, got $T")
36223670
end
@@ -3799,7 +3847,23 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
37993847
if isghostty(T′) || Core.Compiler.isconstType(T′)
38003848
continue
38013849
end
3802-
push!(realparms, params[i])
3850+
3851+
isboxed = GPUCompiler.deserves_argbox(T′)
3852+
3853+
llty = value_type(params[i])
3854+
3855+
convty = convert(LLVMType, T′; allow_boxed=true)
3856+
3857+
if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType))
3858+
al = emit_allocobj!(builder, Base.RefValue{T′})
3859+
al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al))))
3860+
store!(builder, params[i], al)
3861+
al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived))
3862+
push!(realparms, al)
3863+
else
3864+
push!(realparms, params[i])
3865+
end
3866+
38033867
i += 1
38043868
if T <: Const
38053869
elseif T <: Active
@@ -3827,6 +3891,34 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
38273891
elseif T <: Duplicated || T <: DuplicatedNoNeed
38283892
push!(realparms, params[i])
38293893
i += 1
3894+
elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
3895+
parmsi = params[i]
3896+
3897+
if T <: BatchMixedDuplicated
3898+
if GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{T′}})
3899+
njlvalue = LLVM.ArrayType(Int(width), T_prjlvalue)
3900+
parmsi = bitcast!(builder, parmsi, LLVM.PointerType(njlvalue, addrspace(value_type(parmsi))))
3901+
parmsi = load!(builder, njlvalue, parmsi)
3902+
end
3903+
end
3904+
3905+
isboxed = GPUCompiler.deserves_argbox(T′)
3906+
3907+
resty = isboxed ? llty : LLVM.PointerType(llty, Derived)
3908+
3909+
ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, resty)))
3910+
for idx in 1:width
3911+
pv = (width == 1) ? parmsi : extract_value!(builder, parmsi, idx-1)
3912+
pv = bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv))))
3913+
pv = addrspacecast!(builder, pv, LLVM.PointerType(llty, Derived))
3914+
if isboxed
3915+
pv = load!(builder, llty, pv, "mixedboxload")
3916+
end
3917+
ival = (width == 1 ) ? pv : insert_value!(builder, ival, pv, idx-1)
3918+
end
3919+
3920+
push!(realparms, ival)
3921+
i += 1
38303922
elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed
38313923
isboxed = GPUCompiler.deserves_argbox(NTuple{width, T′})
38323924
val = params[i]
@@ -4357,6 +4449,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
43574449

43584450
# generate the wrapper function type & definition
43594451
wrapper_types = LLVM.LLVMType[]
4452+
wrapper_attrs = Vector{LLVM.Attribute}[]
43604453
_, sret, returnRoots = get_return_info(actualRetType)
43614454
sret_union = is_sret_union(actualRetType)
43624455

@@ -4391,31 +4484,44 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
43914484

43924485
if swiftself
43934486
push!(wrapper_types, value_type(parameters(entry_f)[1+sret+returnRoots]))
4487+
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("swiftself")])
43944488
end
43954489

43964490
boxedArgs = Set{Int}()
43974491
loweredArgs = Set{Int}()
4492+
raisedArgs = Set{Int}()
43984493

43994494
for arg in args
44004495
typ = arg.codegen.typ
44014496
if GPUCompiler.deserves_argbox(arg.typ)
44024497
push!(boxedArgs, arg.arg_i)
44034498
push!(wrapper_types, typ)
4499+
push!(wrapper_attrs, LLVM.Attribute[])
44044500
elseif arg.cc != GPUCompiler.BITS_REF
4405-
push!(wrapper_types, typ)
4501+
if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated)
4502+
push!(boxedArgs, arg.arg_i)
4503+
push!(raisedArgs, arg.arg_i)
4504+
push!(wrapper_types, LLVM.PointerType(typ, Derived))
4505+
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
4506+
else
4507+
push!(wrapper_types, typ)
4508+
push!(wrapper_attrs, LLVM.Attribute[])
4509+
end
44064510
else
44074511
# bits ref, and not boxed
4408-
# if TT.parameters[arg.arg_i] <: Const
4409-
# push!(boxedArgs, arg.arg_i)
4410-
# push!(wrapper_types, typ)
4411-
# else
4512+
if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated)
4513+
push!(boxedArgs, arg.arg_i)
4514+
push!(wrapper_types, typ)
4515+
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
4516+
else
44124517
push!(wrapper_types, eltype(typ))
4518+
push!(wrapper_attrs, LLVM.Attribute[])
44134519
push!(loweredArgs, arg.arg_i)
4414-
# end
4520+
end
44154521
end
44164522
end
44174523

4418-
if length(loweredArgs) == 0 && !sret && !sret_union
4524+
if length(loweredArgs) == 0 && length(raisedArgs) == 0 && !sret && !sret_union
44194525
return entry_f, returnRoots, boxedArgs, loweredArgs
44204526
end
44214527

@@ -4436,8 +4542,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
44364542
end
44374543
push!(function_attributes(wrapper_f), EnumAttribute("returns_twice"))
44384544
push!(function_attributes(entry_f), EnumAttribute("returns_twice"))
4439-
if swiftself
4440-
push!(parameter_attributes(wrapper_f, 1), EnumAttribute("swiftself"))
4545+
for (i, v) in enumerate(wrapper_attrs)
4546+
for attr in v
4547+
push!(parameter_attributes(wrapper_f, i), attr)
4548+
end
44414549
end
44424550

44434551
seen = TypeTreeTable()
@@ -4463,6 +4571,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
44634571
parm = ops[arg.codegen.i]
44644572
if arg.arg_i in loweredArgs
44654573
push!(nops, load!(builder, convert(LLVMType, arg.typ), parm))
4574+
elseif arg.arg_i in raisedArgs
4575+
obj = emit_allocobj!(builder, arg.typ)
4576+
bc = bitcast!(builder, obj, LLVM.PointerType(value_type(parm), addrspace(value_type(obj))))
4577+
store!(builder, parm, bc)
4578+
addr = addrspacecast!(builder, bc, LLVM.PointerType(value_type(parm), Derived))
4579+
push!(nops, addr)
44664580
else
44674581
push!(nops, parm)
44684582
end
@@ -4547,6 +4661,13 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
45474661
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(arg.typ, ctx, dl, seen))))
45484662
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ)))))
45494663
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
4664+
elseif arg.arg_i in raisedArgs
4665+
wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm)
4666+
ctx = LLVM.context(wrapparm)
4667+
push!(wrapper_args, wrapparm)
4668+
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(Base.RefValue{arg.typ}, ctx, dl, seen))))
4669+
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ)))))
4670+
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
45504671
else
45514672
push!(wrapper_args, wrapparm)
45524673
for attr in collect(parameter_attributes(entry_f, arg.codegen.i))
@@ -4626,6 +4747,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
46264747
elseif LLVM.return_type(entry_ft) == LLVM.VoidType()
46274748
ret!(builder)
46284749
else
4750+
ctx = LLVM.context(wrapper_f)
46294751
push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen))))
46304752
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType)))))
46314753
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
@@ -4687,7 +4809,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
46874809
if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0
46884810
msg = sprint() do io
46894811
println(io, string(mod))
4690-
println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction))
4812+
println(io, LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction))
46914813
println(io, string(wrapper_f))
46924814
println(io, "parmsRemoved=", parmsRemoved, " retRemoved=", retRemoved, " prargs=", prargs)
46934815
println(io, "Broken function")
@@ -5966,6 +6088,35 @@ end
59666088
push!(ActiveRetTypes, Nothing)
59676089
end
59686090
push!(ccexprs, argexpr)
6091+
elseif T <: MixedDuplicated
6092+
if RawCall
6093+
argexpr = argexprs[i]
6094+
i+=1
6095+
else
6096+
argexpr = Expr(:., expr, QuoteNode(:dval))
6097+
end
6098+
push!(types, Any)
6099+
if is_adjoint
6100+
push!(ActiveRetTypes, Nothing)
6101+
end
6102+
push!(ccexprs, argexpr)
6103+
elseif T <: BatchMixedDuplicated
6104+
if RawCall
6105+
argexpr = argexprs[i]
6106+
i+=1
6107+
else
6108+
argexpr = Expr(:., expr, QuoteNode(:dval))
6109+
end
6110+
isboxedvec = GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{source_typ}})
6111+
if isboxedvec
6112+
push!(types, Any)
6113+
else
6114+
push!(types, NTuple{width, Base.RefValue{source_typ}})
6115+
end
6116+
if is_adjoint
6117+
push!(ActiveRetTypes, Nothing)
6118+
end
6119+
push!(ccexprs, argexpr)
59696120
else
59706121
error("calling convention should be annotated, got $T")
59716122
end

0 commit comments

Comments
 (0)