Skip to content

Commit db7155c

Browse files
committed
sync amax in context parallel
Signed-off-by: Jennifer Chen <[email protected]>
1 parent 7d5f636 commit db7155c

File tree

4 files changed

+10
-6
lines changed

4 files changed

+10
-6
lines changed

examples/nemo_run/qat/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p
9292
To perform QAD training, run:
9393

9494
```bash
95-
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
95+
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4
9696
```
9797

9898
## Supported models

modelopt/torch/quantization/model_calib.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,21 +79,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
7979
if not distributed_sync:
8080
return
8181

82-
def sync_quantizer_amax_across_dp(quantizer, parallel_state):
82+
def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state):
83+
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
8384
if isinstance(quantizer, SequentialQuantizer):
8485
for _q in quantizer:
85-
sync_quantizer_amax_across_dp(_q, parallel_state)
86+
sync_quantizer_amax_across_dp_cp(_q, parallel_state)
8687
return
8788
if getattr(quantizer, "_amax", None) is not None:
8889
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
90+
quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group)
8991
# TODO: create sync_bias_across_distributed_group
9092

9193
for name, module in model.named_modules():
9294
if isinstance(module, QuantModule):
9395
for child in module.children():
9496
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
95-
sync_quantizer_amax_across_dp(child, module.parallel_state)
96-
97+
sync_quantizer_amax_across_dp_cp(child, module.parallel_state)
9798
# TP sync:
9899
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
99100

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def _setup(self):
220220
self.parallel_state = ParallelState(
221221
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
222222
mcore_parallel.get_tensor_model_parallel_group(),
223+
mcore_parallel.get_context_parallel_group(),
223224
)
224225
super()._setup()
225226

modelopt/torch/utils/distributed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,15 @@ def __init__(
241241
self,
242242
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
243243
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
244+
context_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
244245
):
245246
"""Initialize the parallel state."""
246247
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
247248
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
249+
self.context_parallel_group = DistributedProcessGroup(context_parallel_group)
248250

249251
def __repr__(self) -> str:
250-
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}"
252+
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}, context_parallel_group: {self.context_parallel_group}"
251253

252254

253255
def get_group(ranks: list[int]):

0 commit comments

Comments
 (0)