@@ -115,14 +115,14 @@ Base.unsigned(x::LLVMPtr) = UInt(x)
115
115
Base. signed (x:: LLVMPtr ) = Int (x)
116
116
117
117
118
- # pointer type-preserving ccall
118
+ # type-preserving ccall
119
119
120
120
@generated function _typed_llvmcall (:: Val{intr} , rettyp, argtt, args... ) where {intr}
121
121
# make types available for direct use in this generator
122
122
rettyp = rettyp. parameters[1 ]
123
123
argtt = argtt. parameters[1 ]
124
124
argtyps = DataType[argtt. parameters... ]
125
- argexprs = Expr [:(args[$ i]) for i in 1 : length (args)]
125
+ argexprs = Any [:(args[$ i]) for i in 1 : length (args)]
126
126
127
127
# build IR that calls the intrinsic, casting types if necessary
128
128
@dispose ctx= Context () begin
@@ -140,34 +140,77 @@ Base.signed(x::LLVMPtr) = Int(x)
140
140
# reconstruct those so that we can accurately look up intrinsics.
141
141
T_actual_args = LLVMType[]
142
142
actual_args = LLVM. Value[]
143
- for (arg, argtyp) in zip (parameters (llvm_f),argtyps)
143
+ for (i, (arg, argtyp, argval)) in enumerate (zip (parameters (llvm_f), argtyps, args))
144
+ # if the value is a Val, we'll try to emit it as a constant
145
+ const_arg = if argval <: Val
146
+ # also pass the actual value for the fallback path (and to simplify
147
+ # construction of the LLVM function, where we can ignore constants)
148
+ argexprs[i] = argval. parameters[1 ]
149
+
150
+ argval. parameters[1 ]
151
+ else
152
+ nothing
153
+ end
154
+
144
155
if argtyp <: LLVMPtr
145
156
# passed as i8*
146
157
T,AS = argtyp. parameters
147
158
actual_typ = LLVM. PointerType (convert (LLVMType, T; ctx), AS)
148
- actual_arg = bitcast! (builder, arg, actual_typ)
159
+ actual_arg = if const_arg == C_NULL
160
+ LLVM. PointerNull (actual_typ)
161
+ elseif const_arg != = nothing
162
+ intptr = LLVM. ConstantInt (LLVM. Int64Type (ctx), Int (const_arg))
163
+ const_inttoptr (intptr, actual_typ)
164
+ else
165
+ bitcast! (builder, arg, actual_typ)
166
+ end
149
167
elseif argtyp <: Ptr
150
168
# passed as i64
151
169
T = eltype (argtyp)
152
170
actual_typ = LLVM. PointerType (convert (LLVMType, T; ctx))
171
+ actual_arg = if const_arg == C_NULL
172
+ LLVM. PointerNull (actual_typ)
173
+ elseif const_arg != = nothing
174
+ intptr = LLVM. ConstantInt (LLVM. Int64Type (ctx), Int (const_arg))
175
+ const_inttoptr (intptr, actual_typ)
176
+ else
177
+ inttoptr! (builder, arg, actual_typ)
178
+ end
153
179
actual_arg = inttoptr! (builder, arg, actual_typ)
180
+ elseif argtyp <: Bool
181
+ # passed as i8
182
+ T = eltype (argtyp)
183
+ actual_typ = LLVM. Int1Type (ctx)
184
+ actual_arg = if const_arg != = nothing
185
+ LLVM. ConstantInt (actual_typ, const_arg)
186
+ else
187
+ trunc! (builder, arg, actual_typ)
188
+ end
154
189
else
155
190
actual_typ = convert (LLVMType, argtyp; ctx)
156
- actual_arg = arg
191
+ actual_arg = if const_arg isa Integer
192
+ LLVM. ConstantInt (actual_typ, argval. parameters[1 ])
193
+ elseif const_arg isa AbstractFloat
194
+ LLVM. ConstantFP (actual_typ, argval. parameters[1 ])
195
+ else
196
+ arg
197
+ end
157
198
end
158
199
push! (T_actual_args, actual_typ)
159
200
push! (actual_args, actual_arg)
160
201
end
161
202
162
203
# same for the return type
163
- if rettyp <: LLVMPtr
204
+ T_ret_actual = if rettyp <: LLVMPtr
164
205
T,AS = rettyp. parameters
165
- T_ret_actual = LLVM. PointerType (convert (LLVMType, T; ctx), AS)
206
+ LLVM. PointerType (convert (LLVMType, T; ctx), AS)
166
207
elseif rettyp <: Ptr
167
208
T = eltype (rettyp)
168
- T_ret_actual = LLVM. PointerType (convert (LLVMType, T; ctx))
209
+ LLVM. PointerType (convert (LLVMType, T; ctx))
210
+ elseif rettyp <: Bool
211
+ LLVM. Int1Type (ctx)
169
212
else
170
- T_ret_actual = T_ret
213
+ T_ret
171
214
end
172
215
173
216
intr_ft = LLVM. FunctionType (T_ret_actual, T_actual_args)
@@ -179,10 +222,14 @@ Base.signed(x::LLVMPtr) = Int(x)
179
222
ret! (builder)
180
223
else
181
224
# also convert the return value
182
- if rettyp <: LLVMPtr
183
- rv = bitcast! (builder, rv, T_ret)
225
+ rv = if rettyp <: LLVMPtr
226
+ bitcast! (builder, rv, T_ret)
184
227
elseif rettyp <: Ptr
185
- rv = ptrtoint! (builder, rv, T_ret)
228
+ ptrtoint! (builder, rv, T_ret)
229
+ elseif rettyp <: Bool
230
+ zext! (builder, rv, T_ret)
231
+ else
232
+ rv
186
233
end
187
234
188
235
ret! (builder, rv)
@@ -193,12 +240,19 @@ Base.signed(x::LLVMPtr) = Int(x)
193
240
end
194
241
end
195
242
196
- # perform a `ccall(intrinsic, llvmcall)` with accurate pointer types for calling intrinsic.
197
- # this may be needed when selecting LLVM intrinsics, to avoid assertion failures, or when
198
- # the back-end actually emits different code depending on types (e.g. SPIR-V and atomics).
199
- #
200
- # NOTE: this will become unnecessary when LLVM switches to typeless pointers,
201
- # or if and when Julia goes back to emitting exact types when passing pointers.
243
+ """
244
+ @typed_ccall(intrinsic, llvmcall, rettyp, (argtyps...), args...)
245
+
246
+ Perform a `ccall` while more accurately preserving argument types like LLVM expects them:
247
+
248
+ - `Bool`s are passed as `i1`, not `i8`;
249
+ - Pointers (both `Ptr` and `Core.LLVMPtr`) are passed as typed pointers (instead of resp.
250
+ `i8*` and `i64`);
251
+ - `Val`-typed arguments will be passed as constants, if supported.
252
+
253
+ These features can be useful to call LLVM intrinsics, which may expect a specific set of
254
+ argument types.
255
+ """
202
256
macro typed_ccall (intrinsic, cc, rettyp, argtyps, args... )
203
257
# destructure and validate the arguments
204
258
cc == :llvmcall || error (" Can only use @typed_ccall with the llvmcall calling convention" )
@@ -210,7 +264,13 @@ macro typed_ccall(intrinsic, cc, rettyp, argtyps, args...)
210
264
:($ var = $ arg)
211
265
end
212
266
arg_exprs = map (zip (vars,argtyps. args)) do (var,typ)
213
- :(Base. unsafe_convert ($ typ, Base. cconvert ($ typ, $ var)))
267
+ quote
268
+ if $ var isa Val
269
+ Val (Base. unsafe_convert ($ typ, Base. cconvert ($ typ, typeof ($ var). parameters[1 ])))
270
+ else
271
+ Base. unsafe_convert ($ typ, Base. cconvert ($ typ, $ var))
272
+ end
273
+ end
214
274
end
215
275
216
276
esc (quote
0 commit comments