Skip to content

Commit ff71015

Browse files
committed
use cartesian iteration to support blocking
1 parent 550a693 commit ff71015

File tree

7 files changed

+142
-206
lines changed

7 files changed

+142
-206
lines changed

src/KernelAbstractions.jl

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ function async_copy! end
7676
"""
7777
groupsize()
7878
79-
Query the workgroupsize on the device.
79+
Query the workgroupsize on the device. This function returns
80+
a tuple corresponding to kernel configuration. In order to get
81+
the total size you can use `prod(groupsize())`.
8082
"""
8183
function groupsize end
8284

@@ -131,10 +133,6 @@ macro index(locale, args...)
131133
indexkind = :Linear
132134
end
133135

134-
if indexkind === :Cartesian && locale === :Local
135-
error("@index(Local, Cartesian) is not implemented yet")
136-
end
137-
138136
index_function = Symbol(:__index_, locale, :_, indexkind)
139137
Expr(:call, GlobalRef(KernelAbstractions, index_function), map(esc, args)...)
140138
end
@@ -189,14 +187,7 @@ end
189187
workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize
190188
ndrange(::Kernel{D, WorkgroupSize, NDRange}) where {D, WorkgroupSize,NDRange} = NDRange
191189

192-
"""
193-
partition(kernel, ndrange)
194-
195-
Splits the maximum size of the iteration space by the workgroupsize.
196-
Returns the number of workgroups necessary and whether the last workgroup
197-
needs to perform dynamic bounds-checking.
198-
"""
199-
@inline function partition(kernel::Kernel, ndrange, workgroupsize)
190+
function partition(kernel, ndrange, workgroupsize)
200191
static_ndrange = KernelAbstractions.ndrange(kernel)
201192
static_workgroupsize = KernelAbstractions.workgroupsize(kernel)
202193

