Skip to content

Commit 51a59dc

Browse files
bors[bot]vchuravy
andauthored
Merge #244
244: Remove metadata from cassette context r=vchuravy a=vchuravy bors try Co-authored-by: Valentin Churavy <[email protected]>
2 parents eb1e356 + 54d6811 commit 51a59dc

File tree

16 files changed

+136
-116
lines changed

16 files changed

+136
-116
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KernelAbstractions"
22
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
33
authors = ["Valentin Churavy <[email protected]>"]
4-
version = "0.6.1"
4+
version = "0.7.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

examples/performance.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ end
3030
I, J = @index(Global, NTuple)
3131
i, j = @index(Local, NTuple)
3232

33-
N = @uniform groupsize()[1]
34-
M = @uniform groupsize()[2]
33+
N = @uniform @groupsize()[1]
34+
M = @uniform @groupsize()[2]
3535

3636
# +1 to avoid bank conflicts on shared memory
3737
tile = @localmem eltype(output) (N+BANK, M)
@@ -48,8 +48,8 @@ end
4848
gi, gj = @index(Group, NTuple)
4949
i, j = @index(Local, NTuple)
5050

51-
N = @uniform groupsize()[1]
52-
M = @uniform groupsize()[2]
51+
N = @uniform @groupsize()[1]
52+
M = @uniform @groupsize()[2]
5353

5454
# +1 to avoid bank conflicts on shared memory
5555
tile = @localmem eltype(output) (N+BANK, M)
@@ -77,8 +77,8 @@ end
7777
gi, gj = @index(Group, NTuple)
7878
i, j = @index(Local, NTuple)
7979

80-
TILE_DIM = @uniform groupsize()[1]
81-
BLOCK_ROWS = @uniform groupsize()[2]
80+
TILE_DIM = @uniform @groupsize()[1]
81+
BLOCK_ROWS = @uniform @groupsize()[2]
8282

8383
# +1 to avoid bank conflicts on shared memory
8484
tile = @localmem eltype(output) (TILE_DIM+BANK, TILE_DIM)
@@ -103,8 +103,8 @@ end
103103
gi, gj = @index(Group, NTuple)
104104
i, j = @index(Local, NTuple)
105105

106-
TILE_DIM = @uniform groupsize()[1]
107-
BLOCK_ROWS = @uniform groupsize()[2]
106+
TILE_DIM = @uniform @groupsize()[1]
107+
BLOCK_ROWS = @uniform @groupsize()[2]
108108

109109
# +1 to avoid bank conflicts on shared memory
110110
tile = @localmem eltype(output) (TILE_DIM+BANK, TILE_DIM)

examples/performant_matmul.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ const TILE_DIM = 32
99
gi, gj = @index(Group, NTuple)
1010
i, j = @index(Local, NTuple)
1111

12-
TILE_DIM = @uniform groupsize()[1]
12+
TILE_DIM = @uniform @groupsize()[1]
1313

1414
# +1 to avoid bank conflicts on shared memory
1515
tile1 = @localmem eltype(output) (TILE_DIM+BANK, TILE_DIM)

lib/CUDAKernels/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CUDAKernels"
22
uuid = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
33
authors = ["Valentin Churavy <[email protected]>"]
4-
version = "0.2.1"
4+
version = "0.3.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -15,7 +15,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1515
Adapt = "3.0"
1616
CUDA = "3.0"
1717
Cassette = "0.3.3"
18-
KernelAbstractions = "0.6"
18+
KernelAbstractions = "0.7"
1919
SpecialFunctions = "0.10, 1.0"
2020
StaticArrays = "0.12, 1.0"
2121
julia = "1.6"

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(
191191
ndrange, workgroupsize, iterspace, dynamic = launch_config(obj, ndrange, workgroupsize)
192192
# this might not be the final context, since we may tune the workgroupsize
193193
ctx = mkcontext(obj, ndrange, iterspace)
194-
kernel = CUDA.@cuda launch=false name=String(nameof(obj.f)) Cassette.overdub(ctx, obj.f, args...)
194+
kernel = CUDA.@cuda launch=false name=String(nameof(obj.f)) Cassette.overdub(CUDACTX, obj.f, ctx, args...)
195195

196196
# figure out the optimal workgroupsize automatically
197197
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
@@ -220,7 +220,7 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(
220220

221221
# Launch kernel
222222
event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
223-
kernel(ctx, obj.f, args...; threads=threads, blocks=nblocks, stream=stream)
223+
kernel(CUDACTX, obj.f, ctx, args...; threads=threads, blocks=nblocks, stream=stream)
224224

225225
CUDA.record(event, stream)
226226
return CudaEvent(event)
@@ -232,41 +232,43 @@ import KernelAbstractions: CompilerMetadata, CompilerPass, DynamicCheck, LinearI
232232
import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print
233233
import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds
234234

235+
const CUDACTX = Cassette.disablehooks(CUDACtx(pass = CompilerPass))
236+
KernelAbstractions.cassette(::Kernel{CUDADevice}) = CUDACTX
237+
235238
function mkcontext(kernel::Kernel{CUDADevice}, _ndrange, iterspace)
236-
metadata = CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace)
237-
Cassette.disablehooks(CUDACtx(pass = CompilerPass, metadata=metadata))
239+
CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace)
238240
end
239241

