Skip to content

Commit 8b65381

Browse files
authored
Save julia types on sret (#2127)
* Save julia types on sret * fix * lig
1 parent e9d303b commit 8b65381

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

src/absint.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,16 @@ function abs_typeof(
305305
end
306306
end
307307

308+
if isa(arg, LLVM.AllocaInst) || isa(arg, LLVM.CallInst)
309+
if haskey(metadata(arg), "enzymejl_allocart")
310+
mds = operands(metadata(arg)["enzymejl_allocart"])[1]::MDString
311+
mds = Base.convert(String, mds)
312+
ptr = reinterpret(Ptr{Cvoid}, parse(UInt, mds))
313+
RT = Base.unsafe_pointer_to_objref(ptr)
314+
return (true, RT, GPUCompiler.MUT_REF)
315+
end
316+
end
317+
308318
if isa(arg, LLVM.CallInst)
309319
fn = LLVM.called_operand(arg)
310320
nm = ""

src/compiler.jl

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,6 +1727,7 @@ end
17271727
else
17281728
"Unknown object of type" * " " * string(TT)
17291729
end
1730+
@assert !illegal
17301731
illegalVal = cur
17311732
illegal = true
17321733
return make_batched(ncur, prevbb)
@@ -1770,6 +1771,7 @@ end
17701771
end
17711772

17721773
cur2 = if changed
1774+
@assert !illegal
17731775
illegalVal = cur
17741776
illegal = true
17751777
# TODO replace with correct insertions/splats
@@ -1942,8 +1944,10 @@ end
19421944
return make_batched(ncur, prevbb)
19431945
end
19441946

1945-
illegal = true
1946-
illegalVal = cur
1947+
if !illegal
1948+
illegal = true
1949+
illegalVal = cur
1950+
end
19471951
return ncur
19481952
end
19491953

@@ -7070,10 +7074,48 @@ end
70707074
ctx = LLVM.context(mod)
70717075
for f in functions(mod), bb in blocks(f), inst in instructions(bb)
70727076
fn = isa(inst, LLVM.CallInst) ? LLVM.called_operand(inst) : nothing
7077+
7078+
if !API.HasFromStack(inst) && isa(inst, LLVM.AllocaInst)
7079+
7080+
calluse = nothing
7081+
for u in LLVM.uses(inst)
7082+
u = LLVM.user(u)
7083+
if isa(u, LLVM.CallInst) && operands(u)[1] == inst
7084+
7085+
sretkind = kind(if LLVM.version().major >= 12
7086+
TypeAttribute("sret", LLVM.Int32Type())
7087+
else
7088+
EnumAttribute("sret")
7089+
end)
7090+
hassret = false
7091+
llvmfn = LLVM.called_operand(u)
7092+
if llvmfn isa LLVM.Function
7093+
for attr in collect(parameter_attributes(llvmfn, 1))
7094+
if kind(attr) == sretkind
7095+
hassret = true
7096+
break
7097+
end
7098+
end
7099+
end
7100+
if hassret
7101+
calluse = u
7102+
end
7103+
end
7104+
end
7105+
if calluse isa LLVM.CallInst
7106+
_, RT = enzyme_custom_extract_mi(calluse, false)
7107+
if RT !== nothing
7108+
llrt, sret, returnRoots = get_return_info(RT)
7109+
if !(sret isa Nothing) && !is_sret_union(RT)
7110+
metadata(inst)["enzymejl_allocart"] = MDNode(LLVM.Metadata[MDString(string(convert(UInt, unsafe_to_pointer(RT))))])
7111+
end
7112+
end
7113+
end
7114+
end
70737115

70747116
if !API.HasFromStack(inst) &&
70757117
((isa(inst, LLVM.CallInst) &&
7076-
(!isa(fn, LLVM.Function) || isempty(blocks(fn))) ) || isa(inst, LLVM.LoadInst))
7118+
(!isa(fn, LLVM.Function) || isempty(blocks(fn))) ) || isa(inst, LLVM.LoadInst) || isa(inst, LLVM.AllocaInst))
70777119
legal, source_typ, byref = abs_typeof(inst)
70787120
codegen_typ = value_type(inst)
70797121
if legal

0 commit comments

Comments
 (0)