Skip to content

Commit 1c821d8

Browse files
committed
code and test cleanup
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 38f839e commit 1c821d8

File tree

8 files changed

+125
-224
lines changed

8 files changed

+125
-224
lines changed

modelopt/torch/quantization/mode.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ def wrapped_calib_func(
208208
forward_loop and the relevant kwargs and are independent of the ModelOpt framework.
209209
So lets wrap them to be compatible with the ModelOpt convert entrypoint.
210210
"""
211+
from .plugins.custom import register_custom_post_calibration_plugins
212+
211213
kwargs = config.model_dump()
212214
method = kwargs.pop("method")
213215
if method is not None and "awq" in method:
@@ -218,6 +220,7 @@ def wrapped_calib_func(
218220
# Call the function with forward_loop as a separate argument
219221
func(model, forward_loop=forward_loop, **kwargs)
220222

223+
register_custom_post_calibration_plugins(model)
221224
# Lets get the latest metadata for the quantizer states
222225
metadata = {}
223226
update_quantize_metadata(model, config, metadata)
@@ -290,7 +293,10 @@ def convert(self) -> ConvertEntrypoint:
290293
def wrapped_func(model, config, forward_loop=None):
291294
# Access _calib_func as a class attribute to avoid binding
292295
# Check if _calib_func is defined as a class attribute
293-
return wrapped_calib_func(model, config, forward_loop, func=self.__class__._calib_func)
296+
calib_results = wrapped_calib_func(
297+
model, config, forward_loop, func=self.__class__._calib_func
298+
)
299+
return calib_results
294300

295301
return wrapped_func
296302

modelopt/torch/quantization/model_calib.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
8989
if getattr(quantizer, "_amax", None) is not None:
9090
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
9191
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
92-
if parallel_state.expert_tensor_parallel_group is not None:
93-
quantizer.sync_amax_across_distributed_group(
94-
parallel_state.expert_tensor_parallel_group
95-
)
9692
# TODO: create sync_bias_across_distributed_group
9793

9894
for name, module in model.named_modules():

modelopt/torch/quantization/plugins/custom.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
CUSTOM_MODEL_PLUGINS = set()
3232
CUSTOM_POST_CONVERSION_PLUGINS = set()
33+
CUSTOM_POST_CALIBRATION_PLUGINS = set()
3334

3435

3536
# TODO: This is a temporary solution
@@ -46,6 +47,12 @@ def register_custom_post_conversion_plugins(model):
4647
callback(model)
4748

4849

50+
def register_custom_post_calibration_plugins(model):
51+
"""Registers custom modules as QUANT_MODULE after calibration."""
52+
for callback in CUSTOM_POST_CALIBRATION_PLUGINS:
53+
callback(model)
54+
55+
4956
class _QuantFunctionalMixin(QuantModule):
5057
"""Mixin class for quantized functionals.
5158

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,51 @@
4040
from ..nn import QuantModuleRegistry, TensorQuantizer
4141
from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear
4242
from ..qtensor import QTensorWrapper
43-
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
43+
from .custom import CUSTOM_MODEL_PLUGINS, CUSTOM_POST_CALIBRATION_PLUGINS, _ParallelLinear
4444

4545
logger = logging.getLogger(__name__)
4646

4747
__all__ = []
4848

4949

50+
def sync_amax_across_sequential_mlp(model: torch.nn.Module):
51+
"""Sync amax across experts in a SequentialMLP."""
52+
amax_dict = {
53+
"linear_fc1.input_quantizer": {},
54+
"linear_fc1.weight_quantizer": {},
55+
"linear_fc2.input_quantizer": {},
56+
"linear_fc2.weight_quantizer": {},
57+
}
58+
# gather amax values from SequentialMLP experts
59+
for name, module in model.named_modules():
60+
if (
61+
not isinstance(module, TensorQuantizer)
62+
or not hasattr(module, "_amax")
63+
or "local_experts" not in name
64+
):
65+
continue
66+
expert_name, local_expert_name = name.split("local_experts")
67+
for key in amax_dict:
68+
if key in local_expert_name:
69+
amax_dict[key][expert_name] = max(amax_dict[key].get(expert_name, 0), module.amax)
70+
71+
# sync amax values across experts in SequentialMLP
72+
for name, module in model.named_modules():
73+
if (
74+
not isinstance(module, TensorQuantizer)
75+
or not hasattr(module, "_amax")
76+
or "local_experts" not in name
77+
):
78+
continue
79+
expert_name, local_expert_name = name.split("local_experts")
80+
for key in amax_dict:
81+
if key in local_expert_name:
82+
module.amax = amax_dict[key][expert_name]
83+
84+
85+
CUSTOM_POST_CALIBRATION_PLUGINS.add(sync_amax_across_sequential_mlp)
86+
87+
5088
def real_quant_module_get_extra_state(self) -> dict:
5189
"""Populating real_quantizer_state and q_tensor_state."""
5290
extra_state = {}
@@ -223,24 +261,19 @@ class _MegatronParallelLinear(_ParallelLinear):
223261
]
224262

