Skip to content

Commit 8e039d5

Browse files
committed
Remove support for non-1D launch configurations.
It's purely a convenience feature anyway, doesn't generalize, and can be emulated without performance loss.
1 parent 20538cd commit 8e039d5

File tree

7 files changed

+40
-52
lines changed

7 files changed

+40
-52
lines changed

src/array.jl

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -154,25 +154,24 @@ struct JLBackend <: AbstractGPUBackend end
154154

155155
GPUArrays.backend(::Type{<:JLArray}) = JLBackend()
156156

157-
mutable struct JLState{N}
158-
blockdim::NTuple{N, Int}
159-
griddim::NTuple{N, Int}
157+
mutable struct JLState
158+
blockdim::Int
159+
griddim::Int
160160

161-
blockidx::NTuple{N, Int}
162-
threadidx::NTuple{N, Int}
161+
blockidx::Int
162+
threadidx::Int
163163
localmem_counter::Int
164164
localmems::Vector{Vector{Array}}
165165
end
166166

167-
function JLState(threads::NTuple{N}, blockdim::NTuple{N}) where N
168-
idx = ntuple(i-> 1, Val(N))
167+
function JLState(threads::Int, blockdim::Int)
169168
blockcount = prod(blockdim)
170169
lmems = [Vector{Array}() for i in 1:blockcount]
171-
JLState{N}(threads, blockdim, idx, idx, 0, lmems)
170+
JLState(threads, blockdim, 1, 1, 0, lmems)
172171
end
173172

174-
function JLState(state::JLState{N}, threadidx::NTuple{N}) where N
175-
JLState{N}(
173+
function JLState(state::JLState, threadidx::Int)
174+
JLState(
176175
state.blockdim,
177176
state.griddim,
178177
state.blockidx,
@@ -187,17 +186,15 @@ to_device(state, x::Tuple) = to_device.(Ref(state), x)
187186
to_device(state, x::Base.RefValue{<: JLArray}) = Base.RefValue(to_device(state, x[]))
188187
to_device(state, x) = x
189188

190-
function GPUArrays._gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{T, T}) where T <: NTuple{N, Integer} where N
189+
function GPUArrays._gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{Int, Int})
191190
blocks, threads = blocks_threads
192-
idx = ntuple(i-> 1, length(blocks))
193-
blockdim = blocks
194-
state = JLState(threads, blockdim)
191+
state = JLState(threads, blocks)
195192
device_args = to_device.(Ref(state), args)
196-
tasks = Array{Task}(undef, threads...)
197-
for blockidx in CartesianIndices(blockdim)
198-
state.blockidx = blockidx.I
199-
for threadidx in CartesianIndices(threads)
200-
thread_state = JLState(state, threadidx.I)
193+
tasks = Array{Task}(undef, threads)
194+
for blockidx in 1:blocks
195+
state.blockidx = blockidx
196+
for threadidx in 1:threads
197+
thread_state = JLState(state, threadidx)
201198
tasks[threadidx] = @async @allowscalar f(thread_state, device_args...)
202199
# TODO: require 1.3 and use Base.Threads.@spawn for actual multithreading
203200
# (this would require a different synchronization mechanism)
@@ -246,7 +243,7 @@ end
246243

