Skip to content

Commit 6da1def

Browse files
committed
introduce __ctx__ argument to kernel functions
1 parent eb1e356 commit 6da1def

File tree

13 files changed

+127
-111
lines changed

13 files changed

+127
-111
lines changed

examples/performance.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ end
3030
I, J = @index(Global, NTuple)
3131
i, j = @index(Local, NTuple)
3232

33-
N = @uniform groupsize()[1]
34-
M = @uniform groupsize()[2]
33+
N = @uniform @groupsize()[1]
34+
M = @uniform @groupsize()[2]
3535

3636
# +1 to avoid bank conflicts on shared memory
3737
tile = @localmem eltype(output) (N+BANK, M)
@@ -48,8 +48,8 @@ end
4848
gi, gj = @index(Group, NTuple)
4949
i, j = @index(Local, NTuple)
5050

51-
N = @uniform groupsize()[1]
52-
M = @uniform groupsize()[2]
51+
N = @uniform @groupsize()[1]
52+
M = @uniform @groupsize()[2]
5353

5454
# +1 to avoid bank conflicts on shared memory
5555
tile = @localmem eltype(output) (N+BANK, M)
@@ -77,8 +77,8 @@ end
7777
gi, gj = @index(Group, NTuple)
7878
i, j = @index(Local, NTuple)
7979

80-
TILE_DIM = @uniform groupsize()[1]
81-
BLOCK_ROWS = @uniform groupsize()[2]
80+
TILE_DIM = @uniform @groupsize()[1]
81+
BLOCK_ROWS = @uniform @groupsize()[2]
8282

8383
# +1 to avoid bank conflicts on shared memory
8484
tile = @localmem eltype(output) (TILE_DIM+BANK, TILE_DIM)
@@ -103,8 +103,8 @@ end
103103
gi, gj = @index(Group, NTuple)
104104
i, j = @index(Local, NTuple)
105105

106-
TILE_DIM = @uniform groupsize()[1]
107-
BLOCK_ROWS = @uniform groupsize()[2]
106+
TILE_DIM = @uniform @groupsize()[1]
107+
BLOCK_ROWS = @uniform @groupsize()[2]
108108

109109
# +1 to avoid bank conflicts on shared memory
110110
tile = @localmem eltype(output) (TILE_DIM+BANK, TILE_DIM)

examples/performant_matmul.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ const TILE_DIM = 32
99
gi, gj = @index(Group, NTuple)
1010
i, j = @index(Local, NTuple)
1111

12-
TILE_DIM = @uniform groupsize()[1]
12+
TILE_DIM = @uniform @groupsize()[1]
1313

1414
# +1 to avoid bank conflicts on shared memory
1515
tile1 = @localmem eltype(output) (TILE_DIM+BANK, TILE_DIM)

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(
191191
ndrange, workgroupsize, iterspace, dynamic = launch_config(obj, ndrange, workgroupsize)
192192
# this might not be the final context, since we may tune the workgroupsize
193193
ctx = mkcontext(obj, ndrange, iterspace)
194-
kernel = CUDA.@cuda launch=false name=String(nameof(obj.f)) Cassette.overdub(ctx, obj.f, args...)
194+
kernel = CUDA.@cuda launch=false name=String(nameof(obj.f)) Cassette.overdub(CUDACTX, obj.f, ctx, args...)
195195

196196
# figure out the optimal workgroupsize automatically
197197
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
@@ -220,7 +220,7 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(
220220

221221
# Launch kernel
222222
event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
223-
kernel(ctx, obj.f, args...; threads=threads, blocks=nblocks, stream=stream)
223+
kernel(CUDACTX, obj.f, ctx, args...; threads=threads, blocks=nblocks, stream=stream)
224224

225225
CUDA.record(event, stream)
226226
return CudaEvent(event)
@@ -232,41 +232,43 @@ import KernelAbstractions: CompilerMetadata, CompilerPass, DynamicCheck, LinearI
232232
import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print
233233
import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds
234234

235+
const CUDACTX = Cassette.disablehooks(CUDACtx(pass = CompilerPass))
236+
KernelAbstractions.cassette(::Kernel{CUDADevice}) = CUDACTX
237+
235238
function mkcontext(kernel::Kernel{CUDADevice}, _ndrange, iterspace)
236-
metadata = CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace)
237-
Cassette.disablehooks(CUDACtx(pass = CompilerPass, metadata=metadata))
239+
CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace)
238240
end
239241

