Skip to content

Commit 7403bec

Browse files
committed
[Kernels][GPU] Add BF16 FMA matmul test
Add BF16 FMA (Fused Multiply-Add) test to test_matmul.mojo that uses scalar/vector FMA operations instead of tensor cores (enable_tc=False). Update _has_gpu_bf16_fma() helper to include AMD RDNA GPUs. While RDNA3 hardware supports BF16 via v_wmma_* instructions, LLVM cannot lower these intrinsics yet for FMA operations.
1 parent 7d85de0 commit 7403bec

File tree

2 files changed

+38
-16
lines changed

2 files changed

+38
-16
lines changed

max/kernels/test/gpu/layout/test_matmul.mojo

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# ===----------------------------------------------------------------------=== #
1313

1414
from sys import has_nvidia_gpu_accelerator
15-
from sys.info import _has_gpu_fp32_tensor_cores, _has_gpu_tensor_cores
15+
from sys.info import (
16+
_has_gpu_bf16_fma,
17+
_has_gpu_fp32_tensor_cores,
18+
_has_gpu_tensor_cores,
19+
)
1620

1721
from benchmark import Bench
1822
from buffer.dimlist import DimList
@@ -220,6 +224,23 @@ def main():
220224
test.run_test[k5](m)
221225
test.run_test[k6](m)
222226

227+
@parameter
228+
if _has_gpu_bf16_fma():
229+
var test_bf16_fma = test_matmul[
230+
DType.bfloat16, a_layout, b_layout, c_layout, False
231+
](m, ctx)
232+
233+
alias k1_bf16 = run_gemm_kernel_1[
234+
DType.bfloat16, a_layout, b_layout, c_layout, 32, 32
235+
]
236+
237+
test_bf16_fma.run_test[k1_bf16](m)
238+
else:
239+
print(
240+
"Skipping BF16 FMA test (requires FP32 accumulation on this"
241+
" GPU)"
242+
)
243+
223244
@parameter
224245
if _has_gpu_fp32_tensor_cores():
225246
test_tc.run_test[k_tc](m)

mojo/stdlib/stdlib/sys/info.mojo

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -775,33 +775,34 @@ fn _has_gpu_fp32_tensor_cores() -> Bool:
775775

776776
@always_inline("nodebug")
777777
fn _has_gpu_bf16_fma() -> Bool:
778-
"""Returns True if the GPU supports BF16 outputs with FMA operations.
778+
"""Returns True if the GPU supports BF16 FMA operations.
779779
780-
This checks whether the GPU can perform BF16 × BF16 → BF16 operations
781-
using scalar/vector FMA instructions (not tensor cores).
780+
This checks whether the GPU can perform BF16 × BF16 operations using
781+
scalar/vector FMA instructions (not tensor cores). On some platforms,
782+
this may use FP32 emulation internally.
782783
783784
Returns True for:
784-
- NVIDIA GPUs (all architectures support BF16 FMA)
785-
- AMD CDNA GPUs with MFMA (MI300X, MI355X)
785+
- NVIDIA GPUs (all architectures support native BF16 FMA)
786+
- AMD CDNA GPUs with MFMA (MI300X, MI355X - native BF16 support)
787+
- AMD RDNA GPUs (RDNA3+ - emulated via FP32 accumulation)
786788
- Apple GPUs (M-series support BF16 operations)
787789
788-
Returns False for:
789-
- AMD RDNA GPUs - these require FP32 accumulation for BF16 FMA.
790-
BF16 outputs are only supported via WMMA (tensor cores), which
791-
LLVM cannot lower yet. For FMA operations, RDNA requires
792-
BF16 inputs with FP32 outputs.
790+
Implementation notes:
791+
- RDNA3 hardware supports BF16 via v_wmma_* instructions, but LLVM
792+
cannot lower these intrinsics yet. For FMA operations, the compiler
793+
automatically promotes BF16 to FP32, performs FP32 computation, then
794+
converts back to BF16. This emulation provides correct results with
795+
some performance overhead.
796+
- CDNA uses native v_mfma_* instructions for BF16.
793797
794798
Note:
795799
This is specifically for FMA (non-tensor-core) operations.
796800
For tensor core BF16 support, use _has_gpu_tensor_cores().
797801
798802
Returns:
799-
True if the GPU supports BF16 output with FMA operations.
803+
True if the GPU supports BF16 FMA operations (native or emulated).
800804
"""
801-
# NVIDIA: All GPUs support BF16 FMA
802-
# AMD: Only CDNA (MFMA) supports BF16 outputs; RDNA requires FP32 accumulation
803-
# Apple: M-series GPUs support BF16 operations
804-
return is_nvidia_gpu() or _has_amd_tensor_cores() or is_apple_gpu()
805+
return is_nvidia_gpu() or has_amd_gpu_accelerator() or is_apple_gpu()
805806

806807

807808
@always_inline("nodebug")

0 commit comments

Comments
 (0)