@@ -1155,6 +1155,12 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
11551155 parameter_attributes (f, arg. codegen. i),
11561156 StringAttribute (" enzymejl_parmtype_ref" , string (UInt (arg. cc))),
11571157 )
1158+ if arg. rooted_typ != = nothing
1159+ push! (
1160+ parameter_attributes (f, arg. codegen. i),
1161+ StringAttribute (" enzymejl_rooted_typ" , string (convert (UInt, unsafe_to_pointer (arg. rooted_typ))))
1162+ )
1163+ end
11581164
11591165 byref = arg. cc
11601166
@@ -1559,6 +1565,23 @@ function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType),
15591565 nothing
15601566 end
15611567 else
1568+ if Ty == Core. SimpleVector
1569+ @assert count === nothing
1570+ @assert isa (prev, LLVM. CallInst)
1571+ @assert LLVM. name (LLVM. called_operand (prev):: LLVM.Function ) == " julia.gc_alloc_obj"
1572+ sz = operands (prev)[2 ]
1573+ sz = sub! (B, sz, LLVM. ConstantInt (Int (sizeof (Ptr{Cvoid}))))
1574+ T_jlvalue = LLVM. StructType (LLVM. LLVMType[])
1575+ T_prjlvalue = LLVM. PointerType (T_jlvalue, Tracked)
1576+ prev = addrspacecast! (B, prev, LLVM. PointerType (T_jlvalue, Derived))
1577+ prev = bitcast! (B, prev, LLVM. PointerType (T_prjlvalue, Derived))
1578+ gep = LLVM. gep! (B, T_prjlvalue, prev, LLVM. Value[LLVM. ConstantInt (Int64 (1 ))])
1579+ zeroAll = false
1580+ atomic = true
1581+ zero_allocation (B, Any, T_prjlvalue, prev, LLVM. ConstantInt (sizeof (Ptr{Cvoid})), sz, zeroAll, atomic)
1582+ return
1583+ end
1584+
15621585 if fieldcount (Ty) == 0
15631586 error (" Error handling recursive stores for $Ty which has a fieldcount of 0" )
15641587 end
@@ -3895,7 +3918,7 @@ function recombine_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM
38953918 jltype = value_type (sret)
38963919 tracked = CountTrackedPointers (jltype)
38973920 @assert tracked. count > 0
3898- @assert ! tracked. all
3921+ @assert ! tracked. all " Not tracked.all, jltype ( $( string (jltype)) ) "
38993922 root_ty = convert (LLVMType, AnyArray (Int (tracked. count)))
39003923 move_sret_tofrom_roots! (builder, jltype, sret, root_ty, roots, RootPointerToSRetValue)
39013924end
@@ -3904,11 +3927,74 @@ function extract_roots_from_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, ro
39043927 jltype = value_type (sret)
39053928 tracked = CountTrackedPointers (jltype)
39063929 @assert tracked. count > 0
3907- @assert ! tracked. all
3930+ @assert ! tracked. all " Not tracked.all, jltype ( $( string (jltype)) ) "
39083931 root_ty = convert (LLVMType, AnyArray (Int (tracked. count)))
39093932 move_sret_tofrom_roots! (builder, jltype, sret, root_ty, roots, SRetValueToRootPointer)
39103933end
39113934
3935+ function copy_floats_into! (builder:: LLVM.IRBuilder , jltype:: LLVM.LLVMType , dst:: LLVM.Value , src:: LLVM.Value )
3936+ count = 0
3937+ todo = Tuple{Vector{Cuint},LLVM. LLVMType}[(
3938+ Cuint[],
3939+ jltype,
3940+ )]
3941+ function to_llvm (lst:: Vector{Cuint} )
3942+ vals = LLVM. Value[]
3943+ push! (vals, LLVM. ConstantInt (LLVM. IntType (64 ), 0 ))
3944+ for i in lst
3945+ push! (vals, LLVM. ConstantInt (LLVM. IntType (32 ), i))
3946+ end
3947+ return vals
3948+ end
3949+
3950+ extracted = LLVM. Value[]
3951+
3952+ while length (todo) != 0
3953+ path, ty = popfirst! (todo)
3954+
3955+ if isa (ty, LLVM. PointerType) || isa (ty, LLVM. IntegerType)
3956+ continue
3957+ end
3958+
3959+ if isa (ty, LLVM. FloatingPointType)
3960+ dstloc = inbounds_gep! (builder, jltype, dst, to_llvm (path), " dstloc" )
3961+ srcloc = inbounds_gep! (builder, jltype, src, to_llvm (path), " srcloc" )
3962+ val = load! (builder, ty, srcloc)
3963+ st = store! (builder, val, dstloc)
3964+ continue
3965+ end
3966+
3967+ if isa (ty, LLVM. ArrayType)
3968+ for i = 1 : length (ty)
3969+ npath = copy (path)
3970+ push! (npath, i - 1 )
3971+ push! (todo, (npath, eltype (ty)))
3972+ end
3973+ continue
3974+ end
3975+
3976+ if isa (ty, LLVM. VectorType)
3977+ for i = 1 : size (ty)
3978+ npath = copy (path)
3979+ push! (npath, i - 1 )
3980+ push! (todo, (npath, eltype (ty)))
3981+ end
3982+ continue
3983+ end
3984+
3985+ if isa (ty, LLVM. StructType)
3986+ for (i, t) in enumerate (LLVM. elements (ty))
3987+ npath = copy (path)
3988+ push! (npath, i - 1 )
3989+ push! (todo, (npath, t))
3990+ end
3991+ continue
3992+ end
3993+ end
3994+
3995+ return nothing
3996+ end
3997+
39123998
39133999# Modified from GPUCompiler/src/irgen.jl:365 lower_byval
39144000function lower_convention (
@@ -4314,6 +4400,15 @@ function lower_convention(
43144400 string (UInt (GPUCompiler. BITS_VALUE)),
43154401 ),
43164402 )
4403+ if arg. rooted_typ != = nothing
4404+ push! (
4405+ parameter_attributes (wrapper_f, wrapper_idx - 1 ),
4406+ StringAttribute (
4407+ " enzymejl_rooted_typ" ,
4408+ string (convert (UInt, unsafe_to_pointer (arg. rooted_typ))),
4409+ ),
4410+ )
4411+ end
43174412 elseif arg. arg_i in raisedArgs
43184413 wrapparm = load! (builder, convert (LLVMType, arg. typ), wrapparm)
43194414 ctx = LLVM. context (wrapparm)
@@ -4342,6 +4437,15 @@ function lower_convention(
43424437 string (UInt (GPUCompiler. BITS_REF)),
43434438 ),
43444439 )
4440+ if arg. rooted_typ != = nothing
4441+ push! (
4442+ parameter_attributes (wrapper_f, wrapper_idx - 1 ),
4443+ StringAttribute (
4444+ " enzymejl_rooted_typ" ,
4445+ string (convert (UInt, unsafe_to_pointer (arg. rooted_typ)))
4446+ ),
4447+ )
4448+ end
43454449 else
43464450 push! (wrapper_args, wrapparm)
43474451 for attr in collect (parameter_attributes (entry_f, arg. codegen. i))
0 commit comments