23
23
import modelopt .torch .opt as mto
24
24
import modelopt .torch .quantization as mtq
25
25
from modelopt .torch .quantization .backends .gemm_registry import enable_real_quant_gemm
26
+ from modelopt .torch .quantization .nn .modules .tensor_quantizer import SequentialQuantizer
26
27
from modelopt .torch .quantization .utils import is_quantized_linear
27
28
from modelopt .torch .utils import torch_to
28
29
@@ -150,56 +151,35 @@ def forward_loop(model):
150
151
dist .destroy_process_group ()
151
152
152
153
153
- def data_parallel_test_helper (model , config , dp_group ):
154
+ def dp_cp_parallel_test_helper (model , config , group ):
154
155
calib_data = model .get_dummy_input ().cuda ()
155
156
156
157
def forward_loop (model ):
157
158
model (calib_data )
158
159
159
160
model = mtq .quantize (model , config , forward_loop )
160
161
161
- # Input quantizer amax
162
- if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
163
- fc1_amax = model .fc1 .input_quantizer .amax .clone ()
164
- dist .all_reduce (fc1_amax , op = dist .ReduceOp .MAX , group = dp_group )
165
- assert torch .allclose (fc1_amax , model .fc1 .input_quantizer .amax )
166
- fc2_amax = model .fc2 .input_quantizer .amax .clone ()
167
- dist .all_reduce (fc2_amax , op = dist .ReduceOp .MAX , group = dp_group )
168
- assert torch .allclose (fc2_amax , model .fc2 .input_quantizer .amax )
169
-
170
- # Weight quantizer amax
171
- fc1_amax = model .fc1 .weight_quantizer .amax .clone ()
172
- dist .all_reduce (fc1_amax , op = dist .ReduceOp .MAX , group = dp_group )
173
- assert torch .allclose (fc1_amax , model .fc1 .weight_quantizer .amax )
174
- fc2_amax = model .fc2 .weight_quantizer .amax .clone ()
175
- dist .all_reduce (fc2_amax , op = dist .ReduceOp .MAX , group = dp_group )
176
- assert torch .allclose (fc2_amax , model .fc2 .weight_quantizer .amax )
177
-
178
-
179
- def context_parallel_test_helper (model , config , cp_group ):
180
- calib_data = model .get_dummy_input ().cuda ()
181
-
182
- def forward_loop (model ):
183
- model (calib_data )
184
-
185
- model = mtq .quantize (model , config , forward_loop )
162
+ def reduce_amax (quantizer ):
163
+ amax = quantizer .amax .clone ()
164
+ dist .all_reduce (amax , op = dist .ReduceOp .MAX , group = group )
165
+ assert torch .allclose (amax , quantizer .amax )
186
166
187
167
# Input quantizer amax
188
168
if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
189
- fc1_amax = model .fc1 .input_quantizer .amax .clone ()
190
- dist .all_reduce (fc1_amax , op = dist .ReduceOp .MAX , group = cp_group )
191
- assert torch .allclose (fc1_amax , model .fc1 .input_quantizer .amax )
192
- fc2_amax = model .fc2 .input_quantizer .amax .clone ()
193
- dist .all_reduce (fc2_amax , op = dist .ReduceOp .MAX , group = cp_group )
194
- assert torch .allclose (fc2_amax , model .fc2 .input_quantizer .amax )
169
+ reduce_amax (model .fc1 .input_quantizer )
170
+ reduce_amax (model .fc2 .input_quantizer )
195
171
196
172
# Weight quantizer amax
197
- fc1_weight_amax = model .fc1 .weight_quantizer .amax .clone ()
198
- dist .all_reduce (fc1_weight_amax , op = dist .ReduceOp .MAX , group = cp_group )
199
- assert torch .allclose (fc1_weight_amax , model .fc1 .weight_quantizer .amax )
200
- fc2_weight_amax = model .fc2 .weight_quantizer .amax .clone ()
201
- dist .all_reduce (fc2_weight_amax , op = dist .ReduceOp .MAX , group = cp_group )
202
- assert torch .allclose (fc2_weight_amax , model .fc2 .weight_quantizer .amax )
173
+ if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
174
+ for quantizer in model .fc1 .weight_quantizer :
175
+ reduce_amax (quantizer )
176
+ else :
177
+ reduce_amax (model .fc1 .weight_quantizer )
178
+ if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
179
+ for quantizer in model .fc2 .weight_quantizer :
180
+ reduce_amax (quantizer )
181
+ else :
182
+ reduce_amax (model .fc2 .weight_quantizer )
203
183
204
184
205
185
def data_tensor_context_parallel_test_helper (model , config , dp_group , tp_group , cp_group ):
@@ -212,29 +192,29 @@ def forward_loop(model):
212
192
213
193
model = mtq .quantize (model , config , forward_loop )
214
194
195
+ def reduce_amax (quantizer ):
196
+ amax = quantizer .amax .clone ()
197
+ dist .all_reduce (amax , op = dist .ReduceOp .MAX , group = tp_group )
198
+ dist .all_reduce (amax , op = dist .ReduceOp .MAX , group = cp_group )
199
+ dist .all_reduce (amax , op = dist .ReduceOp .MAX , group = dp_group )
200
+ assert torch .allclose (amax , quantizer .amax )
201
+
215
202
# Input quantizer amax
216
203
if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
217
- fc1_amax = model .fc1 .input_quantizer .amax .clone ()
218
- dist .all_reduce (fc1_amax , op = dist .ReduceOp .MAX , group = tp_group )
219
- dist .all_reduce (fc1_amax , op = dist .ReduceOp .MAX , group = cp_group )
220
- dist .all_reduce (fc1_amax , op = dist .ReduceOp .MAX , group = dp_group )
221
- assert torch .allclose (fc1_amax , model .fc1 .input_quantizer .amax )
222
- fc2_amax = model .fc2 .input_quantizer .amax .clone ()
223
- dist .all_reduce (fc2_amax , op = dist .ReduceOp .MAX , group = tp_group )
224
- dist .all_reduce (fc2_amax , op = dist .ReduceOp .MAX , group = cp_group )
225
- dist .all_reduce (fc2_amax , op = dist .ReduceOp .MAX , group = dp_group )
226
- assert torch .allclose (fc2_amax , model .fc2 .input_quantizer .amax )
227
-
228
- fc1_amax = model .fc1 .weight_quantizer .amax .clone ()
229
- dist .all_reduce (fc1_amax , op = dist .ReduceOp .MAX , group = tp_group )
230
- dist .all_reduce (fc1_amax , op = dist .ReduceOp .MAX , group = cp_group )
231
- dist .all_reduce (fc1_amax , op = dist .ReduceOp .MAX , group = dp_group )
232
- assert torch .allclose (fc1_amax , model .fc1 .weight_quantizer .amax )
233
- fc2_amax = model .fc2 .weight_quantizer .amax .clone ()
234
- dist .all_reduce (fc2_amax , op = dist .ReduceOp .MAX , group = tp_group )
235
- dist .all_reduce (fc2_amax , op = dist .ReduceOp .MAX , group = cp_group )
236
- dist .all_reduce (fc2_amax , op = dist .ReduceOp .MAX , group = dp_group )
237
- assert torch .allclose (fc2_amax , model .fc2 .weight_quantizer .amax )
204
+ reduce_amax (model .fc1 .input_quantizer )
205
+ reduce_amax (model .fc2 .input_quantizer )
206
+
207
+ if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
208
+ for quantizer in model .fc1 .weight_quantizer :
209
+ reduce_amax (quantizer )
210
+ else :
211
+ reduce_amax (model .fc1 .weight_quantizer )
212
+
213
+ if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
214
+ for quantizer in model .fc2 .weight_quantizer :
215
+ reduce_amax (quantizer )
216
+ else :
217
+ reduce_amax (model .fc2 .weight_quantizer )
238
218
239
219
240
220
def auto_quantize_helper (model ):
0 commit comments