Skip to content

Commit 47aefc3

Browse files
Add shuffle and fill intrinsics (#555)
* Add and test shuffle and fill intrinsics
1 parent 4906e89 commit 47aefc3

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

src/device/intrinsics/simd.jl

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export simdgroup_load, simdgroup_store, simdgroup_multiply, simdgroup_multiply_accumulate,
2-
simd_shuffle_down, simd_shuffle_up
2+
simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up
33

44
using Core: LLVMPtr
55

@@ -104,6 +104,14 @@ for (jltype, suffix) in simd_shuffle_map
104104
@device_function simd_shuffle_up(data::$jltype, delta::Integer) =
105105
ccall($"extern air.simd_shuffle_up.$suffix",
106106
llvmcall, $jltype, ($jltype, Int16), data, delta)
107+
108+
@device_function simd_shuffle_and_fill_down(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer=threads_per_simdgroup()) =
109+
ccall($"extern air.simd_shuffle_and_fill_down.$suffix",
110+
llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo)
111+
112+
@device_function simd_shuffle_and_fill_up(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer=threads_per_simdgroup()) =
113+
ccall($"extern air.simd_shuffle_and_fill_up.$suffix",
114+
llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo)
107115
end
108116
end
109117

@@ -134,3 +142,39 @@ modify the lower `delta` lanes of `data` because it doesn't wrap values around t
134142
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135143
"""
136144
simd_shuffle_up
145+
146+
@doc """
147+
simd_shuffle_and_fill_down(data::T, filling_data::T, delta::Integer, [modulo::Integer])
148+
149+
Returns `data` or `filling_data` for each vector from the thread whose SIMD lane ID is the
150+
difference from the caller's SIMD lane ID minus `delta`.
151+
152+
If the difference is negative, the operation copies values from the upper `delta` lanes of
153+
`filling_data` to the lower `delta` lanes of `data`.
154+
155+
The value of `delta` needs to be the same for all threads in a SIMD-group.
156+
157+
The `modulo` parameter defines the vector width that splits the SIMD-group into separate vectors
158+
and must be 2, 4, 8, 16, or 32.
159+
160+
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
161+
"""
162+
simd_shuffle_and_fill_down
163+
164+
@doc """
165+
simd_shuffle_and_fill_up(data::T, filling_data::T, delta::Integer, [modulo::Integer])
166+
167+
Returns `data` or `filling_data` for each vector from the thread whose SIMD lane ID is the
168+
sum of the caller's SIMD lane ID and `delta`.
169+
170+
If the sum is greater than `modulo`, the function copies values from the lower `delta` lanes of
171+
`filling_data` into the upper `delta` lanes of `data`.
172+
173+
The value of `delta` needs to be the same for all threads in a SIMD-group.
174+
175+
The `modulo` parameter defines the vector width that splits the SIMD-group into separate vectors
176+
and must be 2, 4, 8, 16, or 32.
177+
178+
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
179+
"""
180+
simd_shuffle_and_fill_up

test/device/intrinsics/simd.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,44 @@
3434
Metal.@sync @metal threads=32 kernel(dev_a, dev_b)
3535
@test sum(a) b[res_idx]
3636
end
37+
@testset "$f($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8], (f,nshift) in [(simd_shuffle_and_fill_down, -4), (simd_shuffle_and_fill_up, 2)]
38+
function kernel_mod(data::MtlDeviceVector{T}, filling_data::MtlDeviceVector{T}, modulo) where T
39+
idx = thread_position_in_grid_1d()
40+
idx_in_simd = thread_index_in_simdgroup() #simd_lane_id
41+
simd_idx = simdgroup_index_in_threadgroup() #simd_group_id
42+
43+
temp_data = MtlThreadGroupArray(T, 16)
44+
temp_data[idx] = data[idx]
45+
temp_filling_data = MtlThreadGroupArray(T, 16)
46+
temp_filling_data[idx] = filling_data[idx]
47+
simdgroup_barrier(Metal.MemoryFlagThreadGroup)
48+
49+
if simd_idx == 1
50+
dat_value = temp_data[idx_in_simd]
51+
dat_fil_value = temp_filling_data[idx_in_simd]
52+
53+
value = f(dat_value, dat_fil_value, abs(nshift), modulo)
3754

55+
data[idx] = value
56+
end
57+
return
58+
end
59+
60+
N = 16
61+
midN = N ÷ 2
62+
63+
data = Array{typ}(1:N)
64+
mtldata = MtlArray(data)
65+
mtlfilling = MtlArray(data)
66+
67+
Metal.@sync @metal threads=N kernel_mod(mtldata, mtlfilling, N)
68+
@test Array(mtldata) == circshift(data, nshift)
69+
70+
mtlfilling2 = MtlArray(data)
71+
72+
Metal.@sync @metal threads=N kernel_mod(mtlfilling2, mtlfilling, midN)
73+
@test Array(mtlfilling2) == [circshift(data[1:midN], nshift); circshift(data[midN+1:end], nshift)]
74+
end
3875
@testset "matrix functions" begin
3976
@testset "load_store($typ)" for typ in [Float16, Float32]
4077
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T},

0 commit comments

Comments
 (0)