Skip to content

Commit f9dd728

Browse files
authored
1.12: Fix custom rule return rooting (#2804)
* 1.12: Fix custom rule return rooting * fix
1 parent 04781c8 commit f9dd728

File tree

2 files changed

+168
-19
lines changed

2 files changed

+168
-19
lines changed

src/compiler.jl

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3805,7 +3805,7 @@ end
38053805
NullifySRetValue = 4
38063806
)
38073807

3808-
function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, sret::LLVM.Value, root_ty::LLVM.LLVMType, rootRet::Union{LLVM.Value, Nothing}, direction::SRetRootMovement)
3808+
function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, sret::LLVM.Value, root_ty::LLVM.LLVMType, rootRet::Union{LLVM.Value, Nothing}, direction::SRetRootMovement; must_cache::Bool = false)
38093809
count = 0
38103810
todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
38113811
Cuint[],
@@ -3845,12 +3845,18 @@ function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType,
38453845
if direction == SRetPointerToRootPointer
38463846
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
38473847
outloc = load!(builder, ty, outloc)
3848+
if must_cache
3849+
API.SetMustCache!(outloc)
3850+
end
38483851
store!(builder, outloc, loc)
38493852
elseif direction == SRetValueToRootPointer
38503853
outloc = Enzyme.API.e_extract_value!(builder, sret, path)
38513854
store!(builder, outloc, loc)
38523855
elseif direction == RootPointerToSRetValue
38533856
loc = load!(builder, ty, loc)
3857+
if must_cache
3858+
API.SetMustCache!(loc)
3859+
end
38543860
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
38553861
elseif direction == NullifySRetValue
38563862
loc = unsafe_to_llvm(builder, nothing)
@@ -3914,13 +3920,13 @@ function nullify_rooted_values!(builder::LLVM.IRBuilder, sret::LLVM.Value)
39143920
move_sret_tofrom_roots!(builder, jltype, sret, root_ty, nothing, NullifySRetValue)
39153921
end
39163922

3917-
function recombine_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM.Value)::LLVM.Value
3923+
function recombine_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM.Value; must_cache::Bool=false)::LLVM.Value
39183924
jltype = value_type(sret)
39193925
tracked = CountTrackedPointers(jltype)
39203926
@assert tracked.count > 0
39213927
@assert !tracked.all "Not tracked.all, jltype ($(string(jltype)))"
39223928
root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
3923-
move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, RootPointerToSRetValue)
3929+
move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, RootPointerToSRetValue; must_cache)
39243930
end
39253931

39263932
function extract_roots_from_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM.Value)
@@ -3957,8 +3963,8 @@ function copy_floats_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::
39573963
end
39583964

39593965
if isa(ty, LLVM.FloatingPointType)
3960-
dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloc")
3961-
srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloc")
3966+
dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloc")
3967+
srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloc")
39623968
val = load!(builder, ty, srcloc)
39633969
st = store!(builder, val, dstloc)
39643970
continue
@@ -3995,6 +4001,67 @@ function copy_floats_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::
39954001
return nothing
39964002
end
39974003

4004+
function extract_nonjlvalues_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::LLVM.Value, src::LLVM.Value)
4005+
count = 0
4006+
todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
4007+
Cuint[],
4008+
jltype,
4009+
)]
4010+
function to_llvm(lst::Vector{Cuint})
4011+
vals = LLVM.Value[]
4012+
push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
4013+
for i in lst
4014+
push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
4015+
end
4016+
return vals
4017+
end
4018+
4019+
extracted = LLVM.Value[]
4020+
4021+
while length(todo) != 0
4022+
path, ty = popfirst!(todo)
4023+
4024+
if isa(ty, LLVM.PointerType)
4025+
if any_jltypes(ty)
4026+
continue
4027+
end
4028+
end
4029+
4030+
if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
4031+
for i = 1:length(ty)
4032+
npath = copy(path)
4033+
push!(npath, i - 1)
4034+
push!(todo, (npath, eltype(ty)))
4035+
end
4036+
continue
4037+
end
4038+
4039+
if isa(ty, LLVM.VectorType) && any_jltypes(ty)
4040+
for i = 1:size(ty)
4041+
npath = copy(path)
4042+
push!(npath, i - 1)
4043+
push!(todo, (npath, eltype(ty)))
4044+
end
4045+
continue
4046+
end
4047+
4048+
if isa(ty, LLVM.StructType) && any_jltypes(ty)
4049+
for (i, t) in enumerate(LLVM.elements(ty))
4050+
npath = copy(path)
4051+
push!(npath, i - 1)
4052+
push!(todo, (npath, t))
4053+
end
4054+
continue
4055+
end
4056+
4057+
dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloc")
4058+
val = Enzyme.API.e_extract_value!(builder, src, path)
4059+
st = store!(builder, val, dstloc)
4060+
end
4061+
4062+
return nothing
4063+
end
4064+
39984065

