Skip to content

Commit 4bec614

Browse files
vchuravymaleadt
andauthored
Support passing symbols as arguments (#2624)
Co-authored-by: Tim Besard <[email protected]>
1 parent 7bee37c commit 4bec614

File tree

4 files changed

+30
-15
lines changed

4 files changed

+30
-15
lines changed

lib/cudadrv/execution.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,21 @@
22

33
export cudacall
44

5+
# In contrast to `Base.RefValue` we just need a container for both pass-by-ref (Symbol),
6+
# and pass-by-value (immutable structs).
7+
mutable struct ArgBox{T}
8+
const val::T
9+
end
10+
11+
function Base.unsafe_convert(P::Union{Type{Ptr{T}}, Type{Ptr{Cvoid}}}, b::ArgBox{T})::P where {T}
12+
# TODO: What to do if T is not a leaftype (compare case 3 for RefValue)
13+
return pointer_from_objref(b)
14+
end
515

616
## device
717

818
# pack arguments in a buffer that CUDA expects
919
@inline @generated function pack_arguments(f::Function, args...)
10-
for arg in args
11-
isbitstype(arg) || throw(ArgumentError("Arguments to kernel should be bitstype."))
12-
end
13-
1420
ex = quote end
1521

1622
# If f has N parameters, then kernelParams needs to be an array of N pointers.
@@ -21,7 +27,7 @@ export cudacall
2127
arg_refs = Vector{Symbol}(undef, length(args))
2228
for i in 1:length(args)
2329
arg_refs[i] = gensym()
24-
push!(ex.args, :($(arg_refs[i]) = Base.RefValue(args[$i])))
30+
push!(ex.args, :($(arg_refs[i]) = $ArgBox(args[$i])))
2531
end
2632

2733
# generate an array with pointers

src/compiler/compilation.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,11 @@ end
242242
CompilerConfig(target, params; kernel, name, always_inline)
243243
end
244244

245+
# a version of `sizeof` that returns the size of the argument we'll pass.
246+
# for example, it supports Symbols where `sizeof(Symbol)` would fail.
247+
argsize(x::Any) = sizeof(x)
248+
argsize(::Type{Symbol}) = sizeof(Ptr{Cvoid})
249+
245250
# compile to executable machine code
246251
function compile(@nospecialize(job::CompilerJob))
247252
# lower to PTX
@@ -281,7 +286,7 @@ function compile(@nospecialize(job::CompilerJob))
281286
argtypes = filter([KernelState, job.source.specTypes.parameters...]) do dt
282287
!isghosttype(dt) && !Core.Compiler.isconstType(dt)
283288
end
284-
param_usage = sum(sizeof, argtypes)
289+
param_usage = sum(argsize, argtypes)
285290
param_limit = 4096
286291
if cap >= v"7.0" && ptx >= v"8.1"
287292
param_limit = 32764

src/compiler/execution.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,15 +259,6 @@ end
259259
call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]]
260260
call_args = Union{Expr,Symbol}[x[1] for x in zip(argexprs, to_pass) if x[2]]
261261

262-
# replace non-isbits arguments (they should be unused, or compilation would have failed)
263-
# alternatively, make it possible to `launch` with non-isbits arguments.
264-
for (i,dt) in enumerate(call_t)
265-
if !isbitstype(dt)
266-
call_t[i] = Ptr{Any}
267-
call_args[i] = :C_NULL
268-
end
269-
end
270-
271262
# add the kernel state, passing an instance with a unique seed
272263
pushfirst!(call_t, KernelState)
273264
pushfirst!(call_args, :(KernelState(kernel.state.exception_info, make_seed(kernel))))

test/core/execution.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,19 @@ end
626626
@test_throws "Kernel invocation uses too much parameter memory" @cuda kernel(ntuple(_->UInt64(1), 2^13))
627627
end
628628

629+
@testset "symbols" begin
630+
function pass_symbol(x, name)
631+
i = name == :var ? 1 : 2
632+
x[i] = true
633+
return nothing
634+
end
635+
x = CuArray([false, false])
636+
@cuda pass_symbol(x, :var)
637+
@test Array(x) == [true, false]
638+
@cuda pass_symbol(x, :not_var)
639+
@test Array(x) == [true, true]
640+
end
641+
629642
end
630643

631644
############################################################################################

0 commit comments

Comments
 (0)