Skip to content

Commit 7d85de0

Browse files
committed
[Kernels][GPU] Add BF16 tensor core test
Add a basic BF16 tensor core test to be used when tensor cores are supported.
1 parent a5958e9 commit 7d85de0

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

max/kernels/test/gpu/layout/test_matmul.mojo

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# ===----------------------------------------------------------------------=== #
1313

1414
from sys import has_nvidia_gpu_accelerator
15-
from sys.info import _has_gpu_fp32_tensor_cores
15+
from sys.info import _has_gpu_fp32_tensor_cores, _has_gpu_tensor_cores
1616

1717
from benchmark import Bench
1818
from buffer.dimlist import DimList
@@ -226,4 +226,29 @@ def main():
226226
else:
227227
print("Skipping float32 tensor core test on GPU (not supported)")
228228

229+
var test_tc_bf16 = test_matmul[
230+
DType.bfloat16, a_layout, b_layout, c_layout, True
231+
](m, ctx)
232+
233+
alias k_tc_bf16 = run_gemm_kernel_tc[
234+
DType.bfloat16,
235+
a_layout,
236+
b_layout,
237+
c_layout,
238+
64,
239+
64,
240+
32,
241+
32,
242+
32,
243+
MMA_M,
244+
MMA_N,
245+
MMA_K,
246+
]
247+
248+
@parameter
249+
if _has_gpu_tensor_cores():
250+
test_tc_bf16.run_test[k_tc_bf16](m)
251+
else:
252+
print("Skipping BF16 tensor core test on GPU (not supported)")
253+
229254
m.dump_report()

0 commit comments

Comments
 (0)