225263
def _setup(self):
226-
data_parallel_group = None
227-
try:
228-
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
229-
except AssertionError:
230-
logger.warning("Context parallel group is not initialized, using data parallel group")
231-
data_parallel_group = get_data_parallel_group()
232-
233-
try:
234-
expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group()
235-
except AssertionError:
236-
expert_tensor_parallel_group = None
237-
238-
self.parallel_state = ParallelState(
239-
data_parallel_group,
240-
mcore_parallel.get_tensor_model_parallel_group(),
241-
mcore_parallel.get_expert_model_parallel_group(),
242-
expert_tensor_parallel_group,
243-
)
264+
if not hasattr(self, "parallel_state") or self.parallel_state is None:
265+
data_parallel_group = None
266+
try:
267+
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
268+
except AssertionError:
269+
logger.warning(
270+
"Context parallel group is not initialized, using data parallel group"
271+
)
272+
data_parallel_group = get_data_parallel_group()
273+
self.parallel_state = ParallelState(
274+
data_parallel_group,
275+
mcore_parallel.get_tensor_model_parallel_group(),
276+
)
244277
super()._setup()
245278

246279
def _process_quantizer_amax(self, k, v, quantizer_state_dict):
@@ -488,26 +521,22 @@ def forward(self, input, *args, **kwargs):
488521
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"})
489522
class _QuantTEGroupedLinear(_MegatronParallelLinear):
490523
def _setup(self):
491-
data_parallel_group = None
492-
try:
493-
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
494-
except AssertionError:
495-
data_parallel_group = get_data_parallel_group()
496-
497-
try:
498-
expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group()
499-
except AssertionError:
500-
expert_tensor_parallel_group = None
501-
self.parallel_state = ParallelState(
502-
data_parallel_group,
503-
mcore_parallel.get_tensor_model_parallel_group(),
504-
mcore_parallel.get_expert_model_parallel_group(),
505-
expert_tensor_parallel_group,
506-
)
507-
self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input)
508-
self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight)
509-
self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output)
510-
self.output_quantizer.disable()
524+
if not hasattr(self, "parallel_state") or self.parallel_state is None:
525+
data_parallel_group = None
526+
try:
527+
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
528+
except AssertionError:
529+
data_parallel_group = get_data_parallel_group()
530+
531+
self.parallel_state = ParallelState(
532+
data_parallel_group,
533+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
534+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
535+
)
536+
self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input)
537+
self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight)
538+
self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output)
539+
self.output_quantizer.disable()
511540

512541
# Memorize the original weight.dtype for modelopt_post_restore given that
513542
# the dtype can change later.
@@ -580,5 +609,5 @@ class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumn
580609
@QuantModuleRegistry.register(
581610
{megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"}
582611
)
583-
class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear):
612+
class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronRowParallelLinear):
584613
_is_row_parallel = True

