|
37 | 37 | )
|
38 | 38 | from modelopt.torch.utils.distributed import ParallelState
|
39 | 39 |
|
40 |
| -from ..nn import QuantModuleRegistry, SequentialQuantizer, TensorQuantizer |
| 40 | +from ..nn import QuantModuleRegistry, TensorQuantizer |
41 | 41 | from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear
|
42 | 42 | from ..qtensor import QTensorWrapper
|
43 | 43 | from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
|
@@ -501,7 +501,6 @@ def _setup(self):
|
501 | 501 | self.parallel_state = ParallelState(
|
502 | 502 | data_parallel_group,
|
503 | 503 | mcore_parallel.get_tensor_model_parallel_group(),
|
504 |
| - mcore_parallel.get_context_parallel_group(), |
505 | 504 | mcore_parallel.get_expert_model_parallel_group(),
|
506 | 505 | expert_tensor_parallel_group,
|
507 | 506 | )
|
@@ -544,70 +543,13 @@ def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args):
|
544 | 543 | ]
|
545 | 544 |
|
546 | 545 | def modelopt_post_restore(self, prefix: str = ""):
|
547 |
| - """Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks. |
548 |
| -
|
549 |
| - ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their |
550 |
| - shape before saving. However this is not enough for MCore/distributed frameworks since the tensor parallelism |
551 |
| - could change between saving and restoring. If the tensor parallelism changes, the shape of the quantizer |
552 |
| - states also changes. So we need to re-calculate the quantizer states. |
553 |
| - """ |
554 |
| - from modelopt.torch.quantization.model_calib import max_calibrate |
555 |
| - |
556 |
| - def _check_unsupported_states(quantizer: TensorQuantizer): |
557 |
| - for k in quantizer.state_dict(): |
558 |
| - if k not in ["_amax", "_pre_quant_scale"]: |
559 |
| - warnings.warn( |
560 |
| - f"Restore of {k} for {prefix} is not supported. The restore of this layer might be " |
561 |
| - f"incorrect. Please implement a custom restore for {k}." |
562 |
| - ) |
563 |
| - |
564 |
| - def _has_state(quantizer, name): |
565 |
| - # Handling for SequentialQuantizer |
566 |
| - quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer |
567 |
| - return hasattr(quantizer, name) |
568 |
| - |
569 |
| - # weights for TEGroupedLinear are stored in weight0, weight1, etc. |
570 |
| - if self.weight0 is None: |
571 |
| - return |
572 |
| - for quantizer in [self.weight_quantizer, self.input_quantizer, self.output_quantizer]: |
573 |
| - _check_unsupported_states( |
574 |
| - quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0] |
575 |
| - ) |
576 |
| - if _has_state(self.weight_quantizer, "_amax"): |
577 |
| - self.weight_quantizer.reset_amax() |
578 |
| - for i in range(self.num_gemms): |
579 |
| - weight = getattr(self, f"weight{i}") |
580 |
| - assert weight is not None, "weight is None" |
581 |
| - |
582 |
| - max_calibrate(self.weight_quantizer, lambda wq: wq(weight), distributed_sync=False) |
583 |
| - if _has_state(self.input_quantizer, "_pre_quant_scale"): |
584 |
| - if hasattr(self.input_quantizer, "_pre_quant_scale"): |
585 |
| - delattr(self.input_quantizer, "_pre_quant_scale") |
586 |
| - pqs = torch.zeros( |
587 |
| - (weight.shape[1]), device=weight.device, dtype=self.original_weight_dtype |
588 |
| - ) |
589 |
| - self.input_quantizer.register_buffer("_pre_quant_scale", pqs) |
590 |
| - |
591 |
| - if _has_state(self.input_quantizer, "_amax"): |
592 |
| - self.input_quantizer.reset_amax() |
593 |
| - dummy_input = torch.ones( |
594 |
| - (1, 1, self.weight0.shape[1]), |
595 |
| - device=self.weight0.device, |
596 |
| - dtype=self.original_weight_dtype, |
597 |
| - ) |
598 |
| - max_calibrate(self.input_quantizer, lambda iq: iq(dummy_input), distributed_sync=False) |
599 |
| - if _has_state(self.output_quantizer, "_amax"): |
600 |
| - self.output_quantizer.reset_amax() |
601 |
| - dummy_input = torch.ones( |
602 |
| - (1, 1, self.weight0.shape[0]), |
603 |
| - device=self.weight0.device, |
604 |
| - dtype=self.original_weight_dtype, |
605 |
| - ) |
606 |
| - max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False) |
607 |
| - # If there are any other states, lets move them to the correct device |
608 |
| - |
609 |
| - self.weight = None |
| 546 | + # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to |
| 547 | + # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning |
| 548 | + # self.weight0 to self.weight to run the quantizer states initialization. |
| 549 | + self.weight = self.weight0 |
610 | 550 | super().modelopt_post_restore(prefix=prefix)
|
| 551 | + # Revert the weight to None after post_restore to avoid the weight being None during forward pass. |
| 552 | + self.weight = None |
611 | 553 |
|
612 | 554 | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
613 | 555 | # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
|
|
0 commit comments