File tree Expand file tree Collapse file tree 1 file changed +8
-0
lines changed
test/prototype/moe_training Expand file tree Collapse file tree 1 file changed +8
-0
lines changed Original file line number Diff line number Diff line change 19
19
20
20
# this test requires torchtitan
21
21
try :
22
+ from torchtitan .experiments .llama4 .infra .expert_parallel import (
23
+ set_token_group_alignment_size_m ,
24
+ )
22
25
from torchtitan .experiments .llama4 .model .args import TransformerModelArgs
23
26
from torchtitan .experiments .llama4 .model .moe import MoE
24
27
except ImportError :
36
39
)
37
40
@pytest .mark .parametrize ("compile" , [False , True ])
38
41
def test_moe_float8_training (target_fqns : list [str ], compile : bool ):
42
+ # Set token group alignment size to 16. This is required so that
43
+ # each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
44
+ # has the contraction dim be divisible by 16. 16 byte alignment is required
45
+ # for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
46
+ set_token_group_alignment_size_m (16 )
39
47
model_args = TransformerModelArgs (
40
48
moe_enabled = True ,
41
49
num_experts = 8 ,
You can’t perform that action at this time.
0 commit comments