Skip to content

Commit 11e6179

Browse files
authored
Merge pull request #230 from JuliaGPU/tb/1d
Remove non-1D indexing
2 parents ffdad12 + 8e039d5 commit 11e6179

File tree

8 files changed

+44
-101
lines changed

8 files changed

+44
-101
lines changed

src/array.jl

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -154,25 +154,24 @@ struct JLBackend <: AbstractGPUBackend end
154154

155155
GPUArrays.backend(::Type{<:JLArray}) = JLBackend()
156156

157-
mutable struct JLState{N}
158-
blockdim::NTuple{N, Int}
159-
griddim::NTuple{N, Int}
157+
mutable struct JLState
158+
blockdim::Int
159+
griddim::Int
160160

161-
blockidx::NTuple{N, Int}
162-
threadidx::NTuple{N, Int}
161+
blockidx::Int
162+
threadidx::Int
163163
localmem_counter::Int
164164
localmems::Vector{Vector{Array}}
165165
end
166166

167-
function JLState(threads::NTuple{N}, blockdim::NTuple{N}) where N
168-
idx = ntuple(i-> 1, Val(N))
167+
function JLState(threads::Int, blockdim::Int)
169168
blockcount = prod(blockdim)
170169
lmems = [Vector{Array}() for i in 1:blockcount]
171-
JLState{N}(threads, blockdim, idx, idx, 0, lmems)
170+
JLState(threads, blockdim, 1, 1, 0, lmems)
172171
end
173172

174-
function JLState(state::JLState{N}, threadidx::NTuple{N}) where N
175-
JLState{N}(
173+
function JLState(state::JLState, threadidx::Int)
174+
JLState(
176175
state.blockdim,
177176
state.griddim,
178177
state.blockidx,
@@ -187,17 +186,15 @@ to_device(state, x::Tuple) = to_device.(Ref(state), x)
187186
to_device(state, x::Base.RefValue{<: JLArray}) = Base.RefValue(to_device(state, x[]))
188187
to_device(state, x) = x
189188

190-
function GPUArrays._gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{T, T}) where T <: NTuple{N, Integer} where N
189+
function GPUArrays._gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{Int, Int})
191190
blocks, threads = blocks_threads
192-
idx = ntuple(i-> 1, length(blocks))
193-
blockdim = blocks
194-
state = JLState(threads, blockdim)
191+
state = JLState(threads, blocks)
195192
device_args = to_device.(Ref(state), args)
196-
tasks = Array{Task}(undef, threads...)
197-
for blockidx in CartesianIndices(blockdim)
198-
state.blockidx = blockidx.I
199-
for threadidx in CartesianIndices(threads)
200-
thread_state = JLState(state, threadidx.I)
193+
tasks = Array{Task}(undef, threads)
194+
for blockidx in 1:blocks
195+
state.blockidx = blockidx
196+
for threadidx in 1:threads
197+
thread_state = JLState(state, threadidx)
201198
tasks[threadidx] = @async @allowscalar f(thread_state, device_args...)
202199
# TODO: require 1.3 and use Base.Threads.@spawn for actual multithreading
203200
# (this would require a different synchronization mechanism)
@@ -246,7 +243,7 @@ end
246243