modelopt/torch/utils/distributed.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -242,26 +242,18 @@ def __init__(
242242
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
243243
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
244244
expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
245-
expert_tensor_parallel_group: torch.distributed.ProcessGroup | int | None = None,
246245
):
247246
"""Initialize the parallel state."""
248247
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
249248
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
250249
self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group)
251-
self.expert_tensor_parallel_group = None
252-
if expert_tensor_parallel_group is not None:
253-
self.expert_tensor_parallel_group = DistributedProcessGroup(
254-
expert_tensor_parallel_group
255-
)
256250

257251
def __repr__(self) -> str:
258252
parallel_groups = (
259253
f"data_parallel_group: {self.data_parallel_group}, "
260254
f"tensor_parallel_group: {self.tensor_parallel_group}, "
261255
f"expert_model_parallel_group: {self.expert_model_parallel_group}"
262256
)
263-
if self.expert_tensor_parallel_group:
264-
parallel_groups += f"expert_tensor_parallel_group: {self.expert_tensor_parallel_group}"
265257
return parallel_groups
266258

267259

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 4 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import copy
16+
import re
1617
from warnings import warn
1718

1819
import torch
@@ -57,7 +58,6 @@
5758
save_sharded_modelopt_state,
5859
)
5960
from modelopt.torch.utils import to_empty_if_meta_device
60-
from modelopt.torch.utils.distributed import DistributedProcessGroup
6161

6262
try:
6363
from megatron.core.extensions.transformer_engine import TENorm
@@ -143,7 +143,7 @@ def get_mcore_gpt_model(
143143
tensor_model_parallel_size: int = 1,
144144
pipeline_model_parallel_size: int = 1,
145145
expert_model_parallel_size: int = 1,
146-
expert_tensor_parallel_size: int = 1,
146+
expert_tensor_parallel_size: int | None = None,
147147
initialize_megatron: bool = False,
148148
*,
149149
num_layers: int = 2,
@@ -497,61 +497,6 @@ def convert_maybe_fp8(v):
497497
)
498498

499499

500-
def compare_model_outputs(grouped_model, non_grouped_model, forward_fn, tolerance=1e-6):
501-
"""Compare outputs of grouped and non-grouped models."""
502-
# Set both models to eval mode
503-
grouped_model.eval()
504-
non_grouped_model.eval()
505-
506-
with torch.no_grad():
507-
# Get outputs from both models
508-
grouped_output = forward_fn(grouped_model)
509-
non_grouped_output = forward_fn(non_grouped_model)
510-
511-
# Compare outputs
512-
if isinstance(grouped_output, tuple):
513-
grouped_output = grouped_output[0]
514-
if isinstance(non_grouped_output, tuple):
515-
non_grouped_output = non_grouped_output[0]
516-
517-
output_close = torch.allclose(
518-
grouped_output, non_grouped_output, atol=tolerance, rtol=tolerance
519-
)
520-
return output_close
521-
522-
523-
def sync_amax(model):
524-
amax_dict = {
525-
"linear_fc1.input_quantizer": {},
526-
"linear_fc1.weight_quantizer": {},
527-
"linear_fc2.input_quantizer": {},
528-
"linear_fc2.weight_quantizer": {},
529-
}
530-
for name, module in model.named_modules():
531-
if not isinstance(module, mtq.nn.TensorQuantizer):
532-
continue
533-
if not hasattr(module, "_amax"):
534-
continue
535-
if "local_experts" not in name:
536-
continue
537-
expert_name, local_expert_name = name.split("local_experts")
538-
for key in amax_dict:
539-
if key in local_expert_name:
540-
amax_dict[key][expert_name] = max(amax_dict[key].get(expert_name, 0), module.amax)
541-
542-
for name, module in model.named_modules():
543-
if not isinstance(module, mtq.nn.TensorQuantizer):
544-
continue
545-
if not hasattr(module, "_amax"):
546-
continue
547-
if "local_experts" not in name:
548-
continue
549-
expert_name, local_expert_name = name.split("local_experts")
550-
for key in amax_dict:
551-
if key in local_expert_name:
552-
module.amax = amax_dict[key][expert_name]
553-
554-
555500
def copy_weights_from_grouped_to_non_grouped(grouped_model, non_grouped_model):
556501
"""Copy weights from grouped MoE model to non-grouped MoE model."""
557502
grouped_state = grouped_model.state_dict()
@@ -625,8 +570,6 @@ def compare_amax_sync_across_expert_parallel(model):
625570
# Create quantizer type key by normalizing the name
626571
if "local_experts" in name:
627572
# Non-grouped MoE: replace expert index with wildcard
628-
import re
629-
630573
quantizer_type = re.sub(r"local_experts\.\d+", "local_experts.*", name)
631574
else:
632575
# Grouped MoE: use the name as-is since experts are grouped
@@ -641,50 +584,7 @@ def compare_amax_sync_across_expert_parallel(model):
641584
if len(rank_values) > 1: # Only check if we have multiple ranks
642585
values = list(rank_values.values())
643586
max_diff = max(values) - min(values)
644-
645587
if max_diff > 1e-6: # Allow for small floating point differences
646-
return False
588+
return False, quantizer_type, rank_values
647589