39994066
# Modified from GPUCompiler/src/irgen.jl:365 lower_byval
40004067
function lower_convention(

src/rules/customrules.jl

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,9 @@ end
971971
end
972972
LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1), attr)
973973
res = load!(B, sty, sret)
974+
if returnRoots !== nothing && VERSION >= v"1.12"
975+
res = recombine_value!(B, res, returnRoots; must_cache=true)
976+
end
974977
end
975978
if swiftself
976979
attr = EnumAttribute("swiftself")
@@ -1000,9 +1003,19 @@ end
10001003
if RT <: Const
10011004
if needsPrimal
10021005
@assert RealRt == fwd_RT
1003-
if get_return_info(RealRt)[2] !== nothing
1006+
_, prim_sret, prim_roots = get_return_info(RealRt)
1007+
if prim_sret !== nothing
10041008
val = new_from_original(gutils, operands(orig)[1])
1005-
store!(B, res, val)
1009+
1010+
if prim_roots !== nothing && VERSION >= v"1.12"
1011+
extract_nonjlvalues_into!(B, value_type(res), val, res)
1012+
1013+
rval = new_from_original(gutils, operands(orig)[2])
1014+
1015+
extract_roots_from_value!(B, res, rval)
1016+
else
1017+
store!(B, res, val)
1018+
end
10061019
else
10071020
normalV = res.ref
10081021
end
@@ -1016,12 +1029,28 @@ end
10161029
ST = NTuple{Int(width),ST}
10171030
end
10181031
@assert ST == fwd_RT
1019-
if get_return_info(RealRt)[2] !== nothing
1032+
_, prim_sret, prim_roots = get_return_info(RealRt)
1033+
if prim_sret !== nothing
10201034
dval_ptr = invert_pointer(gutils, operands(orig)[1], B)
1021-
for idx = 1:width
1035+
1036+
droots = if prim_roots !== nothing && VERSION >= v"1.12"
1037+
@assert !is_constant_value(gutils, operands(orig)[2])
1038+
invert_pointer(gutils, operands(orig)[2], B)
1039+
end
1040+
1041+
for idx = 1:width
10221042
ev = (width == 1) ? dval : extract_value!(B, dval, idx - 1)
10231043
pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx - 1)
1024-
store!(B, res, pev)
1044+
1045+
if prim_roots !== nothing && VERSION >= v"1.12"
1046+
extract_nonjlvalues_into!(B, value_type(ev), pev, ev)
1047+
1048+
rval = (width == 1) ? droots : extract_value!(B, droots, idx - 1)
1049+
1050+
extract_roots_from_value!(B, ev, rval)
1051+
else
1052+
store!(B, ev, pev)
1053+
end
10251054
end
10261055
else
10271056
shadowV = res.ref
@@ -1033,16 +1062,42 @@ end
10331062
BatchDuplicated{RealRt,Int(width)}
10341063
end
10351064
@assert ST == fwd_RT
1036-
if get_return_info(RealRt)[2] !== nothing
1065+
1066+
_, prim_sret, prim_roots = get_return_info(RealRt)
1067+
if prim_sret !== nothing
10371068
val = new_from_original(gutils, operands(orig)[1])
1038-
store!(B, extract_value!(B, res, 0), val)
1069+
1070+
res0 = extract_value!(B, res, 0)
1071+
if prim_roots !== nothing && VERSION >= v"1.12"
1072+
extract_nonjlvalues_into!(B, value_type(res0), val, res0)
1073+
1074+
rval = new_from_original(gutils, operands(orig)[2])
1075+
1076+
extract_roots_from_value!(B, res0, rval)
1077+
else
1078+
store!(B, res0, val)
1079+
end
10391080

