Skip to content

Commit 7edf59c

Browse files
authored
[5704162] Create a copy to avoid leaking ProcessGroup into state dict (NVIDIA#640)
## What does this PR do? **Type of change:** ? Bug Fix **Overview:** ? Fixed a TypeError: cannot pickle 'torch._C._distributed_c10d.ProcessGroup' object error that occurs during checkpoint saving when using modelopt==0.40.0 with Megatron-LM. The ensure_metadata_has_dp_cp_group() function (introduced in NVIDIA#606) modifies the metadata dict in-place by adding a ProcessGroup object. This causes the ProcessGroup to leak into the common_state_dict, which is then broadcast via torch.distributed.broadcast_object_list() during checkpoint validation. Since ProcessGroup objects cannot be pickled, the save fails. Changed ensure_metadata_has_dp_cp_group() to create a new copy of the metadata dict instead of modifying it in-place. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Fridah-nv <[email protected]>
1 parent c1c5ca0 commit 7edf59c

File tree

3 files changed

+28
-50
lines changed

3 files changed

+28
-50
lines changed

modelopt/torch/opt/plugins/megatron.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,36 @@
2222
import megatron.core.transformer.mlp as megatron_mlp
2323
import regex as re
2424
import torch
25+
from megatron.core.parallel_state import get_data_parallel_group
2526

2627
from ..dynamic import DynamicModule
2728

2829

30+
def ensure_metadata_has_dp_cp_group(metadata):
31+
"""Ensure `metadata` is a dict containing `dp_cp_group` entry.
32+
33+
This function is adapted from megatron-lm's megatron.core.transformer.utils to avoid
34+
dependency on megatron-lm's specific version.
35+
36+
Note:
37+
This is a temporary method and will be removed once this function is merged to
38+
megatron.core.transformer.utils in the main branch of megatron-lm.
39+
"""
40+
# Create a copy to avoid modifying the original metadata dict
41+
# This prevents ProcessGroup from leaking into state dict
42+
if metadata is None:
43+
new_metadata = {}
44+
else:
45+
new_metadata = dict(metadata)
46+
if "dp_cp_group" not in new_metadata:
47+
try:
48+
new_metadata["dp_cp_group"] = get_data_parallel_group(with_context_parallel=True)
49+
except (AssertionError, RuntimeError):
50+
# Fallback if context parallel is not initialized
51+
new_metadata["dp_cp_group"] = get_data_parallel_group()
52+
return new_metadata
53+
54+
2955
def _modelopt_get_extra_state(self):
3056
"""Populating the extra_state when state_dict() is called.
3157

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from modelopt.torch.opt.plugins.megatron import (
3535
_MegatronMLP,
36+
ensure_metadata_has_dp_cp_group,
3637
register_modelopt_extra_state_callbacks,
3738
)
3839
from modelopt.torch.utils.distributed import ParallelState
@@ -230,30 +231,6 @@ def _register_extra_state_callbacks(model: torch.nn.Module):
230231
CUSTOM_MODEL_PLUGINS.add(megatron_replace_quant_module_hook)
231232

232233

233-
def ensure_metadata_has_dp_cp_group(metadata):
234-
"""Ensure `metadata` is a dict containing `dp_cp_group` entry.
235-
236-
If `metadata` is None, a new dict is returned with `dp_cp_group` set.
237-
If `metadata` is a dict and missing `dp_cp_group`, it is updated in-place.
238-
239-
This function is adapted from megatron-lm's megatron.core.transformer.utils to avoid
240-
dependency on megatron-lm's specific version.
241-
242-
Note:
243-
This is a temporary method and will be removed once this function is merged to
244-
megatron.core.transformer.utils in the main branch of megatron-lm.
245-
"""
246-
if metadata is None:
247-
metadata = {}
248-
if "dp_cp_group" not in metadata:
249-
try:
250-
metadata["dp_cp_group"] = get_data_parallel_group(with_context_parallel=True)
251-
except (AssertionError, RuntimeError):
252-
# Fallback if context parallel is not initialized
253-
metadata["dp_cp_group"] = get_data_parallel_group()
254-
return metadata
255-
256-
257234
class _MegatronParallelLinear(_ParallelLinear):
258235
_functionals_to_replace = [
259236
(megatron_parallel, "linear_with_grad_accumulation_and_async_allreduce"),

modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,40 +16,15 @@
1616
"""Support sparsify and save/resore for Megatron."""
1717

1818
import megatron.core.transformer.mlp as megatron_mlp
19-
from megatron.core.parallel_state import get_data_parallel_group
2019
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
2120
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
2221

23-
from modelopt.torch.opt.plugins.megatron import _MegatronMLP
22+
from modelopt.torch.opt.plugins.megatron import _MegatronMLP, ensure_metadata_has_dp_cp_group
2423

2524
from ..config import SparseGPTConfig, SparseMagnitudeConfig
2625
from ..module import SparseModule, SpDMRegistry
2726

2827

29-
def ensure_metadata_has_dp_cp_group(metadata):
30-
"""Ensure `metadata` is a dict containing `dp_cp_group` entry.
31-
32-
If `metadata` is None, a new dict is returned with `dp_cp_group` set.
33-
If `metadata` is a dict and missing `dp_cp_group`, it is updated in-place.
34-
35-
This function is adapted from megatron-lm's megatron.core.transformer.utils to avoid
36-
dependency on megatron-lm's specific version.
37-
38-
Note:
39-
This is a temporary method and will be removed once this function is merged to
40-
megatron.core.transformer.utils in the main branch of megatron-lm.
41-
"""
42-
if metadata is None:
43-
metadata = {}
44-
if "dp_cp_group" not in metadata:
45-
try:
46-
metadata["dp_cp_group"] = get_data_parallel_group(with_context_parallel=True)
47-
except (AssertionError, RuntimeError):
48-
# Fallback if context parallel is not initialized
49-
metadata["dp_cp_group"] = get_data_parallel_group()
50-
return metadata
51-
52-
5328
class _MegatronParallelLinear(SparseModule):
5429
def _get_shard_axis_dict(self, state_dict):
5530
raise NotImplementedError

0 commit comments

Comments
 (0)