@@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64})
77    return  (VecElement {Int64} (origin[1 ]- 1 ), VecElement {Int64} (origin[2 ]- 1 ))
88end 
99
10- for  (jltype, suffix) in  ((:Float16 , " f16" :Float32 , " f32" 
10+ for  (jltype, suffix) in  ((:Float16 , " f16" :Float32 , " f32" , ( :BFloat16 ,  " bf16 " ) )
1111    for  as in  (AS. Device, AS. ThreadGroup)
1212        @eval  begin 
1313            @device_function  simdgroup_load (
5555    simdgroup_load(data::MtlDeviceArray{T}, matrix_origin=(1, 1)) 
5656
5757Loads data from device or threadgroup memory into an 8x8 SIMD-group matrix 
58- and returns it. `T` must be either `Float16`  or `Float32 `. 
58+ and returns it. `T` must be either `Float16`, `Float32`,  or `BFloat16 `. 
5959
6060# Arguments 
6161- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the source memory to load from. 
@@ -65,7 +65,7 @@ and returns it. `T` must be either `Float16` or `Float32`.
6565    simdgroup_store(src, dest::MtlDeviceArray{T}, matrix_origin=(1, 1)) 
6666
6767Stores data from an 8x8 SIMD-group matrix into device or threadgroup memory. 
68- `T` must be either `Float16` or  `Float32`. 
68+ `T` must be either `Float16`,  `Float32`, `BFloat16 `. 
6969
7070# Arguments 
7171- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the destination memory to store to. 
@@ -88,6 +88,7 @@ Returns `a * b + c`.
8888
8989simd_shuffle_map =  ((Float32, " f32" 
9090                    (Float16, " f16" 
91+                     (BFloat16, " bf16" 
9192                    (Int32,   " s.i32" 
9293                    (UInt32,  " u.i32" 
9394                    (Int16,   " s.i16" 
@@ -118,7 +119,7 @@ The value for delta must be the same for all threads in the SIMD-group. This fun
118119doesn’t modify the upper delta lanes of data because it doesn’t wrap values around 
119120the SIMD-group. 
120121
121- T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8 
122+ T must be one of the following: Float32, Float16, BFloat16,  Int32, UInt32, Int16, UInt16, Int8, or UInt8 
122123""" 
123124simd_shuffle_down
124125
@@ -131,6 +132,6 @@ lane ID minus delta.
131132The value of delta must be the same for all threads in a SIMD-group. This function doesn’t 
132133modify the lower delta lanes of data because it doesn’t wrap values around the SIMD-group. 
133134
134- T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8 
135+ T must be one of the following: Float32, Float16, BFloat16,  Int32, UInt32, Int16, UInt16, Int8, or UInt8 
135136""" 
136137simd_shuffle_up
0 commit comments