240-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Local_Linear))
242+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Local_Linear), ctx)
241243
return CUDA.threadIdx().x
242244
end
243245

244-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Group_Linear))
246+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Group_Linear), ctx)
245247
return CUDA.blockIdx().x
246248
end
247249

248-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Global_Linear))
249-
I = @inbounds expand(__iterspace(ctx.metadata), CUDA.blockIdx().x, CUDA.threadIdx().x)
250+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Global_Linear), ctx)
251+
I = @inbounds expand(__iterspace(ctx), CUDA.blockIdx().x, CUDA.threadIdx().x)
250252
# TODO: This is unfortunate, can we get the linear index cheaper
251-
@inbounds LinearIndices(__ndrange(ctx.metadata))[I]
253+
@inbounds LinearIndices(__ndrange(ctx))[I]
252254
end
253255

254-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Local_Cartesian))
255-
@inbounds workitems(__iterspace(ctx.metadata))[CUDA.threadIdx().x]
256+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Local_Cartesian), ctx)
257+
@inbounds workitems(__iterspace(ctx))[CUDA.threadIdx().x]
256258
end
257259

258-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Group_Cartesian))
259-
@inbounds blocks(__iterspace(ctx.metadata))[CUDA.blockIdx().x]
260+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Group_Cartesian), ctx)
261+
@inbounds blocks(__iterspace(ctx))[CUDA.blockIdx().x]
260262
end
261263

262-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Global_Cartesian))
263-
return @inbounds expand(__iterspace(ctx.metadata), CUDA.blockIdx().x, CUDA.threadIdx().x)
264+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Global_Cartesian), ctx)
265+
return @inbounds expand(__iterspace(ctx), CUDA.blockIdx().x, CUDA.threadIdx().x)
264266
end
265267

266-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__validindex))
267-
if __dynamic_checkbounds(ctx.metadata)
268-
I = @inbounds expand(__iterspace(ctx.metadata), CUDA.blockIdx().x, CUDA.threadIdx().x)
269-
return I in __ndrange(ctx.metadata)
268+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__validindex), ctx)
269+
if __dynamic_checkbounds(ctx)
270+
I = @inbounds expand(__iterspace(ctx), CUDA.blockIdx().x, CUDA.threadIdx().x)
271+
return I in __ndrange(ctx)
270272
else
271273
return true
272274
end
@@ -323,7 +325,7 @@ import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize
323325
# GPU implementation of shared memory
324326
###
325327

326-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
328+
@inline function Cassette.overdub(::CUDACtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
327329
ptr = emit_shmem(Val(Id), T, Val(prod(Dims)))
328330
CUDA.CuDeviceArray(Dims, ptr)
329331
end
@@ -333,15 +335,15 @@ end
333335
# - private memory for each workitem
334336
###
335337

336-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims}
338+
@inline function Cassette.overdub(::CUDACtx, ::typeof(Scratchpad), ctx, ::Type{T}, ::Val{Dims}) where {T, Dims}
337339
MArray{__size(Dims), T}(undef)
338340
end
339341

340-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__synchronize))
342+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__synchronize))
341343
CUDA.sync_threads()
342344
end
343345

344-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__print), args...)
346+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__print), args...)
345347
CUDA._cuprint(args...)
346348
end
347349