247244
function GPUArrays.LocalMemory(state::JLState, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
248245
state.localmem_counter += 1
249-
lmems = state.localmems[blockidx_x(state)]
246+
lmems = state.localmems[blockidx(state)]
250247

251248
# first invocation in block
252249
data = if length(lmems) < state.localmem_counter
@@ -272,11 +269,8 @@ Base.size(x::JLDeviceArray) = x.dims
272269
@inline Base.getindex(A::JLDeviceArray, index::Integer) = getindex(A.data, index)
273270
@inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(A.data, x, index)
274271

275-
for (i, sym) in enumerate((:x, :y, :z))
276-
for f in (:blockidx, :blockdim, :threadidx, :griddim)
277-
fname = Symbol(string(f, '_', sym))
278-
@eval GPUArrays.$fname(state::JLState) = Int(state.$f[$i])
279-
end
272+
for f in (:blockidx, :blockdim, :threadidx, :griddim)
273+
@eval GPUArrays.$f(state::JLState) = state.$f
280274
end
281275

282276

src/device/indexing.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@ export global_size, synchronize_threads, linear_index
44

55

66
# thread indexing functions
7-
for sym in (:x, :y, :z)
8-
for f in (:blockidx, :blockdim, :threadidx, :griddim)
9-
fname = Symbol(string(f, '_', sym))
10-
@eval $fname(state)::Int = error("Not implemented") # COV_EXCL_LINE
11-
@eval export $fname
12-
end
7+
for f in (:blockidx, :blockdim, :threadidx, :griddim)
8+
@eval $f(state)::Int = error("Not implemented") # COV_EXCL_LINE
9+
@eval export $f
1310
end
1411

1512
"""
@@ -18,8 +15,7 @@ end
1815
Global size == blockdim * griddim == total number of kernel execution
1916
"""
2017
@inline function global_size(state)
21-
# TODO nd version
22-
griddim_x(state) * blockdim_x(state)
18+
griddim(state) * blockdim(state)
2319
end
2420

2521
"""
@@ -29,7 +25,7 @@ linear index corresponding to each kernel launch (in OpenCL equal to get_global_
2925
3026
"""
3127
@inline function linear_index(state)
32-
(blockidx_x(state) - 1) * blockdim_x(state) + threadidx_x(state)
28+
(blockidx(state) - 1) * blockdim(state) + threadidx(state)
3329
end
3430

3531
"""

src/host/execution.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,12 @@ function gpu_call(kernel, A::AbstractArray, args::Tuple, configuration = length(
2727
thread_blocks = if isa(configuration, Integer)
2828
thread_blocks_heuristic(configuration)
2929
elseif isa(configuration, ITuple)
30-
# if a single integer ntuple, we assume it to configure the blocks
31-
configuration, ntuple(x-> x == 1 ? 256 : 1, length(configuration))
30+
@assert length(configuration) == 1
31+
configuration[1], 1
3232
elseif isa(configuration, Tuple{ITuple, ITuple})
33-
# 2 dim tuple of ints == blocks + threads per block
34-
if any(x-> length(x) > 3 || length(x) < 1, configuration)
35-
error("blocks & threads must be 1-3 dimensional. Found: $configuration")
36-
end
37-
map(x-> Int.(x), configuration) # make sure it all has the same int type
33+
@assert length(configuration[1]) == 1
34+
@assert length(configuration[2]) == 1
35+
configuration[1][1], configuration[2][1]
3836
else
3937
error("""Please launch a gpu kernel with a valid configuration.
4038
Found: $configurations
@@ -65,5 +63,5 @@ function thread_blocks_heuristic(len::Integer)
6563
# TODO better threads default
6664
threads = clamp(len, 1, 256)
6765
blocks = max(ceil(Int, len / threads), 1)
68-
(blocks,), (threads,)
66+
(blocks, threads)
6967
end

src/host/mapreduce.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ for i = 0:10
142142
global_index += global_size(state)
143143
end
144144
# Perform parallel reduction
145-
local_index = threadidx_x(state) - 1
145+
local_index = threadidx(state) - 1
146146
@inbounds tmp_local[local_index + 1] = acc
147147
synchronize_threads(state)
148148

149-
offset = blockdim_x(state) ÷ 2
149+
offset = blockdim(state) ÷ 2
150150
@inbounds while offset > 0
151151
if (local_index < offset)
152152
other = tmp_local[local_index + offset + 1]
@@ -157,7 +157,7 @@ for i = 0:10
157157
offset = offset ÷ 2
158158
end
159159
if local_index == 0
160-
@inbounds result[blockidx_x(state)] = tmp_local[1]
160+
@inbounds result[blockidx(state)] = tmp_local[1]
161161
end
162162
return
163163
end

src/host/random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function next_rand(::Type{FT}, state::NTuple{4, T}) where {FT, T <: Unsigned}
2929
end
3030

3131
function gpu_rand(::Type{T}, state, randstate::AbstractVector{NTuple{4, UInt32}}) where T
32-
threadid = GPUArrays.threadidx_x(state)
32+
threadid = GPUArrays.threadidx(state)
3333
stateful_rand = next_rand(T, randstate[threadid])
3434
randstate[threadid] = stateful_rand[1]
3535
return stateful_rand[2]

test/testsuite/base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ function test_base(AT)
146146

147147
@testset "heuristics" begin
148148
blocks, threads = thread_blocks_heuristic(0)
149-
@test blocks == (1,)
150-
@test threads == (1,)
149+
@test blocks == 1
150+
@test threads == 1
151151
end
152152
end
153153
end

test/testsuite/gpuinterface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,25 @@ function test_gpuinterface(AT)
1616
@test all(x-> x == 2, Array(x))
1717
configuration = ((N ÷ 2,), (2,))
1818
gpu_call(x, (x,), configuration) do state, x
19-
x[linear_index(state)] = threadidx_x(state)
19+
x[linear_index(state)] = threadidx(state)
2020
return
2121
end
2222
@test Array(x) == [1,2,1,2,1,2,1,2,1,2]
2323

2424
gpu_call(x, (x,), configuration) do state, x
25-
x[linear_index(state)] = blockidx_x(state)
25+
x[linear_index(state)] = blockidx(state)
2626
return
2727
end
2828
@test Array(x) == [1, 1, 2, 2, 3, 3, 4, 4, 5, 5]
2929
x2 = AT([0])
3030
gpu_call(x, (x2,), configuration) do state, x
31-
x[1] = blockdim_x(state)
31+
x[1] = blockdim(state)
3232
return
3333
end
3434
@test Array(x2) == [2]
3535

3636
gpu_call(x, (x2,), configuration) do state, x
37-
x[1] = griddim_x(state)
37+
x[1] = griddim(state)
3838
return
3939
end
4040
@test Array(x2) == [5]

0 commit comments

Comments
 (0)