Skip to content

Commit 7e1c830

Browse files
authored
[fix] remove (view) transpose to keep consistent with majorness MN requirement. (#1358)
1 parent d517373 commit 7e1c830

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

flashinfer/gemm.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,7 +1549,7 @@ def gemm_fp8_nt_groupwise(
15491549
b: torch.Tensor,
15501550
a_scale: torch.Tensor,
15511551
b_scale: torch.Tensor,
1552-
scale_major_mode: Literal["MN", "K"] = "MN",
1552+
scale_major_mode: Optional[Literal["MN", "K"]] = None,
15531553
mma_sm: int = 1,
15541554
scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128),
15551555
out: Optional[torch.Tensor] = None,
@@ -1571,12 +1571,20 @@ def gemm_fp8_nt_groupwise(
15711571
Column-major input tensor shape (n, k), fp8 e4m3 or fp8 e5m2.
15721572
15731573
a_scale: torch.Tensor
1574-
Column-major scale tensor for a, shape ``(m, k // block_size)`` if scale_major_mode is ``K``
1575-
or shape ``(k // block_size, m)`` if scale_major_mode is ``MN``
1574+
if the backend is ``cutlass``:
1575+
Column-major scale tensor for a, shape ``(m, k // block_size)`` if scale_major_mode is ``K``
1576+
or shape ``(k // block_size, m)`` if scale_major_mode is ``MN``
1577+
if the backend is ``trtllm``:
1578+
scale_major_mode should be None, the scale tensor should be (m, k // block_size),
1579+
contiguous on the first dimension
15761580
15771581
b_scale: torch.Tensor
1578-
Row-major scale tensor for b, shape ``(n // block_size, k // block_size)`` if scale_major_k is ``K``
1579-
or shape ``(k // block_size, n // block_size)`` if scale_major_mode is ``MN``
1582+
if the backend is ``cutlass``:
1583+
Row-major scale tensor for b, shape ``(n // block_size, k // block_size)`` if scale_major_k is ``K``
1584+
or shape ``(k // block_size, n // block_size)`` if scale_major_mode is ``MN``
1585+
if the backend is ``trtllm``:
1586+
scale_major_mode should be None, the scale tensor should be (k // block_size, n // block_size),
1587+
contiguous on the first dimension
15801588
15811589
scale_granularity_mnk: Tuple[int, int, int]
15821590
The granularity of the scale tensor, (m_granularity, n_granularity, k_granularity).
@@ -1642,6 +1650,7 @@ def gemm_fp8_nt_groupwise(
16421650
)
16431651

16441652
if backend == "cutlass":
1653+
assert scale_major_mode is not None
16451654
get_gemm_sm100_module().gemm_fp8_nt_groupwise.default(
16461655
workspace_buffer,
16471656
a,
@@ -1655,15 +1664,14 @@ def gemm_fp8_nt_groupwise(
16551664
)
16561665
elif backend == "trtllm":
16571666
assert scale_granularity_mnk == (1, 128, 128)
1658-
assert scale_major_mode == "MN"
16591667
assert a.shape[1] >= 256
16601668
# mma_sm is ignored
16611669
get_trtllm_gemm_module().trtllm_gemm(
16621670
workspace_buffer,
16631671
a,
16641672
b,
1665-
a_scale.t(),
1666-
b_scale.t().contiguous().t(),
1673+
a_scale,
1674+
b_scale,
16671675
None,
16681676
out,
16691677
False,

tests/test_groupwise_scaled_gemm_fp8.py

100644100755
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ def test_fp8_groupwise_gemm(
111111
b_dequant = dequantize_fp8(b_fp8, b_scale, scale_major_mode)
112112
ref_c = einsum(a_dequant, b_dequant, "m k, n k -> m n").to(out_dtype)
113113

114+
if backend == "trtllm":
115+
b_scale = b_scale.t().contiguous()
116+
114117
c = gemm_fp8_nt_groupwise(
115118
a_fp8,
116119
b_fp8,

0 commit comments

Comments
 (0)