10401081
dval_ptr = invert_pointer(gutils, operands(orig)[1], B)
10411082
dval = extract_value!(B, res, 1)
1042-
for idx = 1:width
1083+
1084+
droots = if prim_roots !== nothing && VERSION >= v"1.12"
1085+
@assert !is_constant_value(gutils, operands(orig)[2])
1086+
invert_pointer(gutils, operands(orig)[2], B)
1087+
end
1088+
1089+
for idx = 1:width
10431090
ev = (width == 1) ? dval : extract_value!(B, dval, idx - 1)
10441091
pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx - 1)
1045-
store!(B, ev, pev)
1092+
if prim_roots !== nothing && VERSION >= v"1.12"
1093+
extract_nonjlvalues_into!(B, value_type(ev), pev, ev)
1094+
1095+
rval = (width == 1) ? droots : extract_value!(B, droots, idx - 1)
1096+
1097+
extract_roots_from_value!(B, ev, rval)
1098+
else
1099+
store!(B, ev, pev)
1100+
end
10461101
end
10471102
else
10481103
normalV = extract_value!(B, res, 0).ref
@@ -1781,6 +1836,9 @@ function enzyme_custom_common_rev(
17811836
)
17821837
res = load!(B, sty, sret)
17831838
API.SetMustCache!(res)
1839+
if returnRoots !== nothing && VERSION >= v"1.12"
1840+
res = recombine_value!(B, res, returnRoots; must_cache=true)
1841+
end
17841842
end
17851843
if swiftself
17861844
attr = EnumAttribute("swiftself")
@@ -1888,9 +1946,19 @@ function enzyme_custom_common_rev(
18881946
if needsPrimal
18891947
@assert !isghostty(RealRt)
18901948
normalV = extract_value!(B, resV, idx)
1891-
if get_return_info(RealRt)[2] !== nothing
1949+
_, prim_sret, prim_roots = get_return_info(RealRt)
1950+
if prim_sret !== nothing
18921951
val = new_from_original(gutils, operands(orig)[1])
1893-
store!(B, normalV, val)
1952+
1953+
if prim_roots !== nothing && VERSION >= v"1.12"
1954+
extract_nonjlvalues_into!(B, value_type(normalV), val, normalV)
1955+
1956+
rval = new_from_original(gutils, operands(orig)[2])
1957+
1958+
extract_roots_from_value!(B, normalV, rval)
1959+
else
1960+
store!(B, normalV, val)
1961+
end
18941962
else
18951963
@assert value_type(normalV) == value_type(orig)
18961964
normalV = normalV.ref
@@ -1901,16 +1969,30 @@ function enzyme_custom_common_rev(
19011969
if needsShadowJL
19021970
@assert !isghostty(RealRt)
19031971
shadowV = extract_value!(B, resV, idx)
1904-
if get_return_info(RealRt)[2] !== nothing
1972+
_, prim_sret, prim_roots = get_return_info(RealRt)
1973+
if prim_sret !== nothing
19051974
dval = invert_pointer(gutils, operands(orig)[1], B)
19061975

1907-
for idx = 1:width
1976+
droots = if prim_roots !== nothing && VERSION >= v"1.12"
1977+
@assert !is_constant_value(gutils, operands(orig)[2])
1978+
invert_pointer(gutils, operands(orig)[2], B)
1979+
end
1980+
1981+
for idx = 1:width
19081982
to_store =
19091983
(width == 1) ? shadowV : extract_value!(B, shadowV, idx - 1)
19101984

19111985
store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx - 1)
19121986

1913-
store!(B, to_store, store_ptr)
1987+
if prim_roots !== nothing && VERSION >= v"1.12"
1988+
extract_nonjlvalues_into!(B, value_type(to_store), store_ptr, to_store)
1989+
1990+
rval = (width == 1) ? droots : extract_value!(B, droots, idx - 1)
1991+
1992+
extract_roots_from_value!(B, to_store, rval)
1993+
else
1994+
store!(B, to_store, store_ptr)
1995+
end
19141996
end
19151997
shadowV = C_NULL
19161998
else

0 commit comments

Comments
 (0)