Skip to content

Commit 4ae6187

Browse files
authored
Fix attribute copy (#2407)
* Fix attribute copy * mark roots * fix returnroots
1 parent 6584816 commit 4ae6187

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

src/compiler.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4232,6 +4232,22 @@ end
42324232

42334233
wrapper_f = LLVM.Function(mod, safe_name(LLVM.name(llvmfn) * "mustwrap"), FT)
42344234

4235+
for idx in 1:length(collect(parameters(llvmfn)))
4236+
for attr in collect(parameter_attributes(llvmfn, idx))
4237+
push!(parameter_attributes(wrapper_f, idx), attr)
4238+
end
4239+
end
4240+
4241+
for attr in collect(function_attributes(llvmfn))
4242+
push!(function_attributes(wrapper_f), attr)
4243+
end
4244+
4245+
for attr in collect(return_attributes(llvmfn))
4246+
push!(return_attributes(wrapper_f), attr)
4247+
end
4248+
4249+
mi, rt = enzyme_custom_extract_mi(primalf)
4250+
42354251
let builder = IRBuilder()
42364252
entry = BasicBlock(wrapper_f, "entry")
42374253
position!(builder, entry)
@@ -4248,7 +4264,7 @@ end
42484264
else
42494265
EnumAttribute("sret")
42504266
end)
4251-
for idx in length(collect(parameters(llvmfn)))
4267+
for idx in 1:length(collect(parameters(llvmfn)))
42524268
for attr in collect(parameter_attributes(llvmfn, idx))
42534269
if kind(attr) == sretkind
42544270
LLVM.API.LLVMAddCallSiteAttribute(
@@ -4260,6 +4276,14 @@ end
42604276
end
42614277
end
42624278

4279+
_, _, returnRoots = get_return_info(rt)
4280+
returnRoots = returnRoots !== nothing
4281+
if returnRoots
4282+
attr = StringAttribute("enzymejl_returnRoots", "")
4283+
push!(parameter_attributes(wrapper_f, 2), attr)
4284+
LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(2), attr)
4285+
end
4286+
42634287
if LLVM.return_type(FT) == LLVM.VoidType()
42644288
ret!(builder)
42654289
else
@@ -4270,7 +4294,6 @@ end
42704294
end
42714295
attributes = function_attributes(wrapper_f)
42724296
push!(attributes, StringAttribute("enzymejl_world", string(job.world)))
4273-
mi, rt = enzyme_custom_extract_mi(primalf)
42744297
push!(
42754298
attributes,
42764299
StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi)))),

src/compiler/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,18 @@ Base.@assume_effects :removable :foldable :nothrow function has_fn_attr(fn::LLVM
400400
return false
401401
end
402402

403+
Base.@assume_effects :removable :foldable :nothrow function has_arg_attr(fn::LLVM.Function, i::Int, attr::LLVM.StringAttribute)::Bool
404+
ekind = LLVM.kind(attr)
405+
for attr in collect(parameter_attributes(fn, i))
406+
if attr isa LLVM.StringAttribute
407+
if kind(attr) == ekind
408+
return true
409+
end
410+
end
411+
end
412+
return false
413+
end
414+
403415
function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction))
404416
@static if isdefined(LLVM, Symbol("erase!"))
405417
LLVM.erase!(inst)

src/rules/customrules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1528,7 +1528,7 @@ end
15281528
fop = called_operand(orig)::LLVM.Function
15291529
for (i, v) in enumerate(operands(orig)[1:end-1])
15301530
if v == val
1531-
if !has_fn_attr(fop, StringAttribute("enzymejl_returnRoots"))
1531+
if !has_arg_attr(fop, i, StringAttribute("enzymejl_returnRoots"))
15321532
non_rooting_use = true
15331533
break
15341534
end

0 commit comments

Comments
 (0)