Skip to content

Commit e3c30d7

Browse files
committed
dogfood
1 parent 664a5a1 commit e3c30d7

File tree

6 files changed

+48
-43
lines changed

6 files changed

+48
-43
lines changed

src/Metal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using ExprTools: splitdef, combinedef
1212
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS
1313
import ObjectiveC: is_macos, darwin_version, macos_version
1414
import KernelAbstractions
15+
import KernelAbstractions: KernelIntrinsics as KI
1516
using ScopedValues
1617

1718
include("version.jl")

src/accumulate.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArray,
33
Rdim, Rpre, Rpost, Rother, neutral, init,
44
::Val{maxthreads}, ::Val{inclusive}=Val(true)) where {T, maxthreads, inclusive}
5-
threads = threads_per_threadgroup_3d().x
6-
thread = thread_position_in_threadgroup_3d().x
5+
threads = get_local_size().x
6+
thread = get_local_id().x
77

8-
temp = MtlThreadGroupArray(T, (Int32(2) * maxthreads,))
8+
temp = KI.localmemory(T, (Int32(2) * maxthreads,))
99

10-
i = (threadgroup_position_in_grid_3d().x - Int32(1)) * threads_per_threadgroup_3d().x + thread_position_in_threadgroup_3d().x
11-
j = (threadgroup_position_in_grid_3d().z - Int32(1)) * threadgroups_per_grid_3d().y + threadgroup_position_in_grid_3d().y
10+
i = (get_group_id().x - Int32(1)) * get_local_size().x + get_local_id().x
11+
j = (get_group_id().z - Int32(1)) * get_num_groups().y + get_group_id().y
1212

1313
if j > length(Rother)
1414
return
@@ -29,7 +29,7 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr
2929
offset = one(thread)
3030
d = threads >> 0x1
3131
while d > zero(d)
32-
threadgroup_barrier(MemoryFlagThreadGroup)
32+
KI.barrier()
3333
@inbounds if thread <= d
3434
ai = offset * (thread << 0x1 - 0x1)
3535
bi = offset * (thread << 0x1)
@@ -46,7 +46,7 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr
4646
d = one(thread)
4747
while d < threads
4848
offset >>= 0x1
49-
threadgroup_barrier(MemoryFlagThreadGroup)
49+
KI.barrier()
5050
@inbounds if thread <= d
5151
ai = offset * (thread << 0x1 - 0x1)
5252
bi = offset * (thread << 0x1)
@@ -58,7 +58,7 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr
5858
d <<= 0x1
5959
end
6060

61-
threadgroup_barrier(MemoryFlagThreadGroup)
61+
KI.barrier()
6262

6363
@inbounds if i <= length(Rdim)
6464
val = if inclusive
@@ -76,10 +76,10 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr
7676
end
7777

7878
function aggregate_partial_scan(op::Function, output::AbstractArray, aggregates::AbstractArray, Rdim, Rpre, Rpost, Rother, init)
79-
block = threadgroup_position_in_grid_3d().x
79+
block = get_group_id().x
8080

81-
i = (threadgroup_position_in_grid_3d().x - Int32(1)) * threads_per_threadgroup_3d().x + thread_position_in_threadgroup_3d().x
82-
j = (threadgroup_position_in_grid_3d().z - Int32(1)) * threadgroups_per_grid_3d().y + threadgroup_position_in_grid_3d().y
81+
i = (get_group_id().x - Int32(1)) * get_local_size().x + get_local_id().x
82+
j = (get_group_id().z - Int32(1)) * get_num_groups().y + get_group_id().y
8383

8484
@inbounds if i <= length(Rdim) && j <= length(Rother)
8585
I = Rother[j]

