Skip to content

Commit 45aba0e

Browse files
committed
Decrease block size from 256 to 64
1M has register pressure that is so high that even a small decrease in occupancy has a big impact:
1 parent 21f1ae3 commit 45aba0e

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

ext/cuda/data_layouts_copyto.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ indicating if the column is padded (true for 63, false for 64).
3030
function knl_copyto_VIJFH_64!(dest, src, ::Val{P}) where {P}
3131
# P is a boolean, indicating if the column is padded
3232
P && threadIdx().x == 64 && return nothing
33-
I = CartesianIndex(threadIdx().y, blockIdx().x, 1, threadIdx().x, blockIdx().y)
33+
I = CartesianIndex(blockIdx().x, blockIdx().y, 1, threadIdx().x, blockIdx().z)
3434
@inbounds dest[I] = src[I]
3535
return nothing
3636
end
@@ -133,8 +133,8 @@ function Base.copyto!(
133133
auto_launch!(
134134
knl_copyto_VIJFH_64!,
135135
args;
136-
threads_s = (64, Ni, 1),
137-
blocks_s = (Nj, Nh, 1),
136+
threads_s = (64, 1, 1),
137+
blocks_s = (Ni, Nj, Nh),
138138
)
139139
return dest
140140
end
@@ -150,8 +150,8 @@ function Base.copyto!(
150150
auto_launch!(
151151
knl_copyto_VIJFH_64!,
152152
args;
153-
threads_s = (64, Ni, 1),
154-
blocks_s = (Nj, Nh, 1),
153+
threads_s = (64, 1, 1),
154+
blocks_s = (Ni, Nj, Nh),
155155
)
156156
return dest
157157
end

ext/cuda/operators_finite_difference.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ function Base.copyto!(
9292
auto_launch!(
9393
copyto_stencil_kernel_64!,
9494
args;
95-
threads_s = (64, Ni, 1),
96-
blocks_s = (Nj, Nh, 1),
95+
threads_s = (64, 1, 1),
96+
blocks_s = (Ni, Nj, Nh),
9797
)
9898
return out
9999
end
@@ -161,10 +161,10 @@ function copyto_stencil_kernel_64!(
161161
@inbounds begin
162162
# P is a boolean, indicating if the column is padded
163163
P && threadIdx().x == 64 && return nothing
164-
i = threadIdx().y
165-
j = blockIdx().x
164+
i = blockIdx().x
165+
j = blockIdx().y
166166
v = threadIdx().x
167-
h = blockIdx().y
167+
h = blockIdx().z
168168
hidx = (i, j, h)
169169
(li, lw, rw, ri) = bds
170170
idx = v - 1 + li

0 commit comments

Comments
 (0)