625625
626626@testset  " simd intrinsics" begin 
627627
628- @testset  " shuffle ($typ )" for  typ in  [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
628+ @testset  " $f ($typ )" for  typ in  [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8], (f,res_idx)  in  [(simd_shuffle_down,  1 ), (simd_shuffle_up,  32 ) ]
629629    function  kernel (a:: MtlDeviceVector{T} , b:: MtlDeviceVector{T} ) where  T
630630        idx =  thread_position_in_grid_1d ()
631631        idx_in_simd =  thread_index_in_simdgroup ()
@@ -638,11 +638,11 @@ end
638638        if  simd_idx ==  1 
639639            value =  temp[idx_in_simd]
640640
641-             value =  value +  simd_shuffle_down (value, 16 )
642-             value =  value +  simd_shuffle_down (value,  8 )
643-             value =  value +  simd_shuffle_down (value,  4 )
644-             value =  value +  simd_shuffle_down (value,  2 )
645-             value =  value +  simd_shuffle_down (value,  1 )
641+             value =  value +  f (value, 16 )
642+             value =  value +  f (value,  8 )
643+             value =  value +  f (value,  4 )
644+             value =  value +  f (value,  2 )
645+             value =  value +  f (value,  1 )
646646
647647            b[idx] =  value
648648        end 
656656
657657    rand! (a, (1 : 4 ))
658658    Metal. @sync  @metal  threads= 32  kernel (dev_a, dev_b)
659-     @test  sum (a) ≈  b[1 ]
659+     @test  sum (a) ≈  b[res_idx ]
660660end 
661661
662662@testset  " matrix functions" begin 
0 commit comments