Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/nemo_run/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ The resulting exported checkpoint also is much smaller in memory at 6.4GB compar

You can run the example either locally or on a [Slurm cluster](ADVANCED.md).

To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.07 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container.
To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.09 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container.

- `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git`
- `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a`
- `git clone https://github.com/NVIDIA-NeMo/NeMo.git`
- `git clone https://github.com/NVIDIA/Megatron-LM.git`

Example docker command:

```bash
docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt -v /home/user/Megatron-LM:/opt/megatron-lm --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash
```

You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written.
Expand Down Expand Up @@ -92,7 +93,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p
To perform QAD training, run:

```bash
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4
```

## Supported models
Expand Down
31 changes: 25 additions & 6 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,26 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
if not distributed_sync:
return

def sync_quantizer_amax_across_dp(quantizer, parallel_state):
def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp(_q, parallel_state)
sync_quantizer_amax_across_dp_ep(_q, parallel_state)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
if parallel_state.expert_tensor_parallel_group is not None:
quantizer.sync_amax_across_distributed_group(
parallel_state.expert_tensor_parallel_group
)
Comment on lines +92 to +95
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor parallel sync here is not handled correctly across various cases. See the comments before sync_quantizer_amax_across_tp for more details

# TODO: create sync_bias_across_distributed_group

for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp(child, module.parallel_state)

sync_quantizer_amax_across_dp_ep(child, module.parallel_state)
# TP sync:
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

Expand All @@ -116,6 +121,7 @@ def sync_quantizer_amax_across_tp(
):
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
"Syncing amax across TP for sequential quantizer"
sync_quantizer_amax_across_tp(
_q, linear_name, quantizer_type, axes_for_sync, parallel_state
)
Expand Down Expand Up @@ -598,19 +604,32 @@ def forward(self, input, *args, **kwargs):
# This will also perform distributed amax sync for input_quantizers
max_calibrate(model, lambda model: None)

def sync_act_scale_across_dp(module, data_parallel_group):
"""Sync activation scale across Data Parallel (DP)."""
if data_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group
)

for name, module in model.named_modules():
if (
is_quantized_linear(module)
and hasattr(module, "awq_lite")
and module.awq_lite.num_cache_steps > 0
):
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps

if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
):
module.awq_lite.is_enabled = False
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
else:
sync_act_scale_across_dp(
module,
module.parallel_state.data_parallel_group,
)

AWQLiteHelper.cache_mode = False
print_rank_0("awq_lite: Searching parameters...")
Expand Down
124 changes: 122 additions & 2 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@

"""Support quantization for megatron linear layers."""

import logging
import warnings
from typing import Any

import megatron.core.parallel_state as mcore_parallel
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear
from megatron.core.extensions import transformer_engine as megatron_te
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
Expand All @@ -34,10 +38,12 @@
from modelopt.torch.utils.distributed import ParallelState

from ..nn import QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear
from ..qtensor import QTensorWrapper
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear

logger = logging.getLogger(__name__)

__all__ = []


Expand Down Expand Up @@ -217,9 +223,23 @@ class _MegatronParallelLinear(_ParallelLinear):
]

def _setup(self):
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
logger.warning("Context parallel group is not initialized, using data parallel group")
data_parallel_group = get_data_parallel_group()

try:
expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group()
except AssertionError:
expert_tensor_parallel_group = None

self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_expert_model_parallel_group(),
expert_tensor_parallel_group,
)
super()._setup()

Expand Down Expand Up @@ -462,3 +482,103 @@ class _RealQuantMegatronRowParallelLinear(

def forward(self, input, *args, **kwargs):
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)


# Register the public te.pytorch.GroupedLinear class
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"})
class _QuantTEGroupedLinear(_MegatronParallelLinear):
def _setup(self):
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()

try:
expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group()
except AssertionError:
expert_tensor_parallel_group = None
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_expert_model_parallel_group(),
expert_tensor_parallel_group,
)
self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input)
self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight)
self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output)
self.output_quantizer.disable()

# Memorize the original weight.dtype for modelopt_post_restore given that
# the dtype can change later.
self.original_weight_dtype = None if self.weight0 is None else self.weight0.dtype

@property
def functionals_to_replace(self):
original_forward = te_grouped_linear._GroupedLinear.forward

def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args):
num_gemms = len(m_splits)
weights_and_biases = args[-2 * num_gemms :]
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
quantized_inputs = self.input_quantizer(inp)
quantized_weights = [self.weight_quantizer(weight) for weight in weights]

output = original_forward(
ctx,
quantized_inputs,
m_splits,
*args[: -2 * num_gemms],
*quantized_weights,
*biases,
)
return self.output_quantizer(output)

return [
(
te_grouped_linear._GroupedLinear,
"forward",
te_grouped_quantized_linear_fn,
),
]

def modelopt_post_restore(self, prefix: str = ""):
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
# self.weight0 to self.weight to run the quantizer states initialization.
self.weight = self.weight0
super().modelopt_post_restore(prefix=prefix)
# Revert the weight to None after post_restore to avoid the weight being None during forward pass.
self.weight = None

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
# sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
# TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms]
# for modelopt checkpoint restore
filtered_state_dict = {
k: v
for k, v in state_dict.items()
if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms))
}
return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs)

def _process_quantizer_amax(self, k, v, quantizer_state_dict):
if v.ndim == 4:
quantizer_state_dict[k] = v.squeeze(1).squeeze(-1)
else:
quantizer_state_dict[k] = v.view(-1, 1) if v.numel() > 1 else v.view(-1)


@QuantModuleRegistry.register(
{megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"}
)
class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear):
_is_column_parallel = True


@QuantModuleRegistry.register(
{megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"}
)
class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear):
_is_row_parallel = True
17 changes: 16 additions & 1 deletion modelopt/torch/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,28 @@ def __init__(
self,
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
expert_tensor_parallel_group: torch.distributed.ProcessGroup | int | None = None,
):
"""Initialize the parallel state."""
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group)
self.expert_tensor_parallel_group = None
if expert_tensor_parallel_group is not None:
self.expert_tensor_parallel_group = DistributedProcessGroup(
expert_tensor_parallel_group
)

def __repr__(self) -> str:
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}"
parallel_groups = (
f"data_parallel_group: {self.data_parallel_group}, "
f"tensor_parallel_group: {self.tensor_parallel_group}, "
f"expert_model_parallel_group: {self.expert_model_parallel_group}"
)
if self.expert_tensor_parallel_group:
parallel_groups += f"expert_tensor_parallel_group: {self.expert_tensor_parallel_group}"
return parallel_groups


def get_group(ranks: list[int]):
Expand Down
Loading