@@ -115,14 +115,14 @@ Base.unsigned(x::LLVMPtr) = UInt(x)
115115Base. signed (x:: LLVMPtr ) = Int (x)
116116
117117
118- # pointer type-preserving ccall
118+ # type-preserving ccall
119119
120120@generated function _typed_llvmcall (:: Val{intr} , rettyp, argtt, args... ) where {intr}
121121 # make types available for direct use in this generator
122122 rettyp = rettyp. parameters[1 ]
123123 argtt = argtt. parameters[1 ]
124124 argtyps = DataType[argtt. parameters... ]
125- argexprs = Expr [:(args[$ i]) for i in 1 : length (args)]
125+ argexprs = Any [:(args[$ i]) for i in 1 : length (args)]
126126
127127 # build IR that calls the intrinsic, casting types if necessary
128128 @dispose ctx= Context () begin
@@ -140,34 +140,77 @@ Base.signed(x::LLVMPtr) = Int(x)
140140 # reconstruct those so that we can accurately look up intrinsics.
141141 T_actual_args = LLVMType[]
142142 actual_args = LLVM. Value[]
143- for (arg, argtyp) in zip (parameters (llvm_f),argtyps)
143+ for (i, (arg, argtyp, argval)) in enumerate (zip (parameters (llvm_f), argtyps, args))
144+ # if the value is a Val, we'll try to emit it as a constant
145+ const_arg = if argval <: Val
146+ # also pass the actual value for the fallback path (and to simplify
147+ # construction of the LLVM function, where we can ignore constants)
148+ argexprs[i] = argval. parameters[1 ]
149+
150+ argval. parameters[1 ]
151+ else
152+ nothing
153+ end
154+
144155 if argtyp <: LLVMPtr
145156 # passed as i8*
146157 T,AS = argtyp. parameters
147158 actual_typ = LLVM. PointerType (convert (LLVMType, T; ctx), AS)
148- actual_arg = bitcast! (builder, arg, actual_typ)
159+ actual_arg = if const_arg == C_NULL
160+ LLVM. PointerNull (actual_typ)
161+ elseif const_arg != = nothing
162+ intptr = LLVM. ConstantInt (LLVM. Int64Type (ctx), Int (const_arg))
163+ const_inttoptr (intptr, actual_typ)
164+ else
165+ bitcast! (builder, arg, actual_typ)
166+ end
149167 elseif argtyp <: Ptr
150168 # passed as i64
151169 T = eltype (argtyp)
152170 actual_typ = LLVM. PointerType (convert (LLVMType, T; ctx))
171+ actual_arg = if const_arg == C_NULL
172+ LLVM. PointerNull (actual_typ)
173+ elseif const_arg != = nothing
174+ intptr = LLVM. ConstantInt (LLVM. Int64Type (ctx), Int (const_arg))
175+ const_inttoptr (intptr, actual_typ)
176+ else
177+ inttoptr! (builder, arg, actual_typ)
178+ end
153179 actual_arg = inttoptr! (builder, arg, actual_typ)
180+ elseif argtyp <: Bool
181+ # passed as i8
182+ T = eltype (argtyp)
183+ actual_typ = LLVM. Int1Type (ctx)
184+ actual_arg = if const_arg != = nothing
185+ LLVM. ConstantInt (actual_typ, const_arg)
186+ else
187+ trunc! (builder, arg, actual_typ)
188+ end
154189 else
155190 actual_typ = convert (LLVMType, argtyp; ctx)
156- actual_arg = arg
191+ actual_arg = if const_arg isa Integer
192+ LLVM. ConstantInt (actual_typ, argval. parameters[1 ])
193+ elseif const_arg isa AbstractFloat
194+ LLVM. ConstantFP (actual_typ, argval. parameters[1 ])
195+ else
196+ arg
197+ end
157198 end
158199 push! (T_actual_args, actual_typ)
159200 push! (actual_args, actual_arg)
160201 end
161202
162203 # same for the return type
163- if rettyp <: LLVMPtr
204+ T_ret_actual = if rettyp <: LLVMPtr
164205 T,AS = rettyp. parameters
165- T_ret_actual = LLVM. PointerType (convert (LLVMType, T; ctx), AS)
206+ LLVM. PointerType (convert (LLVMType, T; ctx), AS)
166207 elseif rettyp <: Ptr
167208 T = eltype (rettyp)
168- T_ret_actual = LLVM. PointerType (convert (LLVMType, T; ctx))
209+ LLVM. PointerType (convert (LLVMType, T; ctx))
210+ elseif rettyp <: Bool
211+ LLVM. Int1Type (ctx)
169212 else
170- T_ret_actual = T_ret
213+ T_ret
171214 end
172215
173216 intr_ft = LLVM. FunctionType (T_ret_actual, T_actual_args)
@@ -179,10 +222,14 @@ Base.signed(x::LLVMPtr) = Int(x)
179222 ret! (builder)
180223 else
181224 # also convert the return value
182- if rettyp <: LLVMPtr
183- rv = bitcast! (builder, rv, T_ret)
225+ rv = if rettyp <: LLVMPtr
226+ bitcast! (builder, rv, T_ret)
184227 elseif rettyp <: Ptr
185- rv = ptrtoint! (builder, rv, T_ret)
228+ ptrtoint! (builder, rv, T_ret)
229+ elseif rettyp <: Bool
230+ zext! (builder, rv, T_ret)
231+ else
232+ rv
186233 end
187234
188235 ret! (builder, rv)
@@ -193,12 +240,19 @@ Base.signed(x::LLVMPtr) = Int(x)
193240 end
194241end
195242
196- # perform a `ccall(intrinsic, llvmcall)` with accurate pointer types for calling intrinsic.
197- # this may be needed when selecting LLVM intrinsics, to avoid assertion failures, or when
198- # the back-end actually emits different code depending on types (e.g. SPIR-V and atomics).
199- #
200- # NOTE: this will become unnecessary when LLVM switches to typeless pointers,
201- # or if and when Julia goes back to emitting exact types when passing pointers.
243+ """
244+ @typed_ccall(intrinsic, llvmcall, rettyp, (argtyps...), args...)
245+
246+ Perform a `ccall` while more accurately preserving argument types like LLVM expects them:
247+
248+ - `Bool`s are passed as `i1`, not `i8`;
249+ - Pointers (both `Ptr` and `Core.LLVMPtr`) are passed as typed pointers (instead of resp.
250+ `i8*` and `i64`);
251+ - `Val`-typed arguments will be passed as constants, if supported.
252+
253+ These features can be useful to call LLVM intrinsics, which may expect a specific set of
254+ argument types.
255+ """
202256macro typed_ccall (intrinsic, cc, rettyp, argtyps, args... )
203257 # destructure and validate the arguments
204258 cc == :llvmcall || error (" Can only use @typed_ccall with the llvmcall calling convention" )
@@ -210,7 +264,13 @@ macro typed_ccall(intrinsic, cc, rettyp, argtyps, args...)
210264 :($ var = $ arg)
211265 end
212266 arg_exprs = map (zip (vars,argtyps. args)) do (var,typ)
213- :(Base. unsafe_convert ($ typ, Base. cconvert ($ typ, $ var)))
267+ quote
268+ if $ var isa Val
269+ Val (Base. unsafe_convert ($ typ, Base. cconvert ($ typ, typeof ($ var). parameters[1 ])))
270+ else
271+ Base. unsafe_convert ($ typ, Base. cconvert ($ typ, $ var))
272+ end
273+ end
214274 end
215275
216276 esc (quote
0 commit comments