14
14
import argparse
15
15
import copy
16
16
import os
17
- import statistics
18
- from time import perf_counter_ns
19
17
20
18
import pytest
21
19
import torch
24
22
from torch .distributed ._composable .fsdp import fully_shard
25
23
from torch .nn import functional as F
26
24
25
+ from benchmarks .prototype .moe_training .utils import (
26
+ bench_fwd_bwd_microseconds ,
27
+ profile_fn ,
28
+ )
29
+
27
30
# this feature requires CUDA and SM89+
28
31
if not torch .cuda .is_available () or torch .cuda .get_device_capability () < (8 , 9 ):
29
32
pytest .skip (
48
51
)
49
52
50
53
51
- def bench_moe_float8_training_fsdp (enable_profile = False ):
54
+ def bench_moe_float8_training_fsdp (
55
+ recipe_name : str , enable_profile : bool , use_compile : bool
56
+ ):
52
57
assert torch .cuda .is_available ()
58
+ assert recipe_name in ["fp8_rowwise" , "mxfp8" ]
59
+ recipe = MoEScalingType [recipe_name .upper ()]
53
60
54
61
# setup distributed for fsdp
55
62
setup_distributed ()
@@ -62,15 +69,19 @@ def bench_moe_float8_training_fsdp(enable_profile=False):
62
69
init_std = 0.02
63
70
device = torch .device ("cuda" )
64
71
65
- # reference bf16 MoE
66
- dim , hidden_dim = 5120 , 4 * 5120
72
+ # reference bf16 MoE using llama4 shapes
73
+ dim , hidden_dim = 5120 , 8192
67
74
ref_model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).cuda ()
68
75
torch .manual_seed (42 )
69
76
ref_model .init_weights (init_std , device )
70
77
71
78
# target MoE for testing conversion
72
79
model = copy .deepcopy (ref_model )
73
80
81
+ # Token group alignment size must be 16 for fp8 rowwise training
82
+ alignment_size = 32 if recipe == MoEScalingType .MXFP8 else 16
83
+ set_token_group_alignment_size_m (alignment_size )
84
+
74
85
# assert starting params are identical for both models
75
86
for param1 , param2 in zip (model .parameters (), ref_model .parameters ()):
76
87
assert torch .equal (param1 , param2 )
@@ -83,15 +94,15 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
83
94
return False
84
95
85
96
# quantize test model
86
- config = MoETrainingConfig (scaling_type = MoEScalingType . FP8_ROWWISE )
97
+ config = MoETrainingConfig (scaling_type = recipe )
87
98
quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
88
99
89
100
# FSDP2
90
101
fully_shard (model )
91
102
fully_shard (ref_model )
92
103
93
104
# inputs (llama4 shapes)
94
- batch , seq = 1 , 8192
105
+ batch , seq = 1 , 16640
95
106
ref_x = torch .randn (
96
107
batch , seq , dim , dtype = torch .bfloat16 , requires_grad = True , device = device
97
108
)
@@ -104,70 +115,34 @@ def warmup(model, input):
104
115
loss .backward ()
105
116
torch .cuda .synchronize ()
106
117
107
- def bench_fn_microseconds (model , input ):
108
- labels = torch .ones_like (input )
109
- times = []
110
- for _ in range (10 ):
111
- start_ns = perf_counter_ns ()
112
- out = model (input )
113
- loss = F .mse_loss (out , labels )
114
- loss .backward ()
115
- torch .cuda .synchronize ()
116
- end_ns = perf_counter_ns ()
117
- duration_us = (end_ns - start_ns ) / 1000
118
- times .append (duration_us )
119
- return statistics .median (times )
120
-
121
- def profile_fn (model , input , profile_name = "profile" ):
122
- # Only profile on rank 0
123
- if torch .distributed .get_rank () == 0 :
124
- labels = torch .ones_like (input )
125
- wait , warmup , active = 1 , 3 , 1
126
- total_steps = wait + warmup + active
127
- with torch .profiler .profile (
128
- activities = [
129
- torch .profiler .ProfilerActivity .CPU ,
130
- torch .profiler .ProfilerActivity .CUDA ,
131
- ],
132
- schedule = torch .profiler .schedule (
133
- wait = wait , warmup = warmup , active = active , repeat = 0
134
- ),
135
- record_shapes = True ,
136
- with_stack = True ,
137
- ) as prof :
138
- for _ in range (total_steps ):
139
- out = model (input )
140
- loss = F .mse_loss (out , labels )
141
- loss .backward ()
142
- prof .step ()
143
-
144
- # Save profiler results
145
- prof .export_chrome_trace (f"{ profile_name } .json" )
146
- print (f"Saved: { profile_name } .json" )
147
-
148
- # Compile models
149
- ref_model = torch .compile (ref_model , fullgraph = False )
150
- model = torch .compile (model , fullgraph = False )
151
-
152
- print ("Benchmarking MoE with FSDP2 using bf16 training" )
153
- warmup (ref_model , ref_x )
154
- bf16_us = bench_fn_microseconds (ref_model , ref_x )
155
- print (f"bf16 time: { bf16_us } us" )
156
- if enable_profile :
157
- print ("Profiling bf16 model" )
158
- profile_fn (ref_model , ref_x , profile_name = "bf16_profile" )
118
+ labels = torch .ones_like (x )
159
119
160
- # Token group alignment size must be 16 for fp8 rowwise training
161
- set_token_group_alignment_size_m (16 )
162
-
163
- print ("Benchmarking MoE with FSDP2 using fp8 rowwise training" )
164
- warmup (model , x )
165
- fp8_us = bench_fn_microseconds (model , x )
166
- print (f"fp8 time: { fp8_us } us" )
120
+ # TODO: bench with fullgraph=True if/when it is supported
121
+ bf16_us = bench_fwd_bwd_microseconds (
122
+ ref_model ,
123
+ ref_x ,
124
+ labels = labels ,
125
+ use_compile = use_compile ,
126
+ fullgraph = False ,
127
+ )
128
+ print (f"BF16 time: { bf16_us } us" )
129
+ if enable_profile :
130
+ print ("Profiling bf16 training" )
131
+ profile_fn (ref_model , ref_x , labels = labels , profile_name = "bf16_profile" )
132
+
133
+ scaled_us = bench_fwd_bwd_microseconds (
134
+ model ,
135
+ x ,
136
+ labels = labels ,
137
+ use_compile = use_compile ,
138
+ fullgraph = False ,
139
+ )
140
+ print (f"Scaled time: { scaled_us } us" )
167
141
if enable_profile :
168
- print ("Profiling fp8 model " )
169
- profile_fn (model , x , profile_name = "fp8_profile " )
142
+ print ("Profiling quantized training " )
143
+ profile_fn (model , x , labels = labels , profile_name = f" { recipe_name } _profile " )
170
144
145
+ print (f"Speedup: { bf16_us / scaled_us :.3f} x" )
171
146
dist .destroy_process_group ()
172
147
173
148
@@ -185,5 +160,15 @@ def setup_distributed():
185
160
action = "store_true" ,
186
161
help = "Enable PyTorch profiling and save results to file" ,
187
162
)
163
+ parser .add_argument ("--recipe" , type = str , help = "[fp8_rowwise, mxfp8]" )
164
+ parser .add_argument (
165
+ "--compile" ,
166
+ action = "store_true" ,
167
+ help = "use torch.compile" ,
168
+ )
188
169
args = parser .parse_args ()
189
- bench_moe_float8_training_fsdp (enable_profile = args .profile )
170
+ bench_moe_float8_training_fsdp (
171
+ recipe_name = args .recipe ,
172
+ enable_profile = args .profile ,
173
+ use_compile = args .compile ,
174
+ )
0 commit comments