Skip to content

Commit 79e5e6a

Browse files
Test both simd shuffle intrinsics. (#553)
* Docstring fixup * Test both up and down shuffles
1 parent 2bf7e24 commit 79e5e6a

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

src/device/intrinsics/simd.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ end
112112
@doc """
113113
simd_shuffle_down(data::T, delta::Integer)
114114
115-
Return data from the thread whose SIMD lane ID is the sum of callers SIMD lane ID and delta.
115+
Return `data` from the thread whose SIMD lane ID is the sum of caller's SIMD lane ID and `delta`.
116116
117-
The value for delta must be the same for all threads in the SIMD-group. This function
118-
doesnt modify the upper delta lanes of data because it doesnt wrap values around
117+
The value for `delta` must be the same for all threads in the SIMD-group. This function
118+
doesn't modify the upper `delta` lanes of `data` because it doesn't wrap values around
119119
the SIMD-group.
120120
121121
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
@@ -125,11 +125,11 @@ simd_shuffle_down
125125
@doc """
126126
simd_shuffle_up(data::T, delta::Integer)
127127
128-
Return data from the thread whose SIMD lane ID is the difference from the callers SIMD
129-
lane ID minus delta.
128+
Return `data` from the thread whose SIMD lane ID is the difference from the caller's SIMD
129+
lane ID minus `delta`.
130130
131-
The value of delta must be the same for all threads in a SIMD-group. This function doesnt
132-
modify the lower delta lanes of data because it doesnt wrap values around the SIMD-group.
131+
The value of `delta` must be the same for all threads in a SIMD-group. This function doesn't
132+
modify the lower `delta` lanes of `data` because it doesn't wrap values around the SIMD-group.
133133
134134
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135135
"""

test/device/intrinsics.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ end
625625

626626
@testset "simd intrinsics" begin
627627

628-
@testset "shuffle($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
628+
@testset "$f($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8], (f,res_idx) in [(simd_shuffle_down, 1), (simd_shuffle_up, 32)]
629629
function kernel(a::MtlDeviceVector{T}, b::MtlDeviceVector{T}) where T
630630
idx = thread_position_in_grid_1d()
631631
idx_in_simd = thread_index_in_simdgroup()
@@ -638,11 +638,11 @@ end
638638
if simd_idx == 1
639639
value = temp[idx_in_simd]
640640

641-
value = value + simd_shuffle_down(value, 16)
642-
value = value + simd_shuffle_down(value, 8)
643-
value = value + simd_shuffle_down(value, 4)
644-
value = value + simd_shuffle_down(value, 2)
645-
value = value + simd_shuffle_down(value, 1)
641+
value = value + f(value, 16)
642+
value = value + f(value, 8)
643+
value = value + f(value, 4)
644+
value = value + f(value, 2)
645+
value = value + f(value, 1)
646646

647647
b[idx] = value
648648
end
@@ -656,7 +656,7 @@ end
656656

657657
rand!(a, (1:4))
658658
Metal.@sync @metal threads=32 kernel(dev_a, dev_b)
659-
@test sum(a) b[1]
659+
@test sum(a) b[res_idx]
660660
end
661661

662662
@testset "matrix functions" begin

0 commit comments

Comments
 (0)