-
Notifications
You must be signed in to change notification settings - Fork 44
implement device-side RNG #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #365 +/- ##
==========================================
+ Coverage 79.01% 80.19% +1.17%
==========================================
Files 12 12
Lines 672 722 +50
==========================================
+ Hits 531 579 +48
- Misses 141 143 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
17892c2
to
aca6d1c
Compare
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/cl/kernel.jl b/lib/cl/kernel.jl
index d024313..5ad48ff 100644
--- a/lib/cl/kernel.jl
+++ b/lib/cl/kernel.jl
@@ -119,8 +119,10 @@ function local_size_for_sub_group_count(ki::KernelSubGroupInfo, sub_group_count:
d = getfield(ki, :device)
input_value = Ref{Csize_t}(sub_group_count)
result = Ref{NTuple{3, Csize_t}}()
- clGetKernelSubGroupInfo(k, d, CL_KERNEL_LOCAL_SIZE_FOR_SUB_GROUP_COUNT,
- sizeof(Csize_t), input_value, sizeof(NTuple{3, Csize_t}), result, C_NULL)
+ clGetKernelSubGroupInfo(
+ k, d, CL_KERNEL_LOCAL_SIZE_FOR_SUB_GROUP_COUNT,
+ sizeof(Csize_t), input_value, sizeof(NTuple{3, Csize_t}), result, C_NULL
+ )
return Int.(result[])
end
@@ -135,7 +137,7 @@ function Base.getproperty(ki::KernelSubGroupInfo, s::Symbol)
return result[]
end
- if s == :max_sub_group_size
+ return if s == :max_sub_group_size
Int(get(CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, Csize_t))
elseif s == :sub_group_count
Int(get(CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE, Csize_t))
@@ -157,8 +159,9 @@ end
## kernel calling
function enqueue_kernel(k::Kernel, global_work_size, local_work_size=nothing;
- global_work_offset=nothing, wait_on::Vector{Event}=Event[],
- rng_state=false, nargs=nothing)
+ global_work_offset = nothing, wait_on::Vector{Event} = Event[],
+ rng_state = false, nargs = nothing
+ )
max_work_dim = device().max_work_item_dims
work_dim = length(global_work_size)
if work_dim > max_work_dim
@@ -252,7 +255,7 @@ function call(
k::Kernel, args...; global_size = (1,), local_size = nothing,
global_work_offset = nothing, wait_on::Vector{Event} = Event[],
indirect_memory::Vector{AbstractMemory} = AbstractMemory[],
- rng_state=false,
+ rng_state = false,
)
set_args!(k, args...)
if !isempty(indirect_memory)
@@ -306,7 +309,7 @@ function call(
clSetKernelExecInfo(k, CL_KERNEL_EXEC_INFO_USM_PTRS_INTEL, sizeof(usm_pointers), usm_pointers)
end
end
- enqueue_kernel(k, global_size, local_size; global_work_offset, wait_on, rng_state, nargs=length(args))
+ return enqueue_kernel(k, global_size, local_size; global_work_offset, wait_on, rng_state, nargs = length(args))
end
# From `julia/base/reflection.jl`, adjusted to add specialization on `t`.
diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl
index 1ab715d..1a252fb 100644
--- a/src/compiler/compilation.jl
+++ b/src/compiler/compilation.jl
@@ -20,11 +20,15 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
GPUCompiler.kernel_state_type(::OpenCLCompilerJob) = KernelState
-function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
- mod::LLVM.Module, entry::LLVM.Function)
- entry = invoke(GPUCompiler.finish_module!,
- Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
- job, mod, entry)
+function GPUCompiler.finish_module!(
+ @nospecialize(job::OpenCLCompilerJob),
+ mod::LLVM.Module, entry::LLVM.Function
+ )
+ entry = invoke(
+ GPUCompiler.finish_module!,
+ Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
+ job, mod, entry
+ )
# if this kernel uses our RNG, we should prime the shared state.
# XXX: these transformations should really happen at the Julia IR level...
@@ -36,7 +40,7 @@ function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
# create a deferred compilation job for `initialize_rng_state`
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
- cfg = CompilerConfig(job.config; kernel=false, name=nothing)
+ cfg = CompilerConfig(job.config; kernel = false, name = nothing)
job = CompilerJob(src, cfg, job.world)
id = length(GPUCompiler.deferred_codegen_jobs) + 1
GPUCompiler.deferred_codegen_jobs[id] = job
@@ -44,7 +48,7 @@ function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
# generate IR for calls to `deferred_codegen` and the resulting function pointer
top_bb = first(blocks(entry))
bb = BasicBlock(top_bb, "initialize_rng")
- @dispose builder=IRBuilder() begin
+ @dispose builder = IRBuilder() begin
position!(builder, bb)
subprogram = LLVM.subprogram(entry)
if subprogram !== nothing
@@ -166,5 +170,5 @@ function link(@nospecialize(job::CompilerJob), compiled)
cl.Program(; source)
end
cl.build!(prog)
- (; kernel=cl.Kernel(prog, compiled.entry), compiled.device_rng)
+ return (; kernel = cl.Kernel(prog, compiled.entry), compiled.device_rng)
end
diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl
index 68885b8..6ffc207 100644
--- a/src/compiler/execution.jl
+++ b/src/compiler/execution.jl
@@ -196,7 +196,7 @@ function clfunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
kernel = get(_kernel_instances, h, nothing)
if kernel === nothing
# create the kernel state object
- kernel = HostKernel{F,tt}(f, linked.kernel, linked.device_rng)
+ kernel = HostKernel{F, tt}(f, linked.kernel, linked.device_rng)
_kernel_instances[h] = kernel
end
return kernel::HostKernel{F,tt}
diff --git a/src/device/random.jl b/src/device/random.jl
index 396b34b..6bbb1d0 100644
--- a/src/device/random.jl
+++ b/src/device/random.jl
@@ -25,7 +25,7 @@ end
function initialize_rng_state()
subgroup_id = get_sub_group_id()
@inbounds global_random_keys()[subgroup_id] = kernel_state().random_seed
- @inbounds global_random_counters()[subgroup_id] = 0
+ return @inbounds global_random_counters()[subgroup_id] = 0
end
# generators
@@ -41,7 +41,7 @@ struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64} end
@inline function Base.getproperty(rng::Philox2x32, field::Symbol)
subgroup_id = get_sub_group_id()
- if field === :key
+ return if field === :key
@inbounds global_random_keys()[subgroup_id]
elseif field === :ctr1
@inbounds global_random_counters()[subgroup_id]
@@ -69,7 +69,7 @@ end
Seed the on-device Philox2x32 generator with an UInt32 number.
Should be called by at least one thread per warp.
"""
-function Random.seed!(rng::Philox2x32, seed::Integer, counter::Integer=UInt32(0))
+function Random.seed!(rng::Philox2x32, seed::Integer, counter::Integer = UInt32(0))
rng.key = seed % UInt32
rng.ctr1 = counter
return
@@ -99,25 +99,57 @@ end
Generate a byte of random data using the on-device Tausworthe generator.
"""
-function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
+function Random.rand(rng::Philox2x32{R}, ::Type{UInt64}) where {R}
ctr1, ctr2, key = rng.ctr1, rng.ctr2, rng.key
- if R > 0 ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 1 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 2 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 3 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 4 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 5 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 6 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 7 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 8 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 9 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 10 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 11 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 12 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 13 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 14 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
- if R > 15 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
+ if R > 0
+ ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 1
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 2
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 3
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 4
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 5
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 6
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 7
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 8
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 9
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 10
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 11
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 12
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 13
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 14
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
+ if R > 15
+ key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+ end
# update the warp counter
# NOTE: this performs the same update on every thread in the warp, but each warp writes
@@ -131,16 +163,15 @@ function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
end
-
# normally distributed random numbers using Ziggurat algorithm
#
# copied from Base because we don't support its global tables
# a hacky method of exposing constant tables as constant GPU memory
function emit_constant_array(name::Symbol, data::AbstractArray{T}) where {T}
- @dispose ctx=Context() begin
+ return @dispose ctx = Context() begin
T_val = convert(LLVMType, T)
- T_ptr = convert(LLVMType, LLVMPtr{T,AS.UniformConstant})
+ T_ptr = convert(LLVMType, LLVMPtr{T, AS.UniformConstant})
# define function and get LLVM module
llvm_f, _ = create_function(T_ptr)
@@ -156,7 +187,7 @@ function emit_constant_array(name::Symbol, data::AbstractArray{T}) where {T}
alignment!(gv, 16)
# generate IR
- @dispose builder=IRBuilder() begin
+ @dispose builder = IRBuilder() begin
entry = BasicBlock(llvm_f, "entry")
position!(builder, entry)
@@ -167,17 +198,17 @@ function emit_constant_array(name::Symbol, data::AbstractArray{T}) where {T}
ret!(builder, untyped_ptr)
end
- call_function(llvm_f, LLVMPtr{T,AS.UniformConstant})
+ call_function(llvm_f, LLVMPtr{T, AS.UniformConstant})
end
end
for var in [:ki, :wi, :fi, :ke, :we, :fe]
val = getfield(Random, var)
gpu_var = Symbol("gpu_$var")
- arr_typ = :(CLDeviceArray{$(eltype(val)),$(ndims(val)),AS.UniformConstant})
+ arr_typ = :(CLDeviceArray{$(eltype(val)), $(ndims(val)), AS.UniformConstant})
@eval @inline @generated function $gpu_var()
ptr = emit_constant_array($(QuoteNode(var)), $val)
- Expr(:call, $arr_typ, $(size(val)), ptr)
+ return Expr(:call, $arr_typ, $(size(val)), ptr)
end
end
@@ -190,17 +221,17 @@ end
r &= 0x000fffffffffffff
rabs = Int64(r >> 1) # One bit for the sign
idx = rabs & 0xFF
- x = ifelse(r % Bool, -rabs, rabs)*gpu_wi()[idx+1]
- rabs < gpu_ki()[idx+1] && return x # 99.3% of the time we return here 1st try
+ x = ifelse(r % Bool, -rabs, rabs) * gpu_wi()[idx + 1]
+ rabs < gpu_ki()[idx + 1] && return x # 99.3% of the time we return here 1st try
# TODO: This code could be outlined once LLVM supports LDS access in recursively-called functions
@inbounds if idx == 0
while true
- xx = -Random.ziggurat_nor_inv_r*log(Random.rand(rng))
+ xx = -Random.ziggurat_nor_inv_r * log(Random.rand(rng))
yy = -log(Random.rand(rng))
- yy+yy > xx*xx &&
- return (rabs >> 8) % Bool ? -Random.ziggurat_nor_r-xx : Random.ziggurat_nor_r+xx
+ yy + yy > xx * xx &&
+ return (rabs >> 8) % Bool ? -Random.ziggurat_nor_r - xx : Random.ziggurat_nor_r + xx
end
- elseif (gpu_fi()[idx] - gpu_fi()[idx+1])*Random.rand(rng) + gpu_fi()[idx+1] < exp(-0.5*x*x)
+ elseif (gpu_fi()[idx] - gpu_fi()[idx + 1]) * Random.rand(rng) + gpu_fi()[idx + 1] < exp(-0.5 * x * x)
return x # return from the triangular area
else
@goto retry
@@ -220,12 +251,12 @@ end
@inbounds begin
ri &= 0x000fffffffffffff
idx = ri & 0xFF
- x = ri*gpu_we()[idx+1]
- ri < gpu_ke()[idx+1] && return x # 98.9% of the time we return here 1st try
+ x = ri * gpu_we()[idx + 1]
+ ri < gpu_ke()[idx + 1] && return x # 98.9% of the time we return here 1st try
# TODO: This code could be outlined once LLVM supports LDS access in recursively-called functions
@inbounds if idx == 0
return Random.ziggurat_exp_r - log(Random.rand(rng))
- elseif (gpu_fe()[idx] - gpu_fe()[idx+1])*Random.rand(rng) + gpu_fe()[idx+1] < exp(-x)
+ elseif (gpu_fe()[idx] - gpu_fe()[idx + 1]) * Random.rand(rng) + gpu_fe()[idx + 1] < exp(-x)
return x # return from the triangular area
else
@goto retry
diff --git a/src/device/runtime.jl b/src/device/runtime.jl
index ef2f608..8e4f2d6 100644
--- a/src/device/runtime.jl
+++ b/src/device/runtime.jl
@@ -42,7 +42,7 @@ end
# run-time equivalent
function additional_arg_value(state, name)
- @dispose ctx=Context() begin
+ return @dispose ctx = Context() begin
T_state = convert(LLVMType, state)
# create function
@@ -54,7 +54,7 @@ function additional_arg_value(state, name)
state_intr_ft = function_type(state_intr)
# generate IR
- @dispose builder=IRBuilder() begin
+ @dispose builder = IRBuilder() begin
entry = BasicBlock(llvm_f, "entry")
position!(builder, entry)
diff --git a/test/device/random.jl b/test/device/random.jl
index 7a40df5..37833f7 100644
--- a/test/device/random.jl
+++ b/test/device/random.jl
@@ -3,7 +3,7 @@ using Random
const n = 256
function apply_seed(seed)
- if seed === missing
+ return if seed === missing
# should result in different numbers across launches
Random.seed!()
# XXX: this currently doesn't work, because of the definition in Base,
@@ -33,8 +33,8 @@ eltypes = [filter(x -> !(x <: Complex), GPUArraysTestSuite.supported_eltypes(CLA
a = OpenCL.zeros(T, n)
b = OpenCL.zeros(T, n)
- @opencl global_size=n local_size=n kernel(a, seed)
- @opencl global_size=n local_size=n kernel(b, seed)
+ @opencl global_size = n local_size = n kernel(a, seed)
+ @opencl global_size = n local_size = n kernel(b, seed)
if seed === nothing || seed === missing
@test Array(a) != Array(b)
@@ -56,7 +56,7 @@ eltypes = [filter(x -> !(x <: Complex), GPUArraysTestSuite.supported_eltypes(CLA
a = OpenCL.zeros(T, n)
b = OpenCL.zeros(T, n)
- @opencl global_size=n local_size=n kernel(a, b, seed)
+ @opencl global_size = n local_size = n kernel(a, b, seed)
@test Array(a) != Array(b)
end
@@ -67,7 +67,7 @@ eltypes = [filter(x -> !(x <: Complex), GPUArraysTestSuite.supported_eltypes(CLA
function kernel(A::AbstractArray{T}, seed) where {T}
apply_seed(seed)
id = get_local_id(1) * get_local_id(2) * get_local_id(3) *
- get_group_id(1) * get_group_id(2) * get_group_id(3)
+ get_group_id(1) * get_group_id(2) * get_group_id(3)
if 1 <= id <= length(A)
A[id] = rand(T)
end
@@ -75,10 +75,10 @@ eltypes = [filter(x -> !(x <: Complex), GPUArraysTestSuite.supported_eltypes(CLA
end
tx, ty, tz, bx, by, bz = [dim == active_dim ? 3 : 1 for dim in 1:6]
- gx, gy, gz = tx*bx, ty*by, tz*bz
+ gx, gy, gz = tx * bx, ty * by, tz * bz
a = OpenCL.zeros(T, 3)
- @opencl local_size=(tx, ty, tz) global_size=(gx, gy, gz) kernel(a, seed)
+ @opencl local_size = (tx, ty, tz) global_size = (gx, gy, gz) kernel(a, seed)
# NOTE: we don't just generate two numbers and compare them, instead generating a
# couple more and checking they're not all the same, in order to avoid
@@ -99,8 +99,8 @@ end
a = OpenCL.zeros(T, n)
b = OpenCL.zeros(T, n)
- @opencl global_size=n local_size=n kernel(a, seed)
- @opencl global_size=n local_size=n kernel(b, seed)
+ @opencl global_size = n local_size = n kernel(a, seed)
+ @opencl global_size = n local_size = n kernel(b, seed)
if seed === nothing || seed === missing
@test Array(a) != Array(b)
@@ -126,8 +126,8 @@ end
a = OpenCL.zeros(T, n)
b = OpenCL.zeros(T, n)
- @opencl global_size=n local_size=n kernel(a, seed)
- @opencl global_size=n local_size=n kernel(b, seed)
+ @opencl global_size = n local_size = n kernel(a, seed)
+ @opencl global_size = n local_size = n kernel(b, seed)
if seed === nothing || seed === missing
@test Array(a) != Array(b)
diff --git a/test/setup.jl b/test/setup.jl
index 90337d3..22e5f43 100644
--- a/test/setup.jl
+++ b/test/setup.jl
@@ -91,7 +91,7 @@ function runtests(f, name, platform_filter)
# some tests require native execution capabilities
requires_il = name in ["atomics", "execution", "intrinsics", "kernelabstractions"] ||
- startswith(name, "gpuarrays/") || startswith(name, "device/")
+ startswith(name, "gpuarrays/") || startswith(name, "device/")
ex = quote
GC.gc(true) |
I don't like the additional argument passing. Isn't this because we currently only support dynamically-sized local memory? IIUC it should be possible to define it entirely in the module, like such:
|
The issue is that to know how much memory we are going to need, we need to query the number of subgroups inside a workgroup. I don't think we can do that for local memory declared inside a module. |
What about doing that as part of the deferred compilation hook in If you want I can give this a try later this or next week. |
I'm still not 100% convinced, as you would probably want to query |
Do we need to query the kernel-specific subgroup count? We only want to minimize over-estimating the amount of memory needed. In Metal.jl we also use a safe over-estimation based on hardware capabilities. |
This is what I get for the pocl CPU backend: (KernelSubGroupInfo(k, device(), lsize)).sub_group_count = 1
(KernelSubGroupInfo(k, device(), Csize_t[])).max_num_sub_groups = 128
(KernelSubGroupInfo(k, device(), Csize_t[])).compile_num_sub_groups = 0
(device()).max_num_sub_groups = 128 So if we don't calculate this for each kernel launch, we would overallocate around 1 kB of local memory for each work group. Probably not a deal breaker on CPU, but if we can avoid it? Looks like we also can't use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if we don't calculate this for each kernel launch, we would overallocate around 1 kB of local memory for each work group.
I guess, I just don't like the implementation, hacking in additional arguments in random places. Probably fine for now, but I hope we find something better.
Pushed something that uses a function attribute instead of an API call. |
partially by making use of JuliaGPU/GPUCompiler.jl#727
Co-authored-by: Tim Besard <[email protected]>
I believe this is ready now, ok to merge? |
Requires
JuliaGPU/GPUCompiler.jl#717JuliaGPU/GPUCompiler.jl#718