Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/nemo_run/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p
To perform QAD training, run:

```bash
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4
```

## Supported models
Expand Down
27 changes: 23 additions & 4 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
if not distributed_sync:
return

def sync_quantizer_amax_across_dp(quantizer, parallel_state):
def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp(_q, parallel_state)
sync_quantizer_amax_across_dp_cp(_q, parallel_state)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group)
# TODO: create sync_bias_across_distributed_group

for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp(child, module.parallel_state)

sync_quantizer_amax_across_dp_cp(child, module.parallel_state)
# TP sync:
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

Expand Down Expand Up @@ -624,13 +625,31 @@ def forward(self, input, *args, **kwargs):
# This will also perform distributed amax sync for input_quantizers
max_calibrate(model, lambda model: None)

def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group):
# Sync across Data Parallel (DP)
if data_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group
)
# Sync across Context Parallel (CP)
if context_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group
)

for name, module in model.named_modules():
if (
is_quantized_linear(module)
and hasattr(module, "awq_lite")
and module.awq_lite.num_cache_steps > 0
):
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
sync_act_scale_across_dp_cp(
module,
module.parallel_state.data_parallel_group,
module.parallel_state.context_parallel_group,
)

# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True

Expand Down
9 changes: 8 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
Expand Down Expand Up @@ -217,9 +218,15 @@ class _MegatronParallelLinear(_ParallelLinear):
]

def _setup(self):
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
Comment on lines 224 to 233
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard get_context_parallel_group() when CP is disabled

get_context_parallel_group() asserts that context parallelism was initialized. When we run TP/DP-only (the default in plenty of setups), that assertion fires and _MegatronParallelLinear._setup() will crash. Please mirror the DP guard and fall back to -1 (unused) when the call raises.

Something along these lines keeps the DP-only path working:

-        self.parallel_state = ParallelState(
-            data_parallel_group,
-            mcore_parallel.get_tensor_model_parallel_group(),
-            mcore_parallel.get_context_parallel_group(),
-        )
+        try:
+            context_parallel_group = mcore_parallel.get_context_parallel_group()
+        except AssertionError:
+            context_parallel_group = -1
+        self.parallel_state = ParallelState(
+            data_parallel_group,
+            mcore_parallel.get_tensor_model_parallel_group(),
+            context_parallel_group,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
mcore_parallel.get_context_parallel_group(),
)
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
data_parallel_group = get_data_parallel_group()
try:
context_parallel_group = mcore_parallel.get_context_parallel_group()
except AssertionError:
context_parallel_group = -1
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
context_parallel_group,
)
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around lines 221 to 230, the
call to mcore_parallel.get_context_parallel_group() is unguarded and will assert
(and crash) when context-parallelism is disabled; mirror the data-parallel
guard: try to call get_context_parallel_group() and if it raises
(AssertionError) set the context group to -1 (or the sentinel used for
"unused"), then pass that value into ParallelState so TP/DP-only setups won't
fail. Ensure you only catch the assertion from the context-group call and keep
the existing fallback for get_data_parallel_group() unchanged.

super()._setup()

Expand Down
8 changes: 7 additions & 1 deletion modelopt/torch/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,19 @@ def __init__(
self,
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
context_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
):
"""Initialize the parallel state."""
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
self.context_parallel_group = DistributedProcessGroup(context_parallel_group)

def __repr__(self) -> str:
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}"
return (
f"data_parallel_group: {self.data_parallel_group}, "
f"tensor_parallel_group: {self.tensor_parallel_group}, "
f"context_parallel_group: {self.context_parallel_group}"
)


def get_group(ranks: list[int]):
Expand Down
11 changes: 8 additions & 3 deletions tests/_test_utils/torch_dist/plugins/megatron_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@


class MegatronModel(MegatronModule):
def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False):
config = TransformerConfig(
tensor_model_parallel_size=tp_size,
context_parallel_size=cp_size,
pipeline_model_parallel_size=1,
normalization="LayerNorm",
# Unused parameters below are set to avoid ZeroDivisionError in __post_init__
Expand Down Expand Up @@ -383,13 +384,17 @@ def run_mcore_inference_with_dummy_input(


def initialize_for_megatron(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1, seed=1234
):
"""Initialize Megatron model parallelism.

NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`.
"""
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
context_parallel_size=context_parallel_size,
)
model_parallel_cuda_manual_seed(seed)


Expand Down
91 changes: 89 additions & 2 deletions tests/_test_utils/torch_quantization/quantize_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
mto.restore_from_modelopt_state(model_ref, state_dict)


def tensor_parallel_test_helper(model, config, tp_group, dp_group):
# The input to fist layer, the column parallel should be the same across all tp ranks
def tensor_parallel_test_helper(model, config, tp_group):
# The input to first layer, the column parallel should be the same across all tp ranks
calib_data = model.get_dummy_input().cuda()
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)

Expand Down Expand Up @@ -150,6 +150,93 @@ def forward_loop(model):
dist.destroy_process_group()


def data_parallel_test_helper(model, config, dp_group):
calib_data = model.get_dummy_input().cuda()

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
fc1_amax = model.fc1.input_quantizer.amax.clone()
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
fc2_amax = model.fc2.input_quantizer.amax.clone()
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)

# Weight quantizer amax
fc1_amax = model.fc1.weight_quantizer.amax.clone()
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax)
fc2_amax = model.fc2.weight_quantizer.amax.clone()
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax)


