@@ -2450,6 +2450,50 @@ else
24502450 end
24512451end
24522452
2453+ function store_nonjl_types! (B, startval, p)
2454+ T_jlvalue = LLVM. StructType (LLVMType[])
2455+ T_prjlvalue = LLVM. PointerType (T_jlvalue, Tracked)
2456+ vals = LLVM. Value[]
2457+ if p != nothing
2458+ push! (vals, p)
2459+ end
2460+ todo = Tuple{Tuple, LLVM. Value}[((), startval)]
2461+ while length (todo) != 0
2462+ path, cur = popfirst! (todo)
2463+ ty = value_type (cur)
2464+ if isa (ty, LLVM. PointerType)
2465+ if any_jltypes (ty)
2466+ continue
2467+ end
2468+ end
2469+ if isa (ty, LLVM. ArrayType)
2470+ if any_jltypes (ty)
2471+ for i= 1 : length (ty)
2472+ ev = extract_value! (B, cur, i- 1 )
2473+ push! (todo, ((path... , i- 1 ), ev))
2474+ end
2475+ continue
2476+ end
2477+ end
2478+ if isa (ty, LLVM. StructType)
2479+ if any_jltypes (ty)
2480+ for (i, t) in enumerate (LLVM. elements (ty))
2481+ ev = extract_value! (B, cur, i- 1 )
2482+ push! (todo, ((path... , i- 1 ), ev))
2483+ end
2484+ continue
2485+ end
2486+ end
2487+ parray = LLVM. Value[LLVM. ConstantInt (LLVM. IntType (64 ), 0 )]
2488+ for v in path
2489+ push! (parray, LLVM. ConstantInt (LLVM. IntType (32 ), v))
2490+ end
2491+ gptr = gep! (B, value_type (startval), p, parray)
2492+ st = store! (B, cur, gptr)
2493+ end
2494+ return
2495+ end
2496+
24532497function get_julia_inner_types (B, p, startvals... ; added= [])
24542498 T_jlvalue = LLVM. StructType (LLVMType[])
24552499 T_prjlvalue = LLVM. PointerType (T_jlvalue, Tracked)
@@ -3404,7 +3448,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
34043448 else
34053449 push! (args_activity, API. DFT_OUT_DIFF)
34063450 end
3407- elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc
3451+ elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc || T <: MixedDuplicated || T <: BatchMixedDuplicated
34083452 push! (args_activity, API. DFT_DUP_ARG)
34093453 elseif T <: DuplicatedNoNeed || T<: BatchDuplicatedNoNeed
34103454 push! (args_activity, API. DFT_DUP_NONEED)
@@ -3588,7 +3632,6 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
35883632
35893633 isboxed = GPUCompiler. deserves_argbox (source_typ)
35903634 llvmT = isboxed ? T_prjlvalue : convert (LLVMType, source_typ)
3591-
35923635 push! (T_wrapperargs, llvmT)
35933636
35943637 if T <: Const || T <: BatchDuplicatedFunc
@@ -3617,6 +3660,11 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
36173660 if is_adjoint && i != 1
36183661 push! (ActiveRetTypes, Nothing)
36193662 end
3663+ elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
3664+ push! (T_wrapperargs, LLVM. LLVMType (API. EnzymeGetShadowType (width, T_prjlvalue)))
3665+ if is_adjoint && i != 1
3666+ push! (ActiveRetTypes, Nothing)
3667+ end
36203668 else
36213669 error (" calling convention should be annotated, got $T " )
36223670 end
@@ -3799,7 +3847,23 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
37993847 if isghostty (T′) || Core. Compiler. isconstType (T′)
38003848 continue
38013849 end
3802- push! (realparms, params[i])
3850+
3851+ isboxed = GPUCompiler. deserves_argbox (T′)
3852+
3853+ llty = value_type (params[i])
3854+
3855+ convty = convert (LLVMType, T′; allow_boxed= true )
3856+
3857+ if (T <: MixedDuplicated || T <: BatchMixedDuplicated ) && ! isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType))
3858+ al = emit_allocobj! (builder, Base. RefValue{T′})
3859+ al = bitcast! (builder, al, LLVM. PointerType (llty, addrspace (value_type (al))))
3860+ store! (builder, params[i], al)
3861+ al = addrspacecast! (builder, al, LLVM. PointerType (llty, Derived))
3862+ push! (realparms, al)
3863+ else
3864+ push! (realparms, params[i])
3865+ end
3866+
38033867 i += 1
38043868 if T <: Const
38053869 elseif T <: Active
@@ -3827,6 +3891,34 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
38273891 elseif T <: Duplicated || T <: DuplicatedNoNeed
38283892 push! (realparms, params[i])
38293893 i += 1
3894+ elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
3895+ parmsi = params[i]
3896+
3897+ if T <: BatchMixedDuplicated
3898+ if GPUCompiler. deserves_argbox (NTuple{width, Base. RefValue{T′}})
3899+ njlvalue = LLVM. ArrayType (Int (width), T_prjlvalue)
3900+ parmsi = bitcast! (builder, parmsi, LLVM. PointerType (njlvalue, addrspace (value_type (parmsi))))
3901+ parmsi = load! (builder, njlvalue, parmsi)
3902+ end
3903+ end
3904+
3905+ isboxed = GPUCompiler. deserves_argbox (T′)
3906+
3907+ resty = isboxed ? llty : LLVM. PointerType (llty, Derived)
3908+
3909+ ival = UndefValue (LLVM. LLVMType (API. EnzymeGetShadowType (width, resty)))
3910+ for idx in 1 : width
3911+ pv = (width == 1 ) ? parmsi : extract_value! (builder, parmsi, idx- 1 )
3912+ pv = bitcast! (builder, pv, LLVM. PointerType (llty, addrspace (value_type (pv))))
3913+ pv = addrspacecast! (builder, pv, LLVM. PointerType (llty, Derived))
3914+ if isboxed
3915+ pv = load! (builder, llty, pv, " mixedboxload" )
3916+ end
3917+ ival = (width == 1 ) ? pv : insert_value! (builder, ival, pv, idx- 1 )
3918+ end
3919+
3920+ push! (realparms, ival)
3921+ i += 1
38303922 elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed
38313923 isboxed = GPUCompiler. deserves_argbox (NTuple{width, T′})
38323924 val = params[i]
@@ -4357,6 +4449,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
43574449
43584450 # generate the wrapper function type & definition
43594451 wrapper_types = LLVM. LLVMType[]
4452+ wrapper_attrs = Vector{LLVM. Attribute}[]
43604453 _, sret, returnRoots = get_return_info (actualRetType)
43614454 sret_union = is_sret_union (actualRetType)
43624455
@@ -4391,31 +4484,44 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
43914484
43924485 if swiftself
43934486 push! (wrapper_types, value_type (parameters (entry_f)[1 + sret+ returnRoots]))
4487+ push! (wrapper_attrs, LLVM. Attribute[EnumAttribute (" swiftself" )])
43944488 end
43954489
43964490 boxedArgs = Set {Int} ()
43974491 loweredArgs = Set {Int} ()
4492+ raisedArgs = Set {Int} ()
43984493
43994494 for arg in args
44004495 typ = arg. codegen. typ
44014496 if GPUCompiler. deserves_argbox (arg. typ)
44024497 push! (boxedArgs, arg. arg_i)
44034498 push! (wrapper_types, typ)
4499+ push! (wrapper_attrs, LLVM. Attribute[])
44044500 elseif arg. cc != GPUCompiler. BITS_REF
4405- push! (wrapper_types, typ)
4501+ if TT != nothing && (TT. parameters[arg. arg_i] <: MixedDuplicated || TT. parameters[arg. arg_i] <: BatchMixedDuplicated )
4502+ push! (boxedArgs, arg. arg_i)
4503+ push! (raisedArgs, arg. arg_i)
4504+ push! (wrapper_types, LLVM. PointerType (typ, Derived))
4505+ push! (wrapper_attrs, LLVM. Attribute[EnumAttribute (" noalias" )])
4506+ else
4507+ push! (wrapper_types, typ)
4508+ push! (wrapper_attrs, LLVM. Attribute[])
4509+ end
44064510 else
44074511 # bits ref, and not boxed
4408- # if TT.parameters[arg.arg_i] <: Const
4409- # push!(boxedArgs, arg.arg_i)
4410- # push!(wrapper_types, typ)
4411- # else
4512+ if TT != nothing && (TT. parameters[arg. arg_i] <: MixedDuplicated || TT. parameters[arg. arg_i] <: BatchMixedDuplicated )
4513+ push! (boxedArgs, arg. arg_i)
4514+ push! (wrapper_types, typ)
4515+ push! (wrapper_attrs, LLVM. Attribute[EnumAttribute (" noalias" )])
4516+ else
44124517 push! (wrapper_types, eltype (typ))
4518+ push! (wrapper_attrs, LLVM. Attribute[])
44134519 push! (loweredArgs, arg. arg_i)
4414- # end
4520+ end
44154521 end
44164522 end
44174523
4418- if length (loweredArgs) == 0 && ! sret && ! sret_union
4524+ if length (loweredArgs) == 0 && length (raisedArgs) == 0 && ! sret && ! sret_union
44194525 return entry_f, returnRoots, boxedArgs, loweredArgs
44204526 end
44214527
@@ -4436,8 +4542,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
44364542 end
44374543 push! (function_attributes (wrapper_f), EnumAttribute (" returns_twice" ))
44384544 push! (function_attributes (entry_f), EnumAttribute (" returns_twice" ))
4439- if swiftself
4440- push! (parameter_attributes (wrapper_f, 1 ), EnumAttribute (" swiftself" ))
4545+ for (i, v) in enumerate (wrapper_attrs)
4546+ for attr in v
4547+ push! (parameter_attributes (wrapper_f, i), attr)
4548+ end
44414549 end
44424550
44434551 seen = TypeTreeTable ()
@@ -4463,6 +4571,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
44634571 parm = ops[arg. codegen. i]
44644572 if arg. arg_i in loweredArgs
44654573 push! (nops, load! (builder, convert (LLVMType, arg. typ), parm))
4574+ elseif arg. arg_i in raisedArgs
4575+ obj = emit_allocobj! (builder, arg. typ)
4576+ bc = bitcast! (builder, obj, LLVM. PointerType (value_type (parm), addrspace (value_type (obj))))
4577+ store! (builder, parm, bc)
4578+ addr = addrspacecast! (builder, bc, LLVM. PointerType (value_type (parm), Derived))
4579+ push! (nops, addr)
44664580 else
44674581 push! (nops, parm)
44684582 end
@@ -4547,6 +4661,13 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
45474661 push! (parameter_attributes (wrapper_f, arg. codegen. i- sret- returnRoots), StringAttribute (" enzyme_type" , string (typetree (arg. typ, ctx, dl, seen))))
45484662 push! (parameter_attributes (wrapper_f, arg. codegen. i- sret- returnRoots), StringAttribute (" enzymejl_parmtype" , string (convert (UInt, unsafe_to_pointer (arg. typ)))))
45494663 push! (parameter_attributes (wrapper_f, arg. codegen. i- sret- returnRoots), StringAttribute (" enzymejl_parmtype_ref" , string (UInt (GPUCompiler. BITS_REF))))
4664+ elseif arg. arg_i in raisedArgs
4665+ wrapparm = load! (builder, convert (LLVMType, arg. typ), wrapparm)
4666+ ctx = LLVM. context (wrapparm)
4667+ push! (wrapper_args, wrapparm)
4668+ push! (parameter_attributes (wrapper_f, arg. codegen. i- sret- returnRoots), StringAttribute (" enzyme_type" , string (typetree (Base. RefValue{arg. typ}, ctx, dl, seen))))
4669+ push! (parameter_attributes (wrapper_f, arg. codegen. i- sret- returnRoots), StringAttribute (" enzymejl_parmtype" , string (convert (UInt, unsafe_to_pointer (arg. typ)))))
4670+ push! (parameter_attributes (wrapper_f, arg. codegen. i- sret- returnRoots), StringAttribute (" enzymejl_parmtype_ref" , string (UInt (GPUCompiler. BITS_REF))))
45504671 else
45514672 push! (wrapper_args, wrapparm)
45524673 for attr in collect (parameter_attributes (entry_f, arg. codegen. i))
@@ -4626,6 +4747,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
46264747 elseif LLVM. return_type (entry_ft) == LLVM. VoidType ()
46274748 ret! (builder)
46284749 else
4750+ ctx = LLVM. context (wrapper_f)
46294751 push! (return_attributes (wrapper_f), StringAttribute (" enzyme_type" , string (typetree (actualRetType, ctx, dl, seen))))
46304752 push! (return_attributes (wrapper_f), StringAttribute (" enzymejl_parmtype" , string (convert (UInt, unsafe_to_pointer (actualRetType)))))
46314753 push! (return_attributes (wrapper_f), StringAttribute (" enzymejl_parmtype_ref" , string (UInt (GPUCompiler. BITS_REF))))
@@ -4687,7 +4809,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
46874809 if LLVM. API. LLVMVerifyFunction (wrapper_f, LLVM. API. LLVMReturnStatusAction) != 0
46884810 msg = sprint () do io
46894811 println (io, string (mod))
4690- println (io, LVM . API. LLVMVerifyFunction (wrapper_f, LLVM. API. LLVMPrintMessageAction))
4812+ println (io, LLVM . API. LLVMVerifyFunction (wrapper_f, LLVM. API. LLVMPrintMessageAction))
46914813 println (io, string (wrapper_f))
46924814 println (io, " parmsRemoved=" , parmsRemoved, " retRemoved=" , retRemoved, " prargs=" , prargs)
46934815 println (io, " Broken function" )
@@ -5966,6 +6088,35 @@ end
59666088 push! (ActiveRetTypes, Nothing)
59676089 end
59686090 push! (ccexprs, argexpr)
6091+ elseif T <: MixedDuplicated
6092+ if RawCall
6093+ argexpr = argexprs[i]
6094+ i+= 1
6095+ else
6096+ argexpr = Expr (:., expr, QuoteNode (:dval ))
6097+ end
6098+ push! (types, Any)
6099+ if is_adjoint
6100+ push! (ActiveRetTypes, Nothing)
6101+ end
6102+ push! (ccexprs, argexpr)
6103+ elseif T <: BatchMixedDuplicated
6104+ if RawCall
6105+ argexpr = argexprs[i]
6106+ i+= 1
6107+ else
6108+ argexpr = Expr (:., expr, QuoteNode (:dval ))
6109+ end
6110+ isboxedvec = GPUCompiler. deserves_argbox (NTuple{width, Base. RefValue{source_typ}})
6111+ if isboxedvec
6112+ push! (types, Any)
6113+ else
6114+ push! (types, NTuple{width, Base. RefValue{source_typ}})
6115+ end
6116+ if is_adjoint
6117+ push! (ActiveRetTypes, Nothing)
6118+ end
6119+ push! (ccexprs, argexpr)
59696120 else
59706121 error (" calling convention should be annotated, got $T " )
59716122 end
0 commit comments