|
22 | 22 | import megatron.core.tensor_parallel.layers as megatron_parallel
|
23 | 23 | import megatron.core.transformer.mlp as megatron_mlp
|
24 | 24 | import torch
|
| 25 | +import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear |
| 26 | +from megatron.core.extensions import transformer_engine as megatron_te |
25 | 27 | from megatron.core.parallel_state import get_data_parallel_group
|
26 | 28 | from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
|
27 | 29 | from megatron.core.transformer import MegatronModule
|
|
34 | 36 | )
|
35 | 37 | from modelopt.torch.utils.distributed import ParallelState
|
36 | 38 |
|
37 |
| -from ..nn import QuantModuleRegistry, TensorQuantizer |
38 |
| -from ..nn.modules.quant_linear import RealQuantLinear |
| 39 | +from ..nn import QuantModuleRegistry, SequentialQuantizer, TensorQuantizer |
| 40 | +from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear |
39 | 41 | from ..qtensor import QTensorWrapper
|
40 | 42 | from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
|
41 | 43 |
|
@@ -223,10 +225,18 @@ def _setup(self):
|
223 | 225 | data_parallel_group = get_data_parallel_group(with_context_parallel=True)
|
224 | 226 | except AssertionError:
|
225 | 227 | data_parallel_group = get_data_parallel_group()
|
| 228 | + |
| 229 | + try: |
| 230 | + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() |
| 231 | + except AssertionError: |
| 232 | + expert_tensor_parallel_group = None |
| 233 | + |
226 | 234 | self.parallel_state = ParallelState(
|
227 | 235 | data_parallel_group,
|
228 | 236 | mcore_parallel.get_tensor_model_parallel_group(),
|
229 | 237 | mcore_parallel.get_context_parallel_group(),
|
| 238 | + mcore_parallel.get_expert_model_parallel_group(), |
| 239 | + expert_tensor_parallel_group, |
230 | 240 | )
|
231 | 241 | super()._setup()
|
232 | 242 |
|
@@ -469,3 +479,161 @@ class _RealQuantMegatronRowParallelLinear(
|
469 | 479 |
|
470 | 480 | def forward(self, input, *args, **kwargs):
|
471 | 481 | return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)
|
| 482 | + |
| 483 | + |
| 484 | +# Register the public te.pytorch.GroupedLinear class |
| 485 | +@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"}) |
| 486 | +class _QuantTEGroupedLinear(_MegatronParallelLinear): |
| 487 | + def _setup(self): |
| 488 | + data_parallel_group = None |
| 489 | + try: |
| 490 | + data_parallel_group = get_data_parallel_group(with_context_parallel=True) |
| 491 | + except AssertionError: |
| 492 | + data_parallel_group = get_data_parallel_group() |
| 493 | + |
| 494 | + try: |
| 495 | + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() |
| 496 | + except AssertionError: |
| 497 | + expert_tensor_parallel_group = None |
| 498 | + self.parallel_state = ParallelState( |
| 499 | + data_parallel_group, |
| 500 | + mcore_parallel.get_tensor_model_parallel_group(), |
| 501 | + mcore_parallel.get_context_parallel_group(), |
| 502 | + mcore_parallel.get_expert_model_parallel_group(), |
| 503 | + expert_tensor_parallel_group, |
| 504 | + ) |
| 505 | + self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input) |
| 506 | + self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight) |
| 507 | + self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output) |
| 508 | + self.output_quantizer.disable() |
| 509 | + |
| 510 | + # Memorize the original weight.dtype for modelopt_post_restore given that |
| 511 | + # the dtype can change later. |
| 512 | + self.original_weight_dtype = None if self.weight0 is None else self.weight0.dtype |
| 513 | + |
| 514 | + @property |
| 515 | + def functionals_to_replace(self): |
| 516 | + original_forward = te_grouped_linear._GroupedLinear.forward |
| 517 | + |
| 518 | + def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args): |
| 519 | + num_gemms = len(m_splits) |
| 520 | + weights_and_biases = args[-2 * num_gemms :] |
| 521 | + weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] |
| 522 | + quantized_inputs = self.input_quantizer(inp) |
| 523 | + quantized_weights = [self.weight_quantizer(weight) for weight in weights] |
| 524 | + |
| 525 | + output = original_forward( |
| 526 | + ctx, |
| 527 | + quantized_inputs, |
| 528 | + m_splits, |
| 529 | + *args[: -2 * num_gemms], |
| 530 | + *quantized_weights, |
| 531 | + *biases, |
| 532 | + ) |
| 533 | + return self.output_quantizer(output) |
| 534 | + |
| 535 | + return [ |
| 536 | + ( |
| 537 | + te_grouped_linear._GroupedLinear, |
| 538 | + "forward", |
| 539 | + te_grouped_quantized_linear_fn, |
| 540 | + ), |
| 541 | + ] |
| 542 | + |
| 543 | + def modelopt_post_restore(self, prefix: str = ""): |
| 544 | + """Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks. |
| 545 | +
|
| 546 | + ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their |
| 547 | + shape before saving. However this is not enough for MCore/distributed frameworks since the tensor parallelism |
| 548 | + could change between saving and restoring. If the tensor parallelism changes, the shape of the quantizer |
| 549 | + states also changes. So we need to re-calculate the quantizer states. |
| 550 | + """ |
| 551 | + from modelopt.torch.quantization.model_calib import max_calibrate |
| 552 | + |
| 553 | + def _check_unsupported_states(quantizer: TensorQuantizer): |
| 554 | + for k in quantizer.state_dict(): |
| 555 | + if k not in ["_amax", "_pre_quant_scale"]: |
| 556 | + warnings.warn( |
| 557 | + f"Restore of {k} for {prefix} is not supported. The restore of this layer might be " |
| 558 | + f"incorrect. Please implement a custom restore for {k}." |
| 559 | + ) |
| 560 | + |
| 561 | + def _has_state(quantizer, name): |
| 562 | + # Handling for SequentialQuantizer |
| 563 | + quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer |
| 564 | + return hasattr(quantizer, name) |
| 565 | + |
| 566 | + # weights for TEGroupedLinear are stored in weight0, weight1, etc. |
| 567 | + if self.weight0 is None: |
| 568 | + return |
| 569 | + for quantizer in [self.weight_quantizer, self.input_quantizer, self.output_quantizer]: |
| 570 | + _check_unsupported_states( |
| 571 | + quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0] |
| 572 | + ) |
| 573 | + if _has_state(self.weight_quantizer, "_amax"): |
| 574 | + self.weight_quantizer.reset_amax() |
| 575 | + for i in range(self.num_gemms): |
| 576 | + weight = getattr(self, f"weight{i}") |
| 577 | + assert weight is not None, "weight is None" |
| 578 | + |
| 579 | + max_calibrate(self.weight_quantizer, lambda wq: wq(weight), distributed_sync=False) |
| 580 | + if _has_state(self.input_quantizer, "_pre_quant_scale"): |
| 581 | + if hasattr(self.input_quantizer, "_pre_quant_scale"): |
| 582 | + delattr(self.input_quantizer, "_pre_quant_scale") |
| 583 | + pqs = torch.zeros( |
| 584 | + (weight.shape[1]), device=weight.device, dtype=self.original_weight_dtype |
| 585 | + ) |
| 586 | + self.input_quantizer.register_buffer("_pre_quant_scale", pqs) |
| 587 | + |
| 588 | + if _has_state(self.input_quantizer, "_amax"): |
| 589 | + self.input_quantizer.reset_amax() |
| 590 | + dummy_input = torch.ones( |
| 591 | + (1, 1, self.weight0.shape[1]), |
| 592 | + device=self.weight0.device, |
| 593 | + dtype=self.original_weight_dtype, |
| 594 | + ) |
| 595 | + max_calibrate(self.input_quantizer, lambda iq: iq(dummy_input), distributed_sync=False) |
| 596 | + if _has_state(self.output_quantizer, "_amax"): |
| 597 | + self.output_quantizer.reset_amax() |
| 598 | + dummy_input = torch.ones( |
| 599 | + (1, 1, self.weight0.shape[0]), |
| 600 | + device=self.weight0.device, |
| 601 | + dtype=self.original_weight_dtype, |
| 602 | + ) |
| 603 | + max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False) |
| 604 | + # If there are any other states, lets move them to the correct device |
| 605 | + |
| 606 | + self.weight = None |
| 607 | + super().modelopt_post_restore(prefix=prefix) |
| 608 | + |
| 609 | + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
| 610 | + # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in |
| 611 | + # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for |
| 612 | + # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] |
| 613 | + # for modelopt checkpoint restore |
| 614 | + filtered_state_dict = { |
| 615 | + k: v |
| 616 | + for k, v in state_dict.items() |
| 617 | + if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms)) |
| 618 | + } |
| 619 | + return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) |
| 620 | + |
| 621 | + def _process_quantizer_amax(self, k, v, quantizer_state_dict): |
| 622 | + if v.ndim == 4: |
| 623 | + quantizer_state_dict[k] = v.squeeze(1).squeeze(-1) |
| 624 | + else: |
| 625 | + quantizer_state_dict[k] = v.view(-1, 1) if v.numel() > 1 else v.view(-1) |
| 626 | + |
| 627 | + |
| 628 | +@QuantModuleRegistry.register( |
| 629 | + {megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"} |
| 630 | +) |
| 631 | +class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): |
| 632 | + _is_column_parallel = True |
| 633 | + |
| 634 | + |
| 635 | +@QuantModuleRegistry.register( |
| 636 | + {megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} |
| 637 | +) |
| 638 | +class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): |
| 639 | + _is_row_parallel = True |
0 commit comments