Skip to content

Commit 6bc8ac2

Browse files
authored
fp8_gemm_rowwise_grouped: fix and reenable eager baseline (#445)
1 parent 056ddb0 commit 6bc8ac2

File tree

1 file changed

+16
-27
lines changed
  • tritonbench/operators/fp8_gemm_rowwise_grouped

1 file changed

+16
-27
lines changed

tritonbench/operators/fp8_gemm_rowwise_grouped/operator.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,8 @@ def __init__(
354354
# Call the parent class's __init__ method to initialize the base attributes
355355
super().__init__(tb_args, extra_args)
356356

357-
# Enable CUDA graphs for this operator
358-
self.use_cuda_graphs = True
357+
# With `--latency-measure-mode profiler`, we no longer need CUDA graphs for accurate latency measurement.
358+
# self.use_cuda_graphs = True
359359

360360
# Enable fp8_fast_accum by default. The cutlass kernel does not support configuring
361361
# this parameter as of now. By default it is true, but there will be correctness issues
@@ -419,7 +419,6 @@ def _triton(self, group_A, group_B, m_sizes, a_scale, b_scale) -> Callable:
419419
@register_benchmark(
420420
enabled=HAS_CUTLASS_OR_CK,
421421
label="ck" if torch.version.hip else "cutlass",
422-
baseline=True,
423422
)
424423
def _cutlass_or_ck(self, group_A, group_B, m_sizes, a_scale, b_scale) -> Callable:
425424
"""
@@ -576,7 +575,7 @@ def get_input_iter(self) -> Generator:
576575
# Yield the quantized tensors and their corresponding scales
577576
yield group_A, group_B, m_sizes, a_scale, b_scale
578577

579-
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
578+
def accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
580579
"""
581580
Check if the output of a function matches the output of a baseline function.
582581
Args:
@@ -646,50 +645,40 @@ def _plot(density, provider):
646645
# Run the plot and save it to the specified path
647646
_plot.run(show_plots=True, print_data=True, save_path=save_path)
648647

649-
"""
650-
# # TODO: Fix this, RuntimeError: CUDA error: operation not permitted when stream is capturing
651648
@register_benchmark(baseline=True)
652-
def _torch(self, group_A, group_B, m_sizes, a_scale, b_scale) -> Callable:
649+
def eager_fp8_gemm_rowwise_grouped(
650+
self, group_A, group_B, m_sizes, a_scale, b_scale
651+
) -> Callable:
653652
def torch_perf_fn(group_A, group_B, m_sizes, a_scale, b_scale):
654653
group_size = len(m_sizes)
655654
xq, wq = group_A, group_B
656655
m, k = xq.size()
657-
gn, k = wq.size()
656+
gn, _ = wq.size()
658657
n = gn // group_size
659658

660659
expected_result = torch.zeros(
661660
m, n, dtype=torch.bfloat16, device=self.device
662661
)
663-
m_offsets, _ = torch.sort(
664-
torch.randint(
665-
low=0,
666-
high=m,
667-
size=[group_size],
668-
device=self.device,
669-
dtype=torch.int32,
670-
)
671-
)
672-
m_offsets[group_size - 1] = m
673662

674-
# Running baseline with quantization to exclude quantization error from the test as it has nothing to do with the correctness of the kernel implementation.
663+
# m_sizes holds the row count for each group; cumulative offsets mark
664+
# the starting row for every group in the flattened input tensor.
665+
m_starts = cumulative_sum_with_initial_offset(m_sizes)
666+
675667
for g in range(group_size):
676-
m_start = 0 if g == 0 else m_offsets[g - 1]
677-
m_end = m_offsets[g]
668+
m_start = int(m_starts[g].item())
669+
m_end = m_start + int(m_sizes[g].item())
678670
n_start = g * n
679-
n_end = (g + 1) * n
671+
n_end = n_start + n
680672

673+
# Dequantize the FP8 tiles with their row-wise scales before performing
674+
# the per-group matmul against the matching weight slice.
681675
expected_result[m_start:m_end, :] = (
682676
group_A[m_start:m_end, :].to(torch.float32)
683677
@ group_B[n_start:n_end, :].to(torch.float32).T
684678
* a_scale[m_start:m_end][:, None]
685679
* b_scale[n_start:n_end][None, :]
686680
).to(torch.bfloat16)
687681

688-
# for a, b in zip(group_A, group_B):
689-
# a_fp16 = a.to(torch.float16)
690-
# b_fp16 = b.to(torch.float16)
691-
# out.append(torch.matmul(a_fp16, b_fp16))
692682
return expected_result
693683

694684
return lambda: torch_perf_fn(group_A, group_B, m_sizes, a_scale, b_scale)
695-
"""

0 commit comments

Comments
 (0)