lib/ROCKernels/src/ROCKernels.jl

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ function (obj::Kernel{ROCDevice})(args...; ndrange=nothing, dependencies=nothing
204204
# Launch kernel
205205
event = AMDGPU.@roc(groupsize=threads, gridsize=nblocks*threads, queue=queue,
206206
name=String(nameof(obj.f)), # TODO: maxthreads=maxthreads,
207-
Cassette.overdub(ctx, obj.f, args...))
207+
Cassette.overdub(ROCCTX, obj.f, ctx, args...))
208208

209209
return ROCEvent(event.event)
210210
end
@@ -215,45 +215,46 @@ import KernelAbstractions: CompilerMetadata, CompilerPass, DynamicCheck, LinearI
215215
import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print
216216
import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds
217217

218+
const ROCCTX = Cassette.disablehooks(ROCCtx(pass = CompilerPass))
219+
KernelAbstractions.cassette(::Kernel{ROCDevice}) = ROCCTX
220+
218221
function mkcontext(kernel::Kernel{ROCDevice}, _ndrange, iterspace)
219222
metadata = CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace)
220-
Cassette.disablehooks(ROCCtx(pass = CompilerPass, metadata=metadata))
221223
end
222224
function mkcontext(kernel::Kernel{ROCDevice}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
223225
metadata = CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
224-
Cassette.disablehooks(ROCCtx(pass = CompilerPass, metadata=metadata))
225226
end
226227

227-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Local_Linear))
228+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Local_Linear), ctx)
228229
return AMDGPU.threadIdx().x
229230
end
230231

231-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Group_Linear))
232+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Group_Linear), ctx)
232233
return AMDGPU.blockIdx().x
233234
end
234235

235-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Global_Linear))
236-
I = @inbounds expand(__iterspace(ctx.metadata), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
236+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Global_Linear), ctx)
237+
I = @inbounds expand(__iterspace(ctx), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
237238
# TODO: This is unfortunate, can we get the linear index cheaper
238-
@inbounds LinearIndices(__ndrange(ctx.metadata))[I]
239+
@inbounds LinearIndices(__ndrange(ctx))[I]
239240
end
240241

241-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Local_Cartesian))
242-
@inbounds workitems(__iterspace(ctx.metadata))[AMDGPU.threadIdx().x]
242+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Local_Cartesian), ctx)
243+
@inbounds workitems(__iterspace(ctx))[AMDGPU.threadIdx().x]
243244
end
244245

245-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Group_Cartesian))
246-
@inbounds blocks(__iterspace(ctx.metadata))[AMDGPU.blockIdx().x]
246+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Group_Cartesian), ctx)
247+
@inbounds blocks(__iterspace(ctx))[AMDGPU.blockIdx().x]
247248
end
248249

249-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Global_Cartesian))
250-
return @inbounds expand(__iterspace(ctx.metadata), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
250+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Global_Cartesian), ctx)
251+
return @inbounds expand(__iterspace(ctx), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
251252
end
252253

253-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__validindex))
254-
if __dynamic_checkbounds(ctx.metadata)
255-
I = @inbounds expand(__iterspace(ctx.metadata), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
256-
return I in __ndrange(ctx.metadata)
254+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__validindex), ctx)
255+
if __dynamic_checkbounds(ctx)
256+
I = @inbounds expand(__iterspace(ctx), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
257+
return I in __ndrange(ctx)
257258
else
258259
return true
259260
end
@@ -305,7 +306,7 @@ import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize
305306
###
306307
# GPU implementation of shared memory
307308
###
308-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
309+
@inline function Cassette.overdub(::ROCCtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
309310
ptr = AMDGPU.alloc_special(Val(Id), T, Val(AMDGPU.AS.Local), Val(prod(Dims)))
310311
AMDGPU.ROCDeviceArray(Dims, ptr)
311312
end
@@ -315,15 +316,15 @@ end
315316
# - private memory for each workitem
316317
###
317318

318-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims}
319+
@inline function Cassette.overdub(::ROCCtx, ::typeof(Scratchpad), ctx, ::Type{T}, ::Val{Dims}) where {T, Dims}
319320
MArray{__size(Dims), T}(undef)
320321
end
321322

322-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__synchronize))
323+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__synchronize))
323324
AMDGPU.sync_workgroup()
324325
end
325326

