|
11 | 11 | # limitations under the License. |
12 | 12 | # ===----------------------------------------------------------------------=== # |
13 | 13 | from math import ceildiv |
14 | | -from sys.info import simd_width_of |
| 14 | +from sys import has_amd_gpu_accelerator |
| 15 | +from sys.info import _has_gpu_tensor_cores, simd_width_of |
15 | 16 |
|
16 | 17 | import linalg.matmul.vendor.blas as vendor_blas |
17 | 18 | from benchmark import Bench, Bencher, BenchId, BenchMetric, ThroughputMeasure |
@@ -160,19 +161,46 @@ fn gemm_kernel_1[ |
160 | 161 | var dst = c.tile[BM, BN](bidy, bidx) |
161 | 162 |
|
162 | 163 | # Initialize a register to accumulate the result for this thread. |
163 | | - var dst_reg: c.element_type = 0 |
164 | | - |
165 | | - # Iterate over the K dimension to compute the dot product. |
166 | | - for k in range(b.dim[0]()): |
167 | | - # Get the corresponding tiles from matrices A and B. |
168 | | - var a_tile = a.tile[BM, 1](bidy, k) |
169 | | - var b_tile = b.tile[1, BN](k, bidx) |
170 | | - |
171 | | - # Multiply the elements and accumulate the result. |
172 | | - dst_reg += a_tile[row, 0] * b_tile[0, col] |
173 | | - |
174 | | - # Write the final accumulated result to the output matrix. |
175 | | - dst[row, col] += dst_reg |
| 164 | + @parameter |
| 165 | + if ( |
| 166 | + dtype == DType.bfloat16 |
| 167 | + and has_amd_gpu_accelerator() |
| 168 | + and not _has_gpu_tensor_cores() |
| 169 | + ): |
| 170 | + var dst_reg: Float32 = 0 |
| 171 | + |
| 172 | + # Iterate over the K dimension to compute the dot product. |
| 173 | + for k in range(b.dim[0]()): |
| 174 | + # Get the corresponding tiles from matrices A and B. |
| 175 | + var a_tile = a.tile[BM, 1](bidy, k) |
| 176 | + var b_tile = b.tile[1, BN](k, bidx) |
| 177 | + |
| 178 | + # Emulate BF16 FMA: promote to FP32, compute, convert back |
| 179 | + # Rebind layout tensor elements to scalars for arithmetic |
| 180 | + var a_val = rebind[Scalar[DType.float32]]( |
| 181 | + a_tile[row, 0].cast[DType.float32]() |
| 182 | + ) |
| 183 | + var b_val = rebind[Scalar[DType.float32]]( |
| 184 | + b_tile[0, col].cast[DType.float32]() |
| 185 | + ) |
| 186 | + dst_reg += a_val * b_val |
| 187 | + |
| 188 | + # Convert FP32 result back to BF16 and write to output |
| 189 | + dst[row, col] += dst_reg.cast[dtype]() |
| 190 | + else: |
| 191 | + var dst_reg: c.element_type = 0 |
| 192 | + |
| 193 | + # Iterate over the K dimension to compute the dot product. |
| 194 | + for k in range(b.dim[0]()): |
| 195 | + # Get the corresponding tiles from matrices A and B. |
| 196 | + var a_tile = a.tile[BM, 1](bidy, k) |
| 197 | + var b_tile = b.tile[1, BN](k, bidx) |
| 198 | + |
| 199 | + # Multiply the elements and accumulate the result. |
| 200 | + dst_reg += a_tile[row, 0] * b_tile[0, col] |
| 201 | + |
| 202 | + # Write the final accumulated result to the output matrix. |
| 203 | + dst[row, col] += dst_reg |
176 | 204 |
|
177 | 205 |
|
178 | 206 | fn run_gemm_kernel_1[ |
|
0 commit comments