src/broadcast.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ end
6666
if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
6767
## COV_EXCL_START
6868
function broadcast_cartesian_static(dest, bc, Is)
69-
i = thread_position_in_grid().x
70-
stride = threads_per_grid().x
69+
i = KI.get_global_id().x
70+
stride = KI.get_global_size().x
7171
while 1 <= i <= length(dest)
7272
I = @inbounds Is[i]
7373
@inbounds dest[I] = bc[I]
@@ -91,8 +91,8 @@ end
9191
(isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
9292
## COV_EXCL_START
9393
function broadcast_linear(dest, bc)
94-
i = thread_position_in_grid().x
95-
stride = threads_per_grid().x
94+
i = KI.get_global_id().x
95+
stride = KI.get_global_size().x
9696
while 1 <= i <= length(dest)
9797
@inbounds dest[i] = bc[i]
9898
i += stride
@@ -150,8 +150,8 @@ end
150150
else
151151
## COV_EXCL_START
152152
function broadcast_cartesian(dest, bc)
153-
i = thread_position_in_grid().x
154-
stride = threads_per_grid().x
153+
i = KI.get_global_id().x
154+
stride = KI.get_global_size().x
155155
while 1 <= i <= length(dest)
156156
I = @inbounds CartesianIndices(dest)[i]
157157
@inbounds dest[I] = bc[I]

src/device/random.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ end
8888
elseif field === :ctr1
8989
@inbounds global_random_counters()[simdgroupId]
9090
elseif field === :ctr2
91-
globalId = thread_position_in_grid().x +
92-
(thread_position_in_grid().y - 1i32) * threads_per_grid().x +
93-
(thread_position_in_grid().z - 1i32) * threads_per_grid().x * threads_per_grid().y
91+
globalId = KI.get_global_id().x +
92+
(KI.get_global_id().y - 1i32) * KI.get_global_size().x +
93+
(KI.get_global_id().z - 1i32) * KI.get_global_size().x * KI.get_global_size().y
9494
globalId % UInt32
9595
end::UInt32
9696
end

src/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function Base.findall(bools::WrappedMtlArray{Bool})
3333

3434
if n > 0
3535
function kernel(ys::MtlDeviceArray, bools, indices)
36-
i = (threadgroup_position_in_grid().x - Int32(1)) * threads_per_threadgroup().x + thread_position_in_threadgroup().x
36+
i = (KI.get_group_id().x - Int32(1)) * KI.get_local_size().x + KI.get_local_id().x
3737

3838
@inbounds if i <= length(bools) && bools[i]
3939
i′ = CartesianIndices(bools)[i]

src/mapreduce.jl

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ end
2020
@inline function reduce_group(op, val::T, neutral, shuffle::Val{true}, ::Val{maxthreads}) where {T, maxthreads}
2121
# shared mem for partial sums
2222
assume(threads_per_simdgroup() == 32)
23-
shared = MtlThreadGroupArray(T, 32)
23+
shared = KI.localmemory(T, 32)
2424

2525
wid = simdgroup_index_in_threadgroup()
2626
lane = thread_index_in_simdgroup()
@@ -34,10 +34,10 @@ end
3434
end
3535

3636
# wait for all partial reductions
37-
threadgroup_barrier(MemoryFlagThreadGroup)
37+
KI.barrier()
3838

3939
# read from shared memory only if that warp existed
40-
val = if thread_index_in_threadgroup() <= fld1(threads_per_threadgroup().x, 32)
40+
val = if KI.get_local_id().x <= fld1(KI.get_local_size().x, 32)
4141
@inbounds shared[lane]
4242
else
4343
neutral
@@ -52,17 +52,17 @@ end
5252

5353
# Reduce a value across a group, using local memory for communication
5454
@inline function reduce_group(op, val::T, neutral, shuffle::Val{false}, ::Val{maxthreads}) where {T, maxthreads}
55-
threads = threads_per_threadgroup().x
56-
thread = thread_position_in_threadgroup().x
55+
threads = KI.get_local_size().x
56+
thread = KI.get_local_id().x
5757

5858
# local mem for a complete reduction
59-
shared = MtlThreadGroupArray(T, (maxthreads,))
59+
shared = KI.localmemory(T, (maxthreads,))
6060
@inbounds shared[thread] = val
6161

6262
# perform a reduction
6363
d = 1
6464
while d < threads
65-
threadgroup_barrier(MemoryFlagThreadGroup)
65+
KI.barrier()
6666
index = 2 * d * (thread-1) + 1
6767
@inbounds if index <= threads
6868
other_val = if index + d <= threads
@@ -94,9 +94,9 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
9494
::Val{Rother}, ::Val{Rlen}, ::Val{grain}, shuffle, R, As...) where {Rreduce, Rother, Rlen, grain}
9595
# decompose the 1D hardware indices into separate ones for reduction (across items
9696
# and possibly groups if it doesn't fit) and other elements (remaining groups)
97-
localIdx_reduce = thread_position_in_threadgroup().x
98-
localDim_reduce = threads_per_threadgroup().x * grain
99-
groupIdx_reduce, groupIdx_other = fldmod1(threadgroup_position_in_grid().x, Rlen)
97+
localIdx_reduce = KI.get_local_id().x
98+
localDim_reduce = KI.get_local_size().x * grain
99+
groupIdx_reduce, groupIdx_other = fldmod1(KI.get_group_id().x, Rlen)
100100

101101
# group-based indexing into the values outside of the reduction dimension
102102
# (that means we can safely synchronize items within this group)
@@ -141,7 +141,7 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
141141
end
142142