240-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Local_Linear))
242+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Local_Linear), ctx)
241243
return CUDA.threadIdx().x
242244
end
243245

244-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Group_Linear))
246+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Group_Linear), ctx)
245247
return CUDA.blockIdx().x
246248
end
247249

248-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Global_Linear))
249-
I = @inbounds expand(__iterspace(ctx.metadata), CUDA.blockIdx().x, CUDA.threadIdx().x)
250+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Global_Linear), ctx)
251+
I = @inbounds expand(__iterspace(ctx), CUDA.blockIdx().x, CUDA.threadIdx().x)
250252
# TODO: This is unfortunate, can we get the linear index cheaper
251-
@inbounds LinearIndices(__ndrange(ctx.metadata))[I]
253+
@inbounds LinearIndices(__ndrange(ctx))[I]
252254
end
253255

254-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Local_Cartesian))
255-
@inbounds workitems(__iterspace(ctx.metadata))[CUDA.threadIdx().x]
256+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Local_Cartesian), ctx)
257+
@inbounds workitems(__iterspace(ctx))[CUDA.threadIdx().x]
256258
end
257259

258-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Group_Cartesian))
259-
@inbounds blocks(__iterspace(ctx.metadata))[CUDA.blockIdx().x]
260+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Group_Cartesian), ctx)
261+
@inbounds blocks(__iterspace(ctx))[CUDA.blockIdx().x]
260262
end
261263

262-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__index_Global_Cartesian))
263-
return @inbounds expand(__iterspace(ctx.metadata), CUDA.blockIdx().x, CUDA.threadIdx().x)
264+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Global_Cartesian), ctx)
265+
return @inbounds expand(__iterspace(ctx), CUDA.blockIdx().x, CUDA.threadIdx().x)
264266
end
265267

266-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__validindex))
267-
if __dynamic_checkbounds(ctx.metadata)
268-
I = @inbounds expand(__iterspace(ctx.metadata), CUDA.blockIdx().x, CUDA.threadIdx().x)
269-
return I in __ndrange(ctx.metadata)
268+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__validindex), ctx)
269+
if __dynamic_checkbounds(ctx)
270+
I = @inbounds expand(__iterspace(ctx), CUDA.blockIdx().x, CUDA.threadIdx().x)
271+
return I in __ndrange(ctx)
270272
else
271273
return true
272274
end
@@ -323,7 +325,7 @@ import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize
323325
# GPU implementation of shared memory
324326
###
325327

326-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
328+
@inline function Cassette.overdub(::CUDACtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
327329
ptr = emit_shmem(Val(Id), T, Val(prod(Dims)))
328330
CUDA.CuDeviceArray(Dims, ptr)
329331
end
@@ -333,15 +335,15 @@ end
333335
# - private memory for each workitem
334336
###
335337

336-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims}
338+
@inline function Cassette.overdub(::CUDACtx, ::typeof(Scratchpad), ctx, ::Type{T}, ::Val{Dims}) where {T, Dims}
337339
MArray{__size(Dims), T}(undef)
338340
end
339341

340-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__synchronize))
342+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__synchronize))
341343
CUDA.sync_threads()
342344
end
343345

344-
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__print), args...)
346+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__print), args...)
345347
CUDA._cuprint(args...)
346348
end
347349