@@ -208,42 +199,49 @@ needs to perform dynamic bounds-checking.
208199
You created a dynamically sized kernel, but forgot to provide runtime
209200
parameters for the kernel. Either provide them statically if known
210201
or dynamically.
211-
NDRange(Static): $(typeof(static_ndrange))
202+
NDRange(Static): $(static_ndrange)
212203
NDRange(Dynamic): $(ndrange)
213-
Workgroupsize(Static): $(typeof(static_workgroupsize))
204+
Workgroupsize(Static): $(static_workgroupsize)
214205
Workgroupsize(Dynamic): $(workgroupsize)
215206
"""
216207
error(errmsg)
217208
end
218209

219-
if ndrange !== nothing && static_ndrange <: StaticSize
220-
if prod(ndrange) != prod(get(static_ndrange))
210+
if static_ndrange <: StaticSize
211+
if ndrange !== nothing && ndrange != get(static_ndrange)
221212
error("Static NDRange and launch NDRange differ")
222213
end
214+
ndrange = get(static_ndrange)
223215
end
224216

225217
if static_workgroupsize <: StaticSize
226-
@assert length(get(static_workgroupsize)) === 1
227-
static_workgroupsize = get(static_workgroupsize)[1]
228-
if workgroupsize !== nothing && workgroupsize != static_workgroupsize
218+
if workgroupsize !== nothing && workgroupsize != get(static_workgroupsize)
229219
error("Static WorkgroupSize and launch WorkgroupSize differ")
230220
end
231-
workgroupsize = static_workgroupsize
221+
workgroupsize = get(static_workgroupsize)
232222
end
223+
233224
@assert workgroupsize !== nothing
225+
@assert ndrange !== nothing
226+
blocks, workgroupsize, dynamic = NDIteration.partition(ndrange, workgroupsize)
234227

235228
if static_ndrange <: StaticSize
236-
maxsize = prod(get(static_ndrange))
237-
else
238-
maxsize = prod(ndrange)
229+
static_blocks = StaticSize{blocks}
230+
blocks = nothing
231+
else
232+
static_blocks = DynamicSize
233+
blocks = CartesianIndices(blocks)
239234
end
240235

241-
nworkgroups = fld1(maxsize, workgroupsize)
242-
dynamic = mod(maxsize, workgroupsize) != 0
243-
244-
dynamic || @assert(nworkgroups * workgroupsize == maxsize)
236+
if static_workgroupsize <: StaticSize
237+
static_workgroupsize = StaticSize{workgroupsize} # we might have padded workgroupsize
238+
workgroupsize = nothing
239+
else
240+
workgroupsize = CartesianIndices(workgroupsize)
241+
end
245242

246-
return nworkgroups, dynamic
243+
iterspace = NDRange{length(ndrange), static_blocks, static_workgroupsize}(blocks, workgroupsize)
244+
return iterspace, dynamic
247245
end
248246

249247
###
@@ -256,10 +254,7 @@ include("compiler.jl")
256254
# Compiler/Frontend
257255
###
258256

259-
@inline function __workitems_iterspace()
260-
return 1:groupsize()
261-
end
262-
257+
function __workitems_iterspace end
263258
function __validindex end
264259

265260
include("macros.jl")

src/backends/cpu.jl

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,86 +14,90 @@ function wait(ev::CPUEvent, progress=nothing)
1414
end
1515

1616
function (obj::Kernel{CPU})(args...; ndrange=nothing, workgroupsize=nothing, dependencies=nothing)
17-
if ndrange isa Int
17+
if ndrange isa Integer
1818
ndrange = (ndrange,)
1919
end
20+
if workgroupsize isa Integer
21+
workgroupsize = (workgroupsize, )
22+
end
2023
if dependencies isa Event
2124
dependencies = (dependencies,)
2225
end
26+
2327
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
24-
workgroupsize = 1024 # Vectorization, 4x unrolling, minimal grain size
28+
workgroupsize = (1024,) # Vectorization, 4x unrolling, minimal grain size
2529
end
26-
nblocks, dynamic = partition(obj, ndrange, workgroupsize)
30+
iterspace, dynamic = partition(obj, ndrange, workgroupsize)
2731
# partition checked that the ndrange's agreed
2832
if KernelAbstractions.ndrange(obj) <: StaticSize
2933
ndrange = nothing
3034
end
31-
if KernelAbstractions.workgroupsize(obj) <: StaticSize
32-
workgroupsize = nothing
33-
end
34-
t = Threads.@spawn begin
35+
36+
t = __run(obj, ndrange, iterspace, args, dependencies)
37+
return CPUEvent(t)
38+
end
39+
40+
# Inference barrier
41+
function __run(obj, ndrange, iterspace, args, dependencies)
42+
return Threads.@spawn begin
3543
if dependencies !== nothing
3644
Base.sync_end(map(e->e.task, dependencies))
3745
end
3846
@sync begin
39-
for I in 1:(nblocks-1)
40-
let ctx = mkcontext(obj, I, ndrange, workgroupsize)
41-
Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
47+
# TODO: how do we use the information that the iteration space maps perfectly to
48+
# the ndrange without incurring a 2x compilation overhead
49+
# if dynamic
50+
for block in iterspace
51+
let ctx = mkcontextdynamic(obj, block, ndrange, iterspace)
52+
Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
53+
end
4254
end
43-
end
44-
45-
if dynamic
46-
let ctx = mkcontextdynamic(obj, nblocks, ndrange, workgroupsize)
47-
Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
48-
end
49-
else
50-
let ctx = mkcontext(obj, nblocks, ndrange, workgroupsize)
51-
Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
52-
end
53-
end
55+
# else
56+
# for block in iterspace
57+
# let ctx = mkcontext(obj, blocks, ndrange, iterspace)
58+
# Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
59+
# end
60+
# end
61+
# end
5462
end
5563
end
56-
return CPUEvent(t)
5764
end
5865

5966
Cassette.@context CPUCtx
6067

61-
function mkcontext(kernel::Kernel{CPU}, I, _ndrange, _workgroupsize)
62-
metadata = CompilerMetadata{workgroupsize(kernel), ndrange(kernel), false}(I, _ndrange, _workgroupsize)
68+
function mkcontext(kernel::Kernel{CPU}, I, _ndrange, iterspace)
69+
metadata = CompilerMetadata{ndrange(kernel), false}(I, _ndrange, iterspace)
6370
Cassette.disablehooks(CPUCtx(pass = CompilerPass, metadata=metadata))
6471
end
6572

66-
function mkcontextdynamic(kernel::Kernel{CPU}, I, _ndrange, _workgroupsize)
67-
metadata = CompilerMetadata{workgroupsize(kernel), ndrange(kernel), true}(I, _ndrange, _workgroupsize)
73+
function mkcontextdynamic(kernel::Kernel{CPU}, I, _ndrange, iterspace)
74+
metadata = CompilerMetadata{ndrange(kernel), true}(I, _ndrange, iterspace)
6875
Cassette.disablehooks(CPUCtx(pass = CompilerPass, metadata=metadata))
6976
end
7077

71-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Local_Linear), idx)
72-
return idx
78+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Local_Linear), idx::CartesianIndex)
79+
indices = workitems(__iterspace(ctx.metadata))
80+
return @inbounds LinearIndices(indices)[idx]
7381
end
7482

75-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Global_Linear), idx)
76-
workgroup = __groupindex(ctx.metadata)
77-
(workgroup - 1) * __groupsize(ctx.metadata) + idx
83+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Global_Linear), idx::CartesianIndex)
84+
I = @inbounds expand(__iterspace(ctx.metadata), __groupindex(ctx.metadata), idx)
85+
@inbounds LinearIndices(__ndrange(ctx.metadata))[I]
7886
end
7987

80-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Local_Cartesian), idx)
81-
error("@index(Local, Cartesian) is not yet defined")
88+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Local_Cartesian), idx::CartesianIndex)
89+
return idx
8290
end
8391

84-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Global_Cartesian), idx)
85-
workgroup = __groupindex(ctx.metadata)
86-
indices = __ndrange(ctx.metadata)
87-
lI = (workgroup - 1) * __groupsize(ctx.metadata) + idx
88-
return @inbounds indices[lI]
92+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Global_Cartesian), idx::CartesianIndex)
93+
return @inbounds expand(__iterspace(ctx.metadata), __groupindex(ctx.metadata), idx)
8994
end
9095

91-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__validindex), idx)
96+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__validindex), idx::CartesianIndex)
9297
# Turns this into a noop for code where we can turn of checkbounds of
9398
if __dynamic_checkbounds(ctx.metadata)
94-
maxidx = prod(size(__ndrange(ctx.metadata)))
95-
valid = idx <= mod1(maxidx, __groupsize(ctx.metadata))
96-
return valid
99+
I = @inbounds expand(__iterspace(ctx.metadata), __groupindex(ctx.metadata), idx)
100+
return I in __ndrange(ctx.metadata)
97101
else
98102
return true
99103
end

src/backends/cuda.jl

Lines changed: 34 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,12 @@ function wait(ev::CudaEvent, progress=nothing)
5959
end
6060

6161
function (obj::Kernel{CUDA})(args...; ndrange=nothing, dependencies=nothing, workgroupsize=nothing)
62-
if ndrange isa Int
62+
if ndrange isa Integer
6363
ndrange = (ndrange,)
6464
end
65+
if workgroupsize isa Integer
66+
workgroupsize = (workgroupsize, )
67+
end
6568
if dependencies isa Event
6669
dependencies = (dependencies,)
6770
end
@@ -74,92 +77,63 @@ function (obj::Kernel{CUDA})(args...; ndrange=nothing, dependencies=nothing, wor
7477
end
7578
end
7679

77-
event = CuEvent(CUDAdrv.EVENT_DISABLE_TIMING)
78-
79-
# Launch kernel
80-
ctx = mkcontext(obj, ndrange)
81-
args = (ctx, obj.f, args...)
82-
GC.@preserve args begin
83-
kernel_args = map(CUDAnative.cudaconvert, args)
84-
kernel_tt = Tuple{map(Core.Typeof, kernel_args)...}
85-
86-
# If the kernel is statically sized we can tell the compiler about that
87-
if KernelAbstractions.workgroupsize(obj) <: StaticSize
88-
static_workgroupsize = get(KernelAbstractions.workgroupsize(obj))[1]
89-
else
90-
static_workgroupsize = nothing
91-
end
80+
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
81+
# TODO: allow for NDRange{1, DynamicSize, DynamicSize}(nothing, nothing)
82+
# and actually use CUDAnative autotuning
83+
workgroupsize = (256,)
84+
end
85+
# If the kernel is statically sized we can tell the compiler about that
86+
if KernelAbstractions.workgroupsize(obj) <: StaticSize
87+
maxthreads = prod(get(KernelAbstractions.workgroupsize(obj)))
88+
else
89+
maxthreads = nothing
90+
end
9291

93-
kernel = CUDAnative.cufunction(Cassette.overdub, kernel_tt; name=String(nameof(obj.f)), maxthreads=static_workgroupsize)
92+
iterspace, dynamic = partition(obj, ndrange, workgroupsize)
9493

95-
# Dynamically sized and size not prescribed, use autotuning
96-
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
97-
workgroupsize = CUDAnative.maxthreads(kernel)
98-
end
94+
nblocks = length(blocks(iterspace))
95+
threads = length(workitems(iterspace))
9996

100-
if workgroupsize === nothing
101-
threads = static_workgroupsize
102-
else
103-
threads = workgroupsize
104-
end
105-
@assert threads !== nothing
97+
ctx = mkcontext(obj, ndrange, iterspace)
98+
# Launch kernel
99+
event = CuEvent(CUDAdrv.EVENT_DISABLE_TIMING)
100+
CUDAnative.@cuda(threads=threads, blocks=nblocks, stream=stream,
101+
name=String(nameof(obj.f)), maxthreads=maxthreads,
102+
Cassette.overdub(ctx, obj.f, args...))
106103

107-
blocks, _ = partition(obj, ndrange, threads)
108-
kernel(kernel_args..., threads=threads, blocks=blocks, stream=stream)
109-
end
110104
CUDAdrv.record(event, stream)
111105
return CudaEvent(event)
112106
end
113107

114108
Cassette.@context CUDACtx
115109

116-
function mkcontext(kernel::Kernel{CUDA}, _ndrange)
117-
metadata = CompilerMetadata{workgroupsize(kernel), ndrange(kernel), true}(_ndrange)
110+
function mkcontext(kernel::Kernel{CUDA}, _ndrange, iterspace)
111+
metadata = CompilerMetadata{ndrange(kernel), true}(_ndrange, iterspace)
118112
Cassette.disablehooks(CUDACtx(pass = CompilerPass, metadata=metadata))
119113
end
120114

121-
122-
@inline function __gpu_groupsize(::CompilerMetadata{WorkgroupSize}) where {WorkgroupSize<:DynamicSize}
123-
CUDAnative.blockDim().x
124-
end
125-
126-
@inline function __gpu_groupsize(cm::CompilerMetadata{WorkgroupSize}) where {WorkgroupSize<:StaticSize}
127-
__groupsize(cm)
128-
end
129-
130115
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Local_Linear))
131-
idx = CUDAnative.threadIdx().x
132-
return idx
116+
return CUDAnative.threadIdx().x
133117
end
134118

135119
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Global_Linear))
136-
idx = CUDAnative.threadIdx().x
137-
workgroup = CUDAnative.blockIdx().x
138-
# XXX: have a verify mode where we check that our static dimensions are right
139-
# e.g. that blockDim().x === __groupsize(ctx.metadata)
140-
return (workgroup - 1) * __gpu_groupsize(ctx.metadata) + idx
120+
I = @inbounds expand(__iterspace(ctx.metadata), CUDAnative.blockIdx().x, CUDAnative.threadIdx().x)
121+
# TODO: This is unfortunate, can we get the linear index cheaper
122+
@inbounds LinearIndices(__ndrange(ctx.metadata))[I]
141123
end
142124

143125
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Local_Cartesian))
144-
error("@index(Local, Cartesian) is not yet defined")
126+
@inbounds workitems(__iterspace(ctx.metadata))[CUDAnative.threadIdx().x]
145127
end
146128

147129
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Global_Cartesian))
148-
idx = CUDAnative.threadIdx().x
149-
workgroup = CUDAnative.blockIdx().x
150-
lI = (workgroup - 1) * __gpu_groupsize(ctx.metadata) + idx
151-
152-
indices = __ndrange(ctx.metadata)
153-
return @inbounds indices[lI]
130+
return @inbounds expand(__iterspace(ctx.metadata), CUDAnative.blockIdx().x, CUDAnative.threadIdx().x)
154131
end
155132

156133
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__validindex))
157134
if __dynamic_checkbounds(ctx.metadata)
158-
idx = CUDAnative.threadIdx().x
159-
workgroup = CUDAnative.blockIdx().x
160-
lI = (workgroup - 1) * __gpu_groupsize(ctx.metadata) + idx
161-
maxidx = prod(size(__ndrange(ctx.metadata)))
162-
return lI <= maxidx
135+
I = @inbounds expand(__iterspace(ctx.metadata), CUDAnative.blockIdx().x, CUDAnative.threadIdx().x)
136+
return I in __ndrange(ctx.metadata)
163137
else
164138
return true
165139
end

0 commit comments

Comments
 (0)