143143
function serial_mapreduce_kernel(f, op, neutral, ::Val{Rreduce}, ::Val{Rother}, R, As) where {Rreduce, Rother}
144-
grid_idx = thread_position_in_grid().x
144+
grid_idx = KI.get_global_id().x
145145

146146
@inbounds if grid_idx <= length(Rother)
147147
Iother = Rother[grid_idx]
@@ -166,11 +166,12 @@ end
166166

167167
## COV_EXCL_STOP
168168

169-
serial_mapreduce_threshold(dev) = dev.maxThreadsPerThreadgroup.width * num_gpu_cores()
169+
serial_mapreduce_threshold(dev) = KI.max_work_group_size(MetalBackend()) * KI.multiprocessor_count(MetalBackend())
170170

171171
function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
172172
A::Union{AbstractArray,Broadcast.Broadcasted};
173173
init=nothing) where {F, OP, T}
174+
backend = MetalBackend()
174175
Base.check_reducedims(R, A)
175176
length(A) == 0 && return R # isempty(::Broadcasted) iterates
176177

@@ -195,10 +196,10 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
195196

196197
# If `Rother` is large enough, then a naive loop is more efficient than partial reductions.
197198
if length(Rother) >= serial_mapreduce_threshold(device(R))
198-
kernel = @metal launch=false serial_mapreduce_kernel(f, op, init, Val(Rreduce), Val(Rother), R, A)
199-
threads = min(length(Rother), kernel.pipeline.maxTotalThreadsPerThreadgroup)
199+
kernel = KI.KIKernel(backend, serial_mapreduce_kernel, f, op, init, Val(Rreduce), Val(Rother), R, A)
200+
threads = KI.kernel_max_work_group_size(backend, kernel; max_work_items=length(Rother))
200201
groups = cld(length(Rother), threads)
201-
kernel(f, op, init, Val(Rreduce), Val(Rother), R, A; threads, groups)
202+
kernel(f, op, init, Val(Rreduce), Val(Rother), R, A; numworkgroups=groups, workgroupsize=threads)
202203
return R
203204
end
204205

@@ -223,17 +224,17 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
223224
# we might not be able to launch all those threads to reduce each slice in one go.
224225
# that's why each threads also loops across their inputs, processing multiple values
225226
# so that we can span the entire reduction dimension using a single item group.
226-
kernel = @metal launch=false partial_mapreduce_device(f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
227+
kernel = KI.KIKernel(backend, partial_mapreduce_device, f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
227228
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A)
228229

229230
# how many threads do we want?
230231
#
231232
# threads in a group work together to reduce values across the reduction dimensions;
232233
# we want as many as possible to improve algorithm efficiency and execution occupancy.
233-
wanted_threads = shuffle ? nextwarp(kernel.pipeline, length(Rreduce)) : length(Rreduce)
234+
wanted_threads = shuffle ? nextwarp(kernel.kern.pipeline, length(Rreduce)) : length(Rreduce)
234235
function compute_threads(max_threads)
235236
if wanted_threads > max_threads
236-
shuffle ? prevwarp(kernel.pipeline, max_threads) : max_threads
237+
shuffle ? prevwarp(kernel.kern.pipeline, max_threads) : max_threads
237238
else
238239
wanted_threads
239240
end
@@ -243,7 +244,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
243244
# kernel above may be greater than the maxTotalThreadsPerThreadgroup of the eventually launched
244245
# kernel below, causing errors
245246
# reduce_threads = compute_threads(kernel.pipeline.maxTotalThreadsPerThreadgroup)
246-
reduce_threads = compute_threads(512)
247+
reduce_threads = compute_threads(KI.kernel_max_work_group_size(backend, kernel))
247248

248249
# how many groups should we launch?
249250
#
@@ -262,7 +263,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
262263
# we can cover the dimensions to reduce using a single group
263264
kernel(f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
264265
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A;
265-
threads, groups)
266+
numworkgroups=groups, workgroupsize=threads)
266267
else
267268
# we need multiple steps to cover all values to reduce
268269
partial = similar(R, (size(R)..., reduce_groups))
@@ -273,9 +274,12 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
273274
end
274275
# NOTE: we can't use the previously-compiled kernel, since the type of `partial`
275276
# might not match the original output container (e.g. if that was a view).
276-
@metal threads groups partial_mapreduce_device(
277+
KI.KIKernel(backend, partial_mapreduce_device,
277278
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
278-
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A)
279+
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A)(
280+
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
281+
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A;
282+
numworkgroups=groups, workgroupsize=threads)
279283

280284
GPUArrays.mapreducedim!(identity, op, R, partial; init=init)
281285
end

0 commit comments

Comments
 (0)