lib/ROCKernels/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ROCKernels"
22
uuid = "7eb9e9f0-4bd3-4c4c-8bef-26bd9629d9b9"
33
authors = ["Valentin Churavy <[email protected]>", "Julian P Samaroo <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.2.0"
55

66
[deps]
77
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
@@ -15,7 +15,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1515
AMDGPU = "0.2.5"
1616
Adapt = "0.4, 1.0, 2.0, 3.0"
1717
Cassette = "0.3.3"
18-
KernelAbstractions = "0.6"
18+
KernelAbstractions = "0.7"
1919
SpecialFunctions = "0.10, 1.0"
2020
StaticArrays = "0.12, 1.0"
2121
julia = "1.6"

lib/ROCKernels/src/ROCKernels.jl

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ function (obj::Kernel{ROCDevice})(args...; ndrange=nothing, dependencies=nothing
204204
# Launch kernel
205205
event = AMDGPU.@roc(groupsize=threads, gridsize=nblocks*threads, queue=queue,
206206
name=String(nameof(obj.f)), # TODO: maxthreads=maxthreads,
207-
Cassette.overdub(ctx, obj.f, args...))
207+
Cassette.overdub(ROCCTX, obj.f, ctx, args...))
208208

209209
return ROCEvent(event.event)
210210
end
@@ -215,45 +215,46 @@ import KernelAbstractions: CompilerMetadata, CompilerPass, DynamicCheck, LinearI
215215
import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print
216216
import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds
217217

218+
const ROCCTX = Cassette.disablehooks(ROCCtx(pass = CompilerPass))
219+
KernelAbstractions.cassette(::Kernel{ROCDevice}) = ROCCTX
220+
218221
function mkcontext(kernel::Kernel{ROCDevice}, _ndrange, iterspace)
219222
metadata = CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace)
220-
Cassette.disablehooks(ROCCtx(pass = CompilerPass, metadata=metadata))
221223
end
222224
function mkcontext(kernel::Kernel{ROCDevice}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
223225
metadata = CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
224-
Cassette.disablehooks(ROCCtx(pass = CompilerPass, metadata=metadata))
225226
end
226227

227-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Local_Linear))
228+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Local_Linear), ctx)
228229
return AMDGPU.threadIdx().x
229230
end
230231

231-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Group_Linear))
232+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Group_Linear), ctx)
232233
return AMDGPU.blockIdx().x
233234
end
234235

235-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Global_Linear))
236-
I = @inbounds expand(__iterspace(ctx.metadata), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
236+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Global_Linear), ctx)
237+
I = @inbounds expand(__iterspace(ctx), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
237238
# TODO: This is unfortunate, can we get the linear index cheaper
238-
@inbounds LinearIndices(__ndrange(ctx.metadata))[I]
239+
@inbounds LinearIndices(__ndrange(ctx))[I]
239240
end
240241

241-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Local_Cartesian))
242-
@inbounds workitems(__iterspace(ctx.metadata))[AMDGPU.threadIdx().x]
242+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Local_Cartesian), ctx)
243+
@inbounds workitems(__iterspace(ctx))[AMDGPU.threadIdx().x]
243244
end
244245

245-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Group_Cartesian))
246-
@inbounds blocks(__iterspace(ctx.metadata))[AMDGPU.blockIdx().x]
246+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Group_Cartesian), ctx)
247+
@inbounds blocks(__iterspace(ctx))[AMDGPU.blockIdx().x]
247248
end
248249

249-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__index_Global_Cartesian))
250-
return @inbounds expand(__iterspace(ctx.metadata), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
250+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__index_Global_Cartesian), ctx)
251+
return @inbounds expand(__iterspace(ctx), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
251252
end
252253

253-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__validindex))
254-
if __dynamic_checkbounds(ctx.metadata)
255-
I = @inbounds expand(__iterspace(ctx.metadata), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
256-
return I in __ndrange(ctx.metadata)
254+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__validindex), ctx)
255+
if __dynamic_checkbounds(ctx)
256+
I = @inbounds expand(__iterspace(ctx), AMDGPU.blockIdx().x, AMDGPU.threadIdx().x)
257+
return I in __ndrange(ctx)
257258
else
258259
return true
259260
end
@@ -305,7 +306,7 @@ import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize
305306
###
306307
# GPU implementation of shared memory
307308
###
308-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
309+
@inline function Cassette.overdub(::ROCCtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
309310
ptr = AMDGPU.alloc_special(Val(Id), T, Val(AMDGPU.AS.Local), Val(prod(Dims)))
310311
AMDGPU.ROCDeviceArray(Dims, ptr)
311312
end
@@ -315,15 +316,15 @@ end
315316
# - private memory for each workitem
316317
###
317318

318-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims}
319+
@inline function Cassette.overdub(::ROCCtx, ::typeof(Scratchpad), ctx, ::Type{T}, ::Val{Dims}) where {T, Dims}
319320
MArray{__size(Dims), T}(undef)
320321
end
321322

