Skip to content

Commit b757fb9

Browse files
[MoE training] Assert expert weights are column-major; preserve subclass with transpose (#2663)
* assert B is col-major * preserve subclass with transpose
1 parent 6bb2baf commit b757fb9

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,7 @@ def forward(
9595
assert not _is_column_major(A), "A must be row-major"
9696

9797
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
98-
if not _is_column_major(B_t):
99-
# FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major.
100-
# TODO: figure out better solution than transposing for each forward pass.
101-
B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1)
98+
assert _is_column_major(B_t), "B must be column-major"
10299

103100
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
104101
# A shape: (M, K) or (B, M, K)

torchao/prototype/moe_training/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
torch.ops.aten._pin_memory.default,
3131
torch.ops.aten.split.Tensor,
3232
torch.ops.aten.clone.default,
33+
torch.ops.aten.transpose.int,
3334
}
3435

3536

0 commit comments

Comments
 (0)