@@ -1549,7 +1549,7 @@ def gemm_fp8_nt_groupwise(
1549
1549
b : torch .Tensor ,
1550
1550
a_scale : torch .Tensor ,
1551
1551
b_scale : torch .Tensor ,
1552
- scale_major_mode : Literal ["MN" , "K" ] = "MN" ,
1552
+ scale_major_mode : Optional [ Literal ["MN" , "K" ]] = None ,
1553
1553
mma_sm : int = 1 ,
1554
1554
scale_granularity_mnk : Tuple [int , int , int ] = (1 , 128 , 128 ),
1555
1555
out : Optional [torch .Tensor ] = None ,
@@ -1571,12 +1571,20 @@ def gemm_fp8_nt_groupwise(
1571
1571
Column-major input tensor shape (n, k), fp8 e4m3 or fp8 e5m2.
1572
1572
1573
1573
a_scale: torch.Tensor
1574
- Column-major scale tensor for a, shape ``(m, k // block_size)`` if scale_major_mode is ``K``
1575
- or shape ``(k // block_size, m)`` if scale_major_mode is ``MN``
1574
+ if the backend is ``cutlass``:
1575
+ Column-major scale tensor for a, shape ``(m, k // block_size)`` if scale_major_mode is ``K``
1576
+ or shape ``(k // block_size, m)`` if scale_major_mode is ``MN``
1577
+ if the backend is ``trtllm``:
1578
+ scale_major_mode should be None, the scale tensor should be (m, k // block_size),
1579
+ contiguous on the first dimension
1576
1580
1577
1581
b_scale: torch.Tensor
1578
- Row-major scale tensor for b, shape ``(n // block_size, k // block_size)`` if scale_major_k is ``K``
1579
- or shape ``(k // block_size, n // block_size)`` if scale_major_mode is ``MN``
1582
+ if the backend is ``cutlass``:
1583
+ Row-major scale tensor for b, shape ``(n // block_size, k // block_size)`` if scale_major_k is ``K``
1584
+ or shape ``(k // block_size, n // block_size)`` if scale_major_mode is ``MN``
1585
+ if the backend is ``trtllm``:
1586
+ scale_major_mode should be None, the scale tensor should be (k // block_size, n // block_size),
1587
+ contiguous on the first dimension
1580
1588
1581
1589
scale_granularity_mnk: Tuple[int, int, int]
1582
1590
The granularity of the scale tensor, (m_granularity, n_granularity, k_granularity).
@@ -1642,6 +1650,7 @@ def gemm_fp8_nt_groupwise(
1642
1650
)
1643
1651
1644
1652
if backend == "cutlass" :
1653
+ assert scale_major_mode is not None
1645
1654
get_gemm_sm100_module ().gemm_fp8_nt_groupwise .default (
1646
1655
workspace_buffer ,
1647
1656
a ,
@@ -1655,15 +1664,14 @@ def gemm_fp8_nt_groupwise(
1655
1664
)
1656
1665
elif backend == "trtllm" :
1657
1666
assert scale_granularity_mnk == (1 , 128 , 128 )
1658
- assert scale_major_mode == "MN"
1659
1667
assert a .shape [1 ] >= 256
1660
1668
# mma_sm is ignored
1661
1669
get_trtllm_gemm_module ().trtllm_gemm (
1662
1670
workspace_buffer ,
1663
1671
a ,
1664
1672
b ,
1665
- a_scale . t () ,
1666
- b_scale . t (). contiguous (). t () ,
1673
+ a_scale ,
1674
+ b_scale ,
1667
1675
None ,
1668
1676
out ,
1669
1677
False ,
0 commit comments