Skip to content

Commit b0d37ce

Browse files
authored
typed_ccall enhancements (#321)
* Pass Bool as i1 * Pass Val(...)-typed values as constants
1 parent d21df91 commit b0d37ce

File tree

3 files changed

+123
-19
lines changed

3 files changed

+123
-19
lines changed

docs/src/lib/interop.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,10 @@ LLVM.Interop.call_function
1919
```@docs
2020
LLVM.Interop.@asmcall
2121
```
22+
23+
24+
## LLVM type support
25+
26+
```@docs
27+
LLVM.Interop.@typed_ccall
28+
```

src/interop/pointer.jl

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,14 @@ Base.unsigned(x::LLVMPtr) = UInt(x)
115115
Base.signed(x::LLVMPtr) = Int(x)
116116

117117

118-
# pointer type-preserving ccall
118+
# type-preserving ccall
119119

120120
@generated function _typed_llvmcall(::Val{intr}, rettyp, argtt, args...) where {intr}
121121
# make types available for direct use in this generator
122122
rettyp = rettyp.parameters[1]
123123
argtt = argtt.parameters[1]
124124
argtyps = DataType[argtt.parameters...]
125-
argexprs = Expr[:(args[$i]) for i in 1:length(args)]
125+
argexprs = Any[:(args[$i]) for i in 1:length(args)]
126126

127127
# build IR that calls the intrinsic, casting types if necessary
128128
@dispose ctx=Context() begin
@@ -140,34 +140,77 @@ Base.signed(x::LLVMPtr) = Int(x)
140140
# reconstruct those so that we can accurately look up intrinsics.
141141
T_actual_args = LLVMType[]
142142
actual_args = LLVM.Value[]
143-
for (arg, argtyp) in zip(parameters(llvm_f),argtyps)
143+
for (i, (arg, argtyp, argval)) in enumerate(zip(parameters(llvm_f), argtyps, args))
144+
# if the value is a Val, we'll try to emit it as a constant
145+
const_arg = if argval <: Val
146+
# also pass the actual value for the fallback path (and to simplify
147+
# construction of the LLVM function, where we can ignore constants)
148+
argexprs[i] = argval.parameters[1]
149+
150+
argval.parameters[1]
151+
else
152+
nothing
153+
end
154+
144155
if argtyp <: LLVMPtr
145156
# passed as i8*
146157
T,AS = argtyp.parameters
147158
actual_typ = LLVM.PointerType(convert(LLVMType, T; ctx), AS)
148-
actual_arg = bitcast!(builder, arg, actual_typ)
159+
actual_arg = if const_arg == C_NULL
160+
LLVM.PointerNull(actual_typ)
161+
elseif const_arg !== nothing
162+
intptr = LLVM.ConstantInt(LLVM.Int64Type(ctx), Int(const_arg))
163+
const_inttoptr(intptr, actual_typ)
164+
else
165+
bitcast!(builder, arg, actual_typ)
166+
end
149167
elseif argtyp <: Ptr
150168
# passed as i64
151169
T = eltype(argtyp)
152170
actual_typ = LLVM.PointerType(convert(LLVMType, T; ctx))
171+
actual_arg = if const_arg == C_NULL
172+
LLVM.PointerNull(actual_typ)
173+
elseif const_arg !== nothing
174+
intptr = LLVM.ConstantInt(LLVM.Int64Type(ctx), Int(const_arg))
175+
const_inttoptr(intptr, actual_typ)
176+
else
177+
inttoptr!(builder, arg, actual_typ)
178+
end
153179
actual_arg = inttoptr!(builder, arg, actual_typ)
180+
elseif argtyp <: Bool
181+
# passed as i8
182+
T = eltype(argtyp)
183+
actual_typ = LLVM.Int1Type(ctx)
184+
actual_arg = if const_arg !== nothing
185+
LLVM.ConstantInt(actual_typ, const_arg)
186+
else
187+
trunc!(builder, arg, actual_typ)
188+
end
154189
else
155190
actual_typ = convert(LLVMType, argtyp; ctx)
156-
actual_arg = arg
191+
actual_arg = if const_arg isa Integer
192+
LLVM.ConstantInt(actual_typ, argval.parameters[1])
193+
elseif const_arg isa AbstractFloat
194+
LLVM.ConstantFP(actual_typ, argval.parameters[1])
195+
else
196+
arg
197+
end
157198
end
158199
push!(T_actual_args, actual_typ)
159200
push!(actual_args, actual_arg)
160201
end
161202

162203
# same for the return type
163-
if rettyp <: LLVMPtr
204+
T_ret_actual = if rettyp <: LLVMPtr
164205
T,AS = rettyp.parameters
165-
T_ret_actual = LLVM.PointerType(convert(LLVMType, T; ctx), AS)
206+
LLVM.PointerType(convert(LLVMType, T; ctx), AS)
166207
elseif rettyp <: Ptr
167208
T = eltype(rettyp)
168-
T_ret_actual = LLVM.PointerType(convert(LLVMType, T; ctx))
209+
LLVM.PointerType(convert(LLVMType, T; ctx))
210+
elseif rettyp <: Bool
211+
LLVM.Int1Type(ctx)
169212
else
170-
T_ret_actual = T_ret
213+
T_ret
171214
end
172215

173216
intr_ft = LLVM.FunctionType(T_ret_actual, T_actual_args)
@@ -179,10 +222,14 @@ Base.signed(x::LLVMPtr) = Int(x)
179222
ret!(builder)
180223
else
181224
# also convert the return value
182-
if rettyp <: LLVMPtr
183-
rv = bitcast!(builder, rv, T_ret)
225+
rv = if rettyp <: LLVMPtr
226+
bitcast!(builder, rv, T_ret)
184227
elseif rettyp <: Ptr
185-
rv = ptrtoint!(builder, rv, T_ret)
228+
ptrtoint!(builder, rv, T_ret)
229+
elseif rettyp <: Bool
230+
zext!(builder, rv, T_ret)
231+
else
232+
rv
186233
end
187234

188235
ret!(builder, rv)
@@ -193,12 +240,19 @@ Base.signed(x::LLVMPtr) = Int(x)
193240
end
194241
end
195242

196-
# perform a `ccall(intrinsic, llvmcall)` with accurate pointer types for calling intrinsic.
197-
# this may be needed when selecting LLVM intrinsics, to avoid assertion failures, or when
198-
# the back-end actually emits different code depending on types (e.g. SPIR-V and atomics).
199-
#
200-
# NOTE: this will become unnecessary when LLVM switches to typeless pointers,
201-
# or if and when Julia goes back to emitting exact types when passing pointers.
243+
"""
244+
@typed_ccall(intrinsic, llvmcall, rettyp, (argtyps...), args...)
245+
246+
Perform a `ccall` while more accurately preserving argument types like LLVM expects them:
247+
248+
- `Bool`s are passed as `i1`, not `i8`;
249+
- Pointers (both `Ptr` and `Core.LLVMPtr`) are passed as typed pointers (instead of resp.
250+
`i8*` and `i64`);
251+
- `Val`-typed arguments will be passed as constants, if supported.
252+
253+
These features can be useful to call LLVM intrinsics, which may expect a specific set of
254+
argument types.
255+
"""
202256
macro typed_ccall(intrinsic, cc, rettyp, argtyps, args...)
203257
# destructure and validate the arguments
204258
cc == :llvmcall || error("Can only use @typed_ccall with the llvmcall calling convention")
@@ -210,7 +264,13 @@ macro typed_ccall(intrinsic, cc, rettyp, argtyps, args...)
210264
:($var = $arg)
211265
end
212266
arg_exprs = map(zip(vars,argtyps.args)) do (var,typ)
213-
:(Base.unsafe_convert($typ, Base.cconvert($typ, $var)))
267+
quote
268+
if $var isa Val
269+
Val(Base.unsafe_convert($typ, Base.cconvert($typ, typeof($var).parameters[1])))
270+
else
271+
Base.unsafe_convert($typ, Base.cconvert($typ, $var))
272+
end
273+
end
214274
end
215275

216276
esc(quote

test/interop.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,43 @@ end
254254

255255
# test return nothing
256256
LLVM.Interop.@typed_ccall("llvm.donothing", llvmcall, Cvoid, ())
257+
258+
# test return Bool
259+
expect_bool(val, expected_val) = LLVM.Interop.@typed_ccall("llvm.expect.i1", llvmcall, Bool, (Bool,Bool), val, expected_val)
260+
@test expect_bool(true, false)
261+
262+
# test return non-special type
263+
expect_int(val, expected_val) = LLVM.Interop.@typed_ccall("llvm.expect.i64", llvmcall, Int, (Int,Int), val, expected_val)
264+
@test expect_int(42, 0) == 42
265+
266+
# test passing constant values
267+
let
268+
a = [42]
269+
b = [0]
270+
memcpy(dst, src, len) = LLVM.Interop.@typed_ccall("llvm.memcpy.p0.p0.i64", llvmcall, Cvoid, (Ptr{Int}, Ptr{Int}, Int, Bool), dst, src, len, Val(false))
271+
memcpy(b, a, 1)
272+
@test b == [42]
273+
274+
const_bool_false = memcpy
275+
ir = sprint(io->code_llvm(io, memcpy, Tuple{Vector{Int}, Vector{Int}, Int}))
276+
@test occursin(r"call void @llvm.memcpy.p0i64.p0i64.i64\(i64\* .+, i64\* .+, i64 .+, i1 false\)", ir)
277+
278+
const_bool_true(dst, src, len) = LLVM.Interop.@typed_ccall("llvm.memcpy.p0.p0.i64", llvmcall, Cvoid, (Ptr{Int}, Ptr{Int}, Int, Bool), dst, src, len, Val(true))
279+
ir = sprint(io->code_llvm(io, const_bool_true, Tuple{Vector{Int}, Vector{Int}, Int}))
280+
@test occursin(r"call void @llvm.memcpy.p0i64.p0i64.i64\(i64\* .+, i64\* .+, i64 .+, i1 true\)", ir)
281+
282+
const_ptrs(len) = LLVM.Interop.@typed_ccall("llvm.memcpy.p0.p0.i64", llvmcall, Cvoid, (Ptr{Int}, Ptr{Int}, Int, Bool), Val(Ptr{Int}(0)), Val(Ptr{Int}(1)), len, Val(true))
283+
ir = sprint(io->code_llvm(io, const_ptrs, Tuple{Int}))
284+
@test occursin(r"call void @llvm.memcpy.p0i64.p0i64.i64\(i64\* null, i64\* inttoptr \(i64 1 to i64\*\), i64 .+, i1 true\)", ir)
285+
286+
const_llvmptrs(len) = LLVM.Interop.@typed_ccall("llvm.memcpy.p0.p0.i64", llvmcall, Cvoid, (LLVMPtr{Int,0}, LLVMPtr{Int,0}, Int, Bool), Val(LLVMPtr{Int,0}(0)), Val(LLVMPtr{Int,0}(1)), len, Val(true))
287+
ir = sprint(io->code_llvm(io, const_llvmptrs, Tuple{Int}))
288+
@test occursin(r"call void @llvm.memcpy.p0i64.p0i64.i64\(i64\* null, i64\* inttoptr \(i64 1 to i64\*\), i64 .+, i1 true\)", ir)
289+
290+
const_int(dst, src) = LLVM.Interop.@typed_ccall("llvm.memcpy.p0.p0.i64", llvmcall, Cvoid, (Ptr{Int}, Ptr{Int}, Int, Bool), dst, src, Val(999), Val(false))
291+
ir = sprint(io->code_llvm(io, const_int, Tuple{Vector{Int}, Vector{Int}}))
292+
@test occursin(r"call void @llvm.memcpy.p0i64.p0i64.i64\(i64\* .+, i64\* .+, i64 999, i1 false\)", ir)
293+
end
257294
end
258295

259296
end

0 commit comments

Comments
 (0)