@@ -354,8 +354,8 @@ def __init__(
354
354
# Call the parent class's __init__ method to initialize the base attributes
355
355
super ().__init__ (tb_args , extra_args )
356
356
357
- # Enable CUDA graphs for this operator
358
- self .use_cuda_graphs = True
357
+ # With `--latency-measure-mode profiler`, we no longer need CUDA graphs for accurate latency measurement.
358
+ # self.use_cuda_graphs = True
359
359
360
360
# Enable fp8_fast_accum by default. The cutlass kernel does not support configuring
361
361
# this parameter as of now. By default it is true, but there will be correctness issues
@@ -419,7 +419,6 @@ def _triton(self, group_A, group_B, m_sizes, a_scale, b_scale) -> Callable:
419
419
@register_benchmark (
420
420
enabled = HAS_CUTLASS_OR_CK ,
421
421
label = "ck" if torch .version .hip else "cutlass" ,
422
- baseline = True ,
423
422
)
424
423
def _cutlass_or_ck (self , group_A , group_B , m_sizes , a_scale , b_scale ) -> Callable :
425
424
"""
@@ -576,7 +575,7 @@ def get_input_iter(self) -> Generator:
576
575
# Yield the quantized tensors and their corresponding scales
577
576
yield group_A , group_B , m_sizes , a_scale , b_scale
578
577
579
- def _get_accuracy (self , fn : Callable , baseline_fn : Callable ) -> bool :
578
+ def accuracy (self , fn : Callable , baseline_fn : Callable ) -> bool :
580
579
"""
581
580
Check if the output of a function matches the output of a baseline function.
582
581
Args:
@@ -646,50 +645,40 @@ def _plot(density, provider):
646
645
# Run the plot and save it to the specified path
647
646
_plot .run (show_plots = True , print_data = True , save_path = save_path )
648
647
649
- """
650
- # # TODO: Fix this, RuntimeError: CUDA error: operation not permitted when stream is capturing
651
648
@register_benchmark (baseline = True )
652
- def _torch(self, group_A, group_B, m_sizes, a_scale, b_scale) -> Callable:
649
+ def eager_fp8_gemm_rowwise_grouped (
650
+ self , group_A , group_B , m_sizes , a_scale , b_scale
651
+ ) -> Callable :
653
652
def torch_perf_fn (group_A , group_B , m_sizes , a_scale , b_scale ):
654
653
group_size = len (m_sizes )
655
654
xq , wq = group_A , group_B
656
655
m , k = xq .size ()
657
- gn, k = wq.size()
656
+ gn , _ = wq .size ()
658
657
n = gn // group_size
659
658
660
659
expected_result = torch .zeros (
661
660
m , n , dtype = torch .bfloat16 , device = self .device
662
661
)
663
- m_offsets, _ = torch.sort(
664
- torch.randint(
665
- low=0,
666
- high=m,
667
- size=[group_size],
668
- device=self.device,
669
- dtype=torch.int32,
670
- )
671
- )
672
- m_offsets[group_size - 1] = m
673
662
674
- # Running baseline with quantization to exclude quantization error from the test as it has nothing to do with the correctness of the kernel implementation.
663
+ # m_sizes holds the row count for each group; cumulative offsets mark
664
+ # the starting row for every group in the flattened input tensor.
665
+ m_starts = cumulative_sum_with_initial_offset (m_sizes )
666
+
675
667
for g in range (group_size ):
676
- m_start = 0 if g == 0 else m_offsets[g - 1]
677
- m_end = m_offsets [g]
668
+ m_start = int ( m_starts [ g ]. item ())
669
+ m_end = m_start + int ( m_sizes [g ]. item ())
678
670
n_start = g * n
679
- n_end = (g + 1) * n
671
+ n_end = n_start + n
680
672
673
+ # Dequantize the FP8 tiles with their row-wise scales before performing
674
+ # the per-group matmul against the matching weight slice.
681
675
expected_result [m_start :m_end , :] = (
682
676
group_A [m_start :m_end , :].to (torch .float32 )
683
677
@ group_B [n_start :n_end , :].to (torch .float32 ).T
684
678
* a_scale [m_start :m_end ][:, None ]
685
679
* b_scale [n_start :n_end ][None , :]
686
680
).to (torch .bfloat16 )
687
681
688
- # for a, b in zip(group_A, group_B):
689
- # a_fp16 = a.to(torch.float16)
690
- # b_fp16 = b.to(torch.float16)
691
- # out.append(torch.matmul(a_fp16, b_fp16))
692
682
return expected_result
693
683
694
684
return lambda : torch_perf_fn (group_A , group_B , m_sizes , a_scale , b_scale )
695
- """
0 commit comments