def context_parallel_test_helper(model, config, cp_group):
calib_data = model.get_dummy_input().cuda()

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
fc1_amax = model.fc1.input_quantizer.amax.clone()
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group)
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
fc2_amax = model.fc2.input_quantizer.amax.clone()
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group)
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)

# Weight quantizer amax
fc1_weight_amax = model.fc1.weight_quantizer.amax.clone()
dist.all_reduce(fc1_weight_amax, op=dist.ReduceOp.MAX, group=cp_group)
assert torch.allclose(fc1_weight_amax, model.fc1.weight_quantizer.amax)
fc2_weight_amax = model.fc2.weight_quantizer.amax.clone()
dist.all_reduce(fc2_weight_amax, op=dist.ReduceOp.MAX, group=cp_group)
assert torch.allclose(fc2_weight_amax, model.fc2.weight_quantizer.amax)


def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group):
calib_data = model.get_dummy_input().cuda()
# data should be same across each TP rank
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
fc1_amax = model.fc1.input_quantizer.amax.clone()
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group)
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group)
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
fc2_amax = model.fc2.input_quantizer.amax.clone()
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group)
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group)
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)

fc1_amax = model.fc1.weight_quantizer.amax.clone()
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group)
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group)
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax)
fc2_amax = model.fc2.weight_quantizer.amax.clone()
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group)
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group)
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax)


def auto_quantize_helper(model):
model, search_state = mtq.auto_quantize(
model,
Expand Down
6 changes: 6 additions & 0 deletions tests/gpu/torch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def need_2_gpus():
pytest.skip("Need at least 2 GPUs to run this test")


@pytest.fixture
def need_8_gpus():
if torch.cuda.device_count() < 8:
pytest.skip("Need at least 8 GPUs to run this test")


@pytest.fixture(scope="module")
def set_torch_dtype(request):
orig_dtype = torch.get_default_dtype()
Expand Down
94 changes: 90 additions & 4 deletions tests/gpu/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from _test_utils.torch_quantization.quant_utils import get_model_size
from _test_utils.torch_quantization.quantize_common import (
auto_quantize_helper,
context_parallel_test_helper,
data_parallel_test_helper,
data_tensor_context_parallel_test_helper,
tensor_parallel_test_helper,
)
from packaging.version import Version
Expand All @@ -40,6 +43,7 @@
import megatron.core
from megatron.core.parallel_state import (
destroy_model_parallel,
get_context_parallel_group,
get_data_parallel_group,
get_tensor_model_parallel_group,
)
Expand Down Expand Up @@ -92,13 +96,12 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1):
destroy_model_parallel()


# 1. Tensor Parallel Test
def _test_tensor_parallel_helper(config, rank, size):
initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED)
model = MegatronModel(size).cuda()
model = MegatronModel(tp_size=size).cuda()

tensor_parallel_test_helper(
model, config, get_tensor_model_parallel_group(), get_data_parallel_group()
)
tensor_parallel_test_helper(model, config, get_tensor_model_parallel_group())


@pytest.mark.parametrize(
Expand All @@ -119,6 +122,89 @@ def test_tensor_parallel(need_2_gpus, config):
)


# 2. Data Parallel Test
def _test_data_parallel_helper(config, rank, size):
# TODO does this model automatically get copied to both DP ranks?
initialize_for_megatron(seed=SEED)
model = MegatronModel().cuda()

data_parallel_test_helper(model, config, get_data_parallel_group())


@pytest.mark.parametrize(
"config",
[
mtq.INT8_DEFAULT_CFG,
mtq.FP8_DEFAULT_CFG,
mtq.W4A8_AWQ_BETA_CFG,
mtq.INT8_SMOOTHQUANT_CFG,
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
mtq.INT4_AWQ_CFG,
mtq.NVFP4_DEFAULT_CFG,
],
)
def test_data_parallel(need_2_gpus, config):
spawn_multiprocess_job(size=2, job=partial(_test_data_parallel_helper, config), backend="nccl")


# 3. Context Parallel Test
def _test_context_parallel_helper(config, rank, size):
initialize_for_megatron(context_parallel_size=size, seed=SEED)
model = MegatronModel(cp_size=size).cuda()

context_parallel_test_helper(model, config, get_context_parallel_group())


@pytest.mark.parametrize(
"config",
[
mtq.INT8_DEFAULT_CFG,
mtq.FP8_DEFAULT_CFG,
mtq.W4A8_AWQ_BETA_CFG,
mtq.INT8_SMOOTHQUANT_CFG,
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
mtq.INT4_AWQ_CFG,
mtq.NVFP4_DEFAULT_CFG,
],
)
def test_context_parallel(need_2_gpus, config):
spawn_multiprocess_job(
size=2, job=partial(_test_context_parallel_helper, config), backend="nccl"
)


# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
def _test_data_tensor_context_parallel_helper(config, rank, size):
initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED)
model = MegatronModel(tp_size=2, cp_size=2).cuda()

data_tensor_context_parallel_test_helper(
model,
config,
get_data_parallel_group(),
get_tensor_model_parallel_group(),
get_context_parallel_group(),
)


@pytest.mark.parametrize(
"config",
[
mtq.INT8_DEFAULT_CFG,
mtq.FP8_DEFAULT_CFG,
mtq.W4A8_AWQ_BETA_CFG,
mtq.INT8_SMOOTHQUANT_CFG,
mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
mtq.INT4_AWQ_CFG,
mtq.NVFP4_DEFAULT_CFG,
],
)
def test_data_tensor_context_parallel(need_8_gpus, config):
spawn_multiprocess_job(
size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl"
)


def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False):
"""Build the model."""

Expand Down
Loading