247244
function GPUArrays.LocalMemory(state::JLState, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
248245
state.localmem_counter += 1
249-
lmems = state.localmems[blockidx_x(state)]
246+
lmems = state.localmems[blockidx(state)]
250247

251248
# first invocation in block
252249
data = if length(lmems) < state.localmem_counter
@@ -272,11 +269,8 @@ Base.size(x::JLDeviceArray) = x.dims
272269
@inline Base.getindex(A::JLDeviceArray, index::Integer) = getindex(A.data, index)
273270
@inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(A.data, x, index)
274271

275-
for (i, sym) in enumerate((:x, :y, :z))
276-
for f in (:blockidx, :blockdim, :threadidx, :griddim)
277-
fname = Symbol(string(f, '_', sym))
278-
@eval GPUArrays.$fname(state::JLState) = Int(state.$f[$i])
279-
end
272+
for f in (:blockidx, :blockdim, :threadidx, :griddim)
273+
@eval GPUArrays.$f(state::JLState) = state.$f
280274
end
281275

282276

src/device/indexing.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@ export global_size, synchronize_threads, linear_index
44

55

66
# thread indexing functions
7-
for sym in (:x, :y, :z)
8-
for f in (:blockidx, :blockdim, :threadidx, :griddim)
9-
fname = Symbol(string(f, '_', sym))
10-
@eval $fname(state)::Int = error("Not implemented") # COV_EXCL_LINE
11-
@eval export $fname
12-
end
7+
for f in (:blockidx, :blockdim, :threadidx, :griddim)
8+
@eval $f(state)::Int = error("Not implemented") # COV_EXCL_LINE
9+
@eval export $f
1310
end
1411

1512
"""
@@ -18,8 +15,7 @@ end
1815
Global size == blockdim * griddim == total number of kernel execution
1916
"""
2017
@inline function global_size(state)
21-
# TODO nd version
22-
griddim_x(state) * blockdim_x(state)
18+
griddim(state) * blockdim(state)
2319
end
2420

2521
"""
@@ -29,7 +25,7 @@ linear index corresponding to each kernel launch (in OpenCL equal to get_global_
2925
3026
"""
3127
@inline function linear_index(state)
32-
(blockidx_x(state) - 1) * blockdim_x(state) + threadidx_x(state)
28+
(blockidx(state) - 1) * blockdim(state) + threadidx(state)
3329
end
3430

3531
"""

src/host/execution.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,12 @@ function gpu_call(kernel, A::AbstractArray, args::Tuple, configuration = length(
2727
thread_blocks = if isa(configuration, Integer)
2828
thread_blocks_heuristic(configuration)
2929
elseif isa(configuration, ITuple)
30-
# if a single integer ntuple, we assume it to configure the blocks
31-
configuration, ntuple(x-> x == 1 ? 256 : 1, length(configuration))
30+
@assert length(configuration) == 1
31+
configuration[1], 1
3232
elseif isa(configuration, Tuple{ITuple, ITuple})
33-
# 2 dim tuple of ints == blocks + threads per block
34-
if any(x-> length(x) > 3 || length(x) < 1, configuration)
35-
error("blocks & threads must be 1-3 dimensional. Found: $configuration")
36-
end
37-
map(x-> Int.(x), configuration) # make sure it all has the same int type
33+
@assert length(configuration[1]) == 1
34+
@assert length(configuration[2]) == 1
35+
configuration[1][1], configuration[2][1]
3836
else
3937
error("""Please launch a gpu kernel with a valid configuration.
4038
Found: $configurations
@@ -65,5 +63,5 @@ function thread_blocks_heuristic(len::Integer)
6563
# TODO better threads default
6664
threads = clamp(len, 1, 256)
6765
blocks = max(ceil(Int, len / threads), 1)
68-
(blocks,), (threads,)
66+
(blocks, threads)
6967
end

src/host/linalg.jl

Lines changed: 4 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -114,56 +114,11 @@ end
114114

115115
## high-level functionality
116116

117-
function transpose_blocks!(
118-
state, odata::AbstractArray{T}, idata, ::Val{SHMEM}, ::Val{TDIM}, ::Val{BLOCK_ROWS}, ::Val{NROW}
119-
) where {T, SHMEM, TDIM, BLOCK_ROWS, NROW}
120-
121-
tile = @LocalMemory(state, T, SHMEM)
122-
bidx_x = blockidx_x(state) - 1
123-
bidx_y = blockidx_y(state) - 1
124-
tidx_x = threadidx_x(state) - 1
125-
tidx_y = threadidx_y(state) - 1
126-
127-
x = bidx_x * TDIM + tidx_x + 1
128-
y = bidx_y * TDIM + tidx_y + 1
129-
dims = size(idata)
130-
131-
(x <= dims[2] && (y + (BLOCK_ROWS * 3)) <= dims[1]) || return
132-
133-
for j = 0:3
134-
j0 = j * BLOCK_ROWS
135-
@inbounds tile[tidx_x + 1, tidx_y + j0 + 1] = idata[y + j0, x]
136-
end
137-
138-
synchronize_threads(state)
139-
for j = 0:3
140-
j0 = j * BLOCK_ROWS
141-
@inbounds odata[x, y + j0] = tile[tidx_x + 1, tidx_y + j0 + 1]
142-
end
143-
144-
return
145-
end
146-
147117
function LinearAlgebra.transpose!(At::AbstractGPUArray{T, 2}, A::AbstractGPUArray{T, 2}) where T
148-
if size(A, 1) == size(A, 2) && all(x-> x % 32 == 0, size(A))
149-
outsize = size(At)
150-
TDIM = 32; BLOCK_ROWS = 8
151-
nrows = TDIM ÷ BLOCK_ROWS
152-
shmemdim = (TDIM, (TDIM + 1))
153-
static_params = map(x-> Val(x), (shmemdim, TDIM, BLOCK_ROWS, nrows))
154-
args = (At, A, static_params...)
155-
156-
griddim = ceil.(Int, size(A) ./ (TDIM, TDIM))
157-
blockdim = (TDIM, BLOCK_ROWS)
158-
# optimized version for 32x & square dimensions
159-
gpu_call(transpose_blocks!, At, args, (griddim, blockdim))
160-
else
161-
# simple fallback
162-
gpu_call(At, (At, A)) do state, At, A
163-
idx = @cartesianidx A state
164-
@inbounds At[idx[2], idx[1]] = A[idx[1], idx[2]]
165-
return
166-
end
118+
gpu_call(At, (At, A)) do state, At, A
119+
idx = @cartesianidx A state
120+
@inbounds At[idx[2], idx[1]] = A[idx[1], idx[2]]
121+
return
167122
end
168123
At
169124
end

src/host/mapreduce.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ for i = 0:10
142142
global_index += global_size(state)
143143
end
144144
# Perform parallel reduction
145-
local_index = threadidx_x(state) - 1
145+
local_index = threadidx(state) - 1
146146
@inbounds tmp_local[local_index + 1] = acc
147147
synchronize_threads(state)
148148

149-
offset = blockdim_x(state) ÷ 2
149+
offset = blockdim(state) ÷ 2
150150
@inbounds while offset > 0
151151
if (local_index < offset)
152152
other = tmp_local[local_index + offset + 1]
@@ -157,7 +157,7 @@ for i = 0:10
157157
offset = offset ÷ 2
158158
end
159159
if local_index == 0
160-
@inbounds result[blockidx_x(state)] = tmp_local[1]
160+
@inbounds result[blockidx(state)] = tmp_local[1]
161161
end
162162
return
163163
end

src/host/random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function next_rand(::Type{FT}, state::NTuple{4, T}) where {FT, T <: Unsigned}
2929
end
3030

3131
function gpu_rand(::Type{T}, state, randstate::AbstractVector{NTuple{4, UInt32}}) where T
32-
threadid = GPUArrays.threadidx_x(state)
32+
threadid = GPUArrays.threadidx(state)
3333
stateful_rand = next_rand(T, randstate[threadid])
3434
randstate[threadid] = stateful_rand[1]
3535
return stateful_rand[2]

test/testsuite/base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ function test_base(AT)
146146

147147
@testset "heuristics" begin
148148
blocks, threads = thread_blocks_heuristic(0)
149-
@test blocks == (1,)
150-
@test threads == (1,)
149+
@test blocks == 1
150+
@test threads == 1
151151
end
152152
end
153153
end

test/testsuite/gpuinterface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,25 @@ function test_gpuinterface(AT)
1616
@test all(x-> x == 2, Array(x))
1717
configuration = ((N ÷ 2,), (2,))
1818
gpu_call(x, (x,), configuration) do state, x
19-
x[linear_index(state)] = threadidx_x(state)
19+
x[linear_index(state)] = threadidx(state)
2020
return
2121
end
2222
@test Array(x) == [1,2,1,2,1,2,1,2,1,2]
2323

2424
gpu_call(x, (x,), configuration) do state, x
25-
x[linear_index(state)] = blockidx_x(state)
25+
x[linear_index(state)] = blockidx(state)
2626
return
2727
end
2828
@test Array(x) == [1, 1, 2, 2, 3, 3, 4, 4, 5, 5]
2929
x2 = AT([0])
3030
gpu_call(x, (x2,), configuration) do state, x
31-
x[1] = blockdim_x(state)
31+
x[1] = blockdim(state)
3232
return
3333
end
3434
@test Array(x2) == [2]
3535

3636
gpu_call(x, (x2,), configuration) do state, x
37-
x[1] = griddim_x(state)
37+
x[1] = griddim(state)
3838
return
3939
end
4040
@test Array(x2) == [5]

0 commit comments

Comments
 (0)