@@ -4232,6 +4232,22 @@ end
42324232
42334233 wrapper_f = LLVM. Function (mod, safe_name (LLVM. name (llvmfn) * " mustwrap" ), FT)
42344234
4235+ for idx in 1 : length (collect (parameters (llvmfn)))
4236+ for attr in collect (parameter_attributes (llvmfn, idx))
4237+ push! (parameter_attributes (wrapper_f, idx), attr)
4238+ end
4239+ end
4240+
4241+ for attr in collect (function_attributes (llvmfn))
4242+ push! (function_attributes (wrapper_f), attr)
4243+ end
4244+
4245+ for attr in collect (return_attributes (llvmfn))
4246+ push! (return_attributes (wrapper_f), attr)
4247+ end
4248+
4249+ mi, rt = enzyme_custom_extract_mi (primalf)
4250+
42354251 let builder = IRBuilder ()
42364252 entry = BasicBlock (wrapper_f, " entry" )
42374253 position! (builder, entry)
@@ -4248,7 +4264,7 @@ end
42484264 else
42494265 EnumAttribute (" sret" )
42504266 end )
4251- for idx in length (collect (parameters (llvmfn)))
4267+ for idx in 1 : length (collect (parameters (llvmfn)))
42524268 for attr in collect (parameter_attributes (llvmfn, idx))
42534269 if kind (attr) == sretkind
42544270 LLVM. API. LLVMAddCallSiteAttribute (
@@ -4260,6 +4276,14 @@ end
42604276 end
42614277 end
42624278
4279+ _, _, returnRoots = get_return_info (rt)
4280+ returnRoots = returnRoots != = nothing
4281+ if returnRoots
4282+ attr = StringAttribute (" enzymejl_returnRoots" , " " )
4283+ push! (parameter_attributes (wrapper_f, 2 ), attr)
4284+ LLVM. API. LLVMAddCallSiteAttribute (res, LLVM. API. LLVMAttributeIndex (2 ), attr)
4285+ end
4286+
42634287 if LLVM. return_type (FT) == LLVM. VoidType ()
42644288 ret! (builder)
42654289 else
@@ -4270,7 +4294,6 @@ end
42704294 end
42714295 attributes = function_attributes (wrapper_f)
42724296 push! (attributes, StringAttribute (" enzymejl_world" , string (job. world)))
4273- mi, rt = enzyme_custom_extract_mi (primalf)
42744297 push! (
42754298 attributes,
42764299 StringAttribute (" enzymejl_mi" , string (convert (UInt, pointer_from_objref (mi)))),
0 commit comments