Skip to content

Commit 9dc65e1

Browse files
authored
[TUTORIAL] Fix tflops metric for tma fp8 kernels (triton-lang#6372)
Currently the metadata function assumes all tma kernels are running fp16 which causes the tflops8 metric to not be populated.
1 parent 9beabc2 commit 9dc65e1

File tree

1 file changed

+0
-2
lines changed

1 file changed

+0
-2
lines changed

python/tutorials/09-persistent-matmul.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ def _matmul_launch_metadata(grid, kernel, args):
5252
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
5353
if "c_ptr" in args:
5454
bytes_per_elem = args["c_ptr"].element_size()
55-
elif "c_desc_ptr" in args:
56-
bytes_per_elem = 2
5755
else:
5856
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
5957
ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K

0 commit comments

Comments
 (0)