Skip to content

Commit ca12f0d

Browse files
committed
[Kernels][GPU] Add BF16 FMA emulation using FP32 for RDNA3
Add FP32 emulation for BF16 FMA operations in gemm_kernel_1 for AMD GPUs without tensor core support. Today that should be RNDA3. Although RDNA3 has BF16 support, it can ony use it with WMMA and LLVM does not lower lower v_wmma_* intrinsics yet, it effectively means until then RDNA3 support llacks BF16 support. This emulation could effectiely be leveraged by other GPUs later in similar predicaments. The emulation strategy: 1. Promote BF16 operands to FP32 using cast operations 2. Perform FP32 multiply-accumulate operations 3. Convert final FP32 result back to BF16
1 parent f622c74 commit ca12f0d

File tree

1 file changed

+42
-14
lines changed

1 file changed

+42
-14
lines changed

max/kernels/test/gpu/layout/matmul_kernels.mojo

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# limitations under the License.
1212
# ===----------------------------------------------------------------------=== #
1313
from math import ceildiv
14-
from sys.info import simd_width_of
14+
from sys import has_amd_gpu_accelerator
15+
from sys.info import _has_gpu_tensor_cores, simd_width_of
1516

1617
import linalg.matmul.vendor.blas as vendor_blas
1718
from benchmark import Bench, Bencher, BenchId, BenchMetric, ThroughputMeasure
@@ -160,19 +161,46 @@ fn gemm_kernel_1[
160161
var dst = c.tile[BM, BN](bidy, bidx)
161162

162163
# Initialize a register to accumulate the result for this thread.
163-
var dst_reg: c.element_type = 0
164-
165-
# Iterate over the K dimension to compute the dot product.
166-
for k in range(b.dim[0]()):
167-
# Get the corresponding tiles from matrices A and B.
168-
var a_tile = a.tile[BM, 1](bidy, k)
169-
var b_tile = b.tile[1, BN](k, bidx)
170-
171-
# Multiply the elements and accumulate the result.
172-
dst_reg += a_tile[row, 0] * b_tile[0, col]
173-
174-
# Write the final accumulated result to the output matrix.
175-
dst[row, col] += dst_reg
164+
@parameter
165+
if (
166+
dtype == DType.bfloat16
167+
and has_amd_gpu_accelerator()
168+
and not _has_gpu_tensor_cores()
169+
):
170+
var dst_reg: Float32 = 0
171+
172+
# Iterate over the K dimension to compute the dot product.
173+
for k in range(b.dim[0]()):
174+
# Get the corresponding tiles from matrices A and B.
175+
var a_tile = a.tile[BM, 1](bidy, k)
176+
var b_tile = b.tile[1, BN](k, bidx)
177+
178+
# Emulate BF16 FMA: promote to FP32, compute, convert back
179+
# Rebind layout tensor elements to scalars for arithmetic
180+
var a_val = rebind[Scalar[DType.float32]](
181+
a_tile[row, 0].cast[DType.float32]()
182+
)
183+
var b_val = rebind[Scalar[DType.float32]](
184+
b_tile[0, col].cast[DType.float32]()
185+
)
186+
dst_reg += a_val * b_val
187+
188+
# Convert FP32 result back to BF16 and write to output
189+
dst[row, col] += dst_reg.cast[dtype]()
190+
else:
191+
var dst_reg: c.element_type = 0
192+
193+
# Iterate over the K dimension to compute the dot product.
194+
for k in range(b.dim[0]()):
195+
# Get the corresponding tiles from matrices A and B.
196+
var a_tile = a.tile[BM, 1](bidy, k)
197+
var b_tile = b.tile[1, BN](k, bidx)
198+
199+
# Multiply the elements and accumulate the result.
200+
dst_reg += a_tile[row, 0] * b_tile[0, col]
201+
202+
# Write the final accumulated result to the output matrix.
203+
dst[row, col] += dst_reg
176204

177205

178206
fn run_gemm_kernel_1[

0 commit comments

Comments
 (0)