648-
return True
649-
650-
651-
def disable_distributed_parallel_sync(model, expert_parallel_type: str = "tensor"):
652-
"""Disable distributed parallel synchronization groups."""
653-
module_parallel_groups = {}
654-
655-
for name, module in model.named_modules():
656-
if isinstance(module, mtq.nn.QuantModule):
657-
# Store original groups
658-
module_parallel_groups[name] = {
659-
"data_parallel_group": module.parallel_state.data_parallel_group,
660-
"expert_tensor_parallel_group": module.parallel_state.expert_tensor_parallel_group,
661-
"expert_model_parallel_group": module.parallel_state.expert_model_parallel_group,
662-
}
663-
664-
# Disable groups
665-
module.parallel_state.data_parallel_group = DistributedProcessGroup(-1)
666-
667-
if expert_parallel_type in ["tensor", "both"]:
668-
module.parallel_state.expert_tensor_parallel_group = DistributedProcessGroup(-1)
669-
if expert_parallel_type in ["model", "both"]:
670-
module.parallel_state.expert_model_parallel_group = DistributedProcessGroup(-1)
671-
672-
return module_parallel_groups
673-
674-
675-
def enable_distributed_parallel_sync(
676-
model, module_parallel_groups, expert_parallel_type: str = "tensor"
677-
):
678-
"""Re-enable distributed parallel synchronization groups."""
679-
for name, module in model.named_modules():
680-
if isinstance(module, mtq.nn.QuantModule) and name in module_parallel_groups:
681-
groups = module_parallel_groups[name]
682-
683-
if expert_parallel_type in ["tensor", "both"]:
684-
module.parallel_state.expert_tensor_parallel_group = groups[
685-
"expert_tensor_parallel_group"
686-
]
687-
if expert_parallel_type in ["model", "both"]:
688-
module.parallel_state.expert_model_parallel_group = groups[
689-
"expert_model_parallel_group"
690-
]
590+
return True, None, None

tests/gpu/torch/conftest.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,6 @@ def need_8_gpus():
4040
pytest.skip("Need at least 8 GPUs to run this test")
4141

4242

43-
@pytest.fixture
44-
def need_4_gpus():
45-
if torch.cuda.device_count() < 4:
46-
pytest.skip("Need at least 4 GPUs to run this test")
47-
48-
4943
@pytest.fixture(scope="module")
5044
def set_torch_dtype(request):
5145
orig_dtype = torch.get_default_dtype()

0 commit comments

Comments
 (0)