322-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__synchronize))
323+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__synchronize))
323324
AMDGPU.sync_workgroup()
324325
end
325326

326-
@inline function Cassette.overdub(ctx::ROCCtx, ::typeof(__print), args...)
327+
@inline function Cassette.overdub(::ROCCtx, ::typeof(__print), args...)
327328
for arg in args
328329
AMDGPU.@rocprintf("%s", arg)
329330
end

src/KernelAbstractions.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module KernelAbstractions
22

33
export @kernel
4-
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize, @print
4+
export @Const, @localmem, @private, @uniform, @synchronize, @index, @groupsize, @print
55
export Device, GPU, CPU, Event, MultiEvent, NoneEvent
66
export async_copy!
77

@@ -111,14 +111,20 @@ function async_copy! end
111111
###
112112

113113
"""
114-
groupsize()
114+
@groupsize()
115115
116116
Query the workgroupsize on the device. This function returns
117117
a tuple corresponding to kernel configuration. In order to get
118-
the total size you can use `prod(groupsize())`.
118+
the total size you can use `prod(@groupsize())`.
119119
"""
120120
function groupsize end
121121

122+
macro groupsize()
123+
quote
124+
$groupsize($(esc(:__ctx__)))
125+
end
126+
end
127+
122128
"""
123129
@localmem T dims
124130
@@ -150,7 +156,7 @@ macro private(T, dims)
150156
dims = (dims,)
151157
end
152158
quote
153-
$Scratchpad($(esc(T)), Val($(esc(dims))))
159+
$Scratchpad($(esc(:__ctx__)), $(esc(T)), Val($(esc(dims))))
154160
end
155161
end
156162

@@ -297,7 +303,7 @@ macro index(locale, args...)
297303
end
298304

299305
index_function = Symbol(:__index_, locale, :_, indexkind)
300-
Expr(:call, GlobalRef(KernelAbstractions, index_function), map(esc, args)...)
306+
Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...)
301307
end
302308

303309
###
@@ -312,9 +318,9 @@ function __index_Local_Cartesian end
312318
function __index_Group_Cartesian end
313319
function __index_Global_Cartesian end
314320

315-
__index_Local_NTuple(I...) = Tuple(__index_Local_Cartesian(I...))
316-
__index_Group_NTuple(I...) = Tuple(__index_Group_Cartesian(I...))
317-
__index_Global_NTuple(I...) = Tuple(__index_Global_Cartesian(I...))
321+
__index_Local_NTuple(ctx, I...) = Tuple(__index_Local_Cartesian(ctx, I...))
322+
__index_Group_NTuple(ctx, I...) = Tuple(__index_Group_Cartesian(ctx, I...))
323+
__index_Global_NTuple(ctx, I...) = Tuple(__index_Global_Cartesian(ctx, I...))
318324

319325
struct ConstAdaptor end
320326

@@ -350,6 +356,10 @@ struct Kernel{Device, WorkgroupSize<:_Size, NDRange<:_Size, Fun}
350356
f::Fun
351357
end
352358

359+
function Base.similar(kernel::Kernel{D, WS, ND}, f::F) where {D, WS, ND, F}
360+
Kernel{D, WS, ND, F}(f)
361+
end
362+
353363
workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize
354364
ndrange(::Kernel{D, WorkgroupSize, NDRange}) where {D, WorkgroupSize,NDRange} = NDRange
355365

@@ -429,7 +439,7 @@ include("macros.jl")
429439
# Backends/Interface
430440
###
431441

432-
function Scratchpad(::Type{T}, ::Val{Dims}) where {T, Dims}
442+
function Scratchpad(ctx, ::Type{T}, ::Val{Dims}) where {T, Dims}
433443
throw(MethodError(Scratchpad, (T, Val(Dims))))
434444
end
435445

0 commit comments

Comments
 (0)