326-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__print), args...)
327+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__print), args...)
327328
for arg in args
328329
AMDGPU.@rocprintf("%s", arg)
329330
end

src/KernelAbstractions.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module KernelAbstractions
22

33
export @kernel
4-
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize, @print
4+
export @Const, @localmem, @private, @uniform, @synchronize, @index, @groupsize, @print
55
export Device, GPU, CPU, Event, MultiEvent, NoneEvent
66
export async_copy!
77

@@ -111,14 +111,20 @@ function async_copy! end
111111
###
112112

113113
"""
114-
groupsize()
114+
@groupsize()
115115
116116
Query the workgroupsize on the device. This function returns
117117
a tuple corresponding to kernel configuration. In order to get
118-
the total size you can use `prod(groupsize())`.
118+
the total size you can use `prod(@groupsize())`.
119119
"""
120120
function groupsize end
121121

122+
macro groupsize()
123+
quote
124+
$groupsize($(esc(:__ctx__)))
125+
end
126+
end
127+
122128
"""
123129
@localmem T dims
124130
@@ -150,7 +156,7 @@ macro private(T, dims)
150156
dims = (dims,)
151157
end
152158
quote
153-
$Scratchpad($(esc(T)), Val($(esc(dims))))
159+
$Scratchpad($(esc(:__ctx__)), $(esc(T)), Val($(esc(dims))))
154160
end
155161
end
156162

@@ -297,7 +303,7 @@ macro index(locale, args...)
297303
end
298304

299305
index_function = Symbol(:__index_, locale, :_, indexkind)
300-
Expr(:call, GlobalRef(KernelAbstractions, index_function), map(esc, args)...)
306+
Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...)
301307
end
302308

303309
###
@@ -312,9 +318,9 @@ function __index_Local_Cartesian end
312318
function __index_Group_Cartesian end
313319
function __index_Global_Cartesian end
314320

315-
__index_Local_NTuple(I...) = Tuple(__index_Local_Cartesian(I...))
316-
__index_Group_NTuple(I...) = Tuple(__index_Group_Cartesian(I...))
317-
__index_Global_NTuple(I...) = Tuple(__index_Global_Cartesian(I...))
321+
__index_Local_NTuple(ctx, I...) = Tuple(__index_Local_Cartesian(ctx, I...))
322+
__index_Group_NTuple(ctx, I...) = Tuple(__index_Group_Cartesian(ctx, I...))
323+
__index_Global_NTuple(ctx, I...) = Tuple(__index_Global_Cartesian(ctx, I...))
318324

319325
struct ConstAdaptor end
320326

@@ -429,7 +435,7 @@ include("macros.jl")
429435
# Backends/Interface
430436
###
431437

432-
function Scratchpad(::Type{T}, ::Val{Dims}) where {T, Dims}
438+
function Scratchpad(ctx, ::Type{T}, ::Val{Dims}) where {T, Dims}
433439
throw(MethodError(Scratchpad, (T, Val(Dims))))
434440
end
435441

src/compiler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ include("compiler/pass.jl")
3232

3333
function generate_overdubs(mod, Ctx)
3434
@eval mod begin
35-
@inline Cassette.overdub(ctx::$Ctx, ::typeof(groupsize)) = __groupsize(ctx.metadata)
36-
@inline Cassette.overdub(ctx::$Ctx, ::typeof(__workitems_iterspace)) = workitems(__iterspace(ctx.metadata))
35+
@inline Cassette.overdub(::$Ctx, ::typeof(groupsize), ctx) = __groupsize(ctx)
36+
@inline Cassette.overdub(::$Ctx, ::typeof(__workitems_iterspace), ctx) = workitems(__iterspace(ctx))
3737

3838
###
3939
# Cassette fixes

0 commit comments

Comments
 (0)