Skip to content

Commit 4069021

Browse files
authored
Merge pull request #817
support cooperative groups
2 parents 5e5b7b7 + a277330 commit 4069021

File tree

4 files changed

+44
-4
lines changed

4 files changed

+44
-4
lines changed

src/device/gcn/synchronization.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,7 @@ and returns non-zero if and only if predicate evaluates to non-zero for any of t
6969
end
7070

7171
@inline __not(x::Cint)::Cint = ifelse(iszero(x), one(x), zero(x))
72+
73+
@inline function sync_grid()::Cvoid
74+
ccall("extern __ockl_grid_sync", llvmcall, Cvoid, ())
75+
end

src/highlevel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ rocconvert(arg) = adapt(Runtime.Adaptor(), arg)
9494

9595
const MACRO_KWARGS = [:launch]
9696
const COMPILER_KWARGS = [:name, :unsafe_fp_atomics]
97-
const LAUNCH_KWARGS = [:gridsize, :groupsize, :shmem, :stream]
97+
const LAUNCH_KWARGS = [:gridsize, :groupsize, :shmem, :stream, :cooperative]
9898

9999
"""
100100
@roc [kwargs...] func(args...)

src/runtime/hip-execution.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,20 @@ function launch(
117117
fun::HIP.HIPFunction, args::Vararg{Any, N};
118118
gridsize = 1, groupsize = 1,
119119
shmem::Integer = 0, stream::HIP.HIPStream,
120+
cooperative = false,
120121
) where N
121122
gd = gridsize isa ROCDim3 ? gridsize : ROCDim3(gridsize)
122123
bd = groupsize isa ROCDim3 ? groupsize : ROCDim3(groupsize)
123124
pack_arguments(args...) do kernel_params
124-
HIP.hipModuleLaunchKernel(
125-
fun, gd.x, gd.y, gd.z, bd.x, bd.y, bd.z,
126-
shmem, stream, kernel_params, C_NULL)
125+
if cooperative
126+
HIP.hipModuleLaunchCooperativeKernel(
127+
fun, gd.x, gd.y, gd.z, bd.x, bd.y, bd.z,
128+
shmem, stream, kernel_params)
129+
else
130+
HIP.hipModuleLaunchKernel(
131+
fun, gd.x, gd.y, gd.z, bd.x, bd.y, bd.z,
132+
shmem, stream, kernel_params, C_NULL)
133+
end
127134
end
128135

129136
AMDGPU.LAUNCH_BLOCKING[] && AMDGPU.synchronize(stream)

test/device/launch.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,32 @@ end
130130
# @test occ1 != occ2
131131
# end
132132
end
133+
134+
if !iszero(AMDGPU.HIP.properties(AMDGPU.device()).cooperativeLaunch)
135+
@testset "Cooperative Groups" begin
136+
function test_kernel!(x)
137+
block_row, block_col = workgroupIdx().x, workgroupIdx().y
138+
139+
for diag in 2:(gridGroupDim().x + gridGroupDim().y)
140+
if block_row + block_col == diag
141+
if diag == 2
142+
x[block_col, block_row] = 1
143+
else
144+
x[block_col, block_row] = 0
145+
for I in CartesianIndices(x)
146+
i, j = Tuple(I)
147+
i + j == diag - 1 || continue
148+
x[block_col, block_row] += x[I]
149+
end
150+
end
151+
end
152+
AMDGPU.Device.sync_grid()
153+
end
154+
end
155+
156+
n = 4
157+
x = ROCArray{Int}(undef, n, n)
158+
@roc groupsize = (1, 1, 1) gridsize = (n, n, 1) cooperative = true test_kernel!(x)
159+
@test Array(x) == [1 1 2 6; 1 2 6 24; 2 6 24 72; 6 24 72 144]
160+
end
161+
end

0 commit comments

Comments
 (0)