Skip to content

Commit 246b142

Browse files
[moe training] integrate rowwise expert quant kernel (#2698)
1 parent 143c3a6 commit 246b142

File tree

3 files changed

+8
-13
lines changed

3 files changed

+8
-13
lines changed

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from torchao.prototype.moe_training.kernels.float8_rowwise import (
2+
triton_fp8_rowwise_3d_transpose_rhs as triton_fp8_rowwise_3d_transpose_rhs,
3+
)
14
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
25
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales,
36
)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchao.prototype.moe_training.kernels import (
1616
triton_fp8_col_major_jagged_colwise_scales,
1717
triton_fp8_row_major_jagged_rowwise_scales,
18+
triton_fp8_rowwise_3d_transpose_rhs,
1819
)
1920
from torchao.prototype.moe_training.utils import (
2021
_is_column_major,
@@ -142,20 +143,11 @@ def forward(
142143
# Precompute non-transposed B column-major for backward, to save memory by storing the
143144
# low precision B tensor instead of the high precision B tensor.
144145
# In the backward this is needed for grad_A: grad_output @ B.
145-
B = B_t.contiguous().transpose(-2, -1)
146-
147-
# - B shape: (E, N, K)
148-
# - B scales must be computed rowwise keeping the outer/final dim, so:
149-
# - B_scale shape: (E, 1, K)
150-
B_scales = tensor_to_scale(
151-
B,
152-
torch.float8_e4m3fn,
153-
scaling_granularity=ScalingGranularity.AXISWISE,
154-
axiswise_dim=-2,
146+
B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
147+
B_t,
148+
output_dtype=torch.float8_e4m3fn,
155149
round_scales_to_power_of_2=True,
156150
)
157-
B_scaled = B.to(torch.float32) * B_scales
158-
B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn)
159151

160152
# Store what we need for backward.
161153
ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)

torchao/prototype/moe_training/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def torch_to_3d_rowwise_float8_transpose_rhs(
152152
) -> Tuple[torch.Tensor, torch.Tensor]:
153153
"""
154154
This function converts the 3D input tensor to a float8 tensor, with scales computed along logical columns
155-
on a per-expert basis.
155+
on a per-expert basis. Output will be in column-major memory layout.
156156
157157
Args:
158158
x (torch.Tensor): The input tensor to be converted to a float8 tensor. Shape (E, K, N).

0 commit comments

Comments
 (0)