Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9cac53c
sync amax in context parallel and awq act scale
jenchen13 Sep 24, 2025
be5e838
lint
jenchen13 Sep 25, 2025
4a2a8d7
test weight quantizer too
jenchen13 Sep 25, 2025
cacee61
fix test
jenchen13 Sep 26, 2025
41cc9bd
awq test
jenchen13 Sep 29, 2025
4a706ef
move awq test inside megatron tests
jenchen13 Sep 29, 2025
7b2c969
fix amax tests
jenchen13 Sep 30, 2025
e6dc5e5
fix awq lite param
jenchen13 Sep 30, 2025
f17320e
fix test
jenchen13 Sep 30, 2025
a1fdf18
add print
jenchen13 Oct 1, 2025
cd31159
docstring
jenchen13 Oct 1, 2025
5a67acf
fix tests
jenchen13 Oct 2, 2025
9d7dff1
fix multiprocess size
jenchen13 Oct 2, 2025
3bf16e6
Added quantization support for TEGroupedMoE for megatron-lm
kinjalpatel27 Oct 7, 2025
70776c3
code cleanup
kinjalpatel27 Oct 7, 2025
bab9ca2
code and test cleanup
kinjalpatel27 Oct 8, 2025
f9ba6e8
Updated moe names in tests
kinjalpatel27 Oct 9, 2025
a917c2b
updated parallel state for experts
kinjalpatel27 Oct 9, 2025
1ea4ed1
fixed bug for is_quantized_linear check
kinjalpatel27 Oct 9, 2025
169677c
code cleanup and bug fixes
kinjalpatel27 Oct 11, 2025
153e376
rebase bug fixes
kinjalpatel27 Oct 11, 2025
5bc99e0
fixing test and comments
kinjalpatel27 Oct 11, 2025
23daf38
Code cleanup
kinjalpatel27 Oct 13, 2025
15ffb87
Code cleanup and test update
kinjalpatel27 Oct 14, 2025
28c8bbf
remove post calib hook
kinjalpatel27 Oct 14, 2025
5481d10
fixed tests for per-channel support
kinjalpatel27 Oct 16, 2025
91837c3
minor fix
kinjalpatel27 Oct 16, 2025
ca55348
Addressed MR comments
kinjalpatel27 Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions modelopt/torch/quantization/mode.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not need register_custom_post_calibration_plugins. Lets not introduce new infrastructure un-necessarily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the point of post_calibration plugins now. Let's keep them as we discussed.

Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def wrapped_calib_func(
forward_loop and the relevant kwargs and are independent of the ModelOpt framework.
So lets wrap them to be compatible with the ModelOpt convert entrypoint.
"""
from .plugins.custom import register_custom_post_calibration_plugins

kwargs = config.model_dump()
method = kwargs.pop("method")
if method is not None and "awq" in method:
Expand All @@ -218,6 +220,7 @@ def wrapped_calib_func(
# Call the function with forward_loop as a separate argument
func(model, forward_loop=forward_loop, **kwargs)

register_custom_post_calibration_plugins(model)
# Lets get the latest metadata for the quantizer states
metadata = {}
update_quantize_metadata(model, config, metadata)
Expand Down
10 changes: 6 additions & 4 deletions modelopt/torch/quantization/model_calib.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change looks good!

Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
if not distributed_sync:
return

def sync_quantizer_amax_across_dp(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel group."""
def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and expert parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp(_q, parallel_state)
sync_quantizer_amax_across_dp_ep(_q, parallel_state)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
# TODO: create sync_bias_across_distributed_group

for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp(child, module.parallel_state)
sync_quantizer_amax_across_dp_ep(child, module.parallel_state)
# TP sync:
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

Expand All @@ -117,6 +118,7 @@ def sync_quantizer_amax_across_tp(
# Syncing amax across TP for sequential quantizer
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
"Syncing amax across TP for sequential quantizer"
sync_quantizer_amax_across_tp(
_q, linear_name, quantizer_type, axes_for_sync, parallel_state
)
Expand Down
7 changes: 7 additions & 0 deletions modelopt/torch/quantization/plugins/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

CUSTOM_MODEL_PLUGINS = set()
CUSTOM_POST_CONVERSION_PLUGINS = set()
CUSTOM_POST_CALIBRATION_PLUGINS = set()


# TODO: This is a temporary solution
Expand All @@ -46,6 +47,12 @@ def register_custom_post_conversion_plugins(model):
callback(model)


def register_custom_post_calibration_plugins(model):
"""Registers custom modules as QUANT_MODULE after calibration."""
for callback in CUSTOM_POST_CALIBRATION_PLUGINS:
callback(model)


class _QuantFunctionalMixin(QuantModule):
"""Mixin class for quantized functionals.

Expand Down
193 changes: 182 additions & 11 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import megatron.core.parallel_state as mcore_parallel
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import megatron.core.transformer.moe.experts as megatron_moe
import torch
import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear
from megatron.core.extensions import transformer_engine as megatron_te
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import MegatronModule
Expand All @@ -38,13 +41,49 @@
from ..nn import QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..qtensor import QTensorWrapper
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
from .custom import CUSTOM_MODEL_PLUGINS, CUSTOM_POST_CALIBRATION_PLUGINS, _ParallelLinear

logger = logging.getLogger(__name__)

__all__ = []


def sync_amax_across_sequential_mlp(model: torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do this only for per-tensor amax

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per-channel weight amax: element wise maximum for RowParallel (fc2 - RowParallel Cout dim is shared across experts)
per-channel weight amax for ColumnParallel - No Op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do in the followup MR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: per-tensor amax also works now. I have modified the test case to correctly check that.

"""Sync amax across experts in a SequentialMLP."""
amax_dict = {}

def get_sequential_mlp_expert_names(name: str, module: torch.nn.Module):
if (
isinstance(module, TensorQuantizer)
and hasattr(module, "_amax")
and ".local_experts." in name
):
expert_name, local_expert_name = name.split(".local_experts.")
# extract quantizer name by removing local_expert number from the name
local_expert_name = ".".join(local_expert_name.split(".")[1:])
return f"{expert_name}.{local_expert_name}"
return None

# gather amax values from SequentialMLP experts
for name, module in model.named_modules():
expert_name = get_sequential_mlp_expert_names(name, module)
if expert_name and module.amax is not None:
stored_amax = amax_dict.get(expert_name)
amax_tensor = module.amax.detach().clone()
amax_dict[expert_name] = (
amax_tensor if stored_amax is None else torch.maximum(stored_amax, amax_tensor)
)

# sync amax values across experts in SequentialMLP
for name, module in model.named_modules():
expert_name = get_sequential_mlp_expert_names(name, module)
if expert_name and module.amax is not None:
module.amax = amax_dict[expert_name].detach().clone().to(module.amax.device)


CUSTOM_POST_CALIBRATION_PLUGINS.add(sync_amax_across_sequential_mlp)


def real_quant_module_get_extra_state(self) -> dict:
"""Populating real_quantizer_state and q_tensor_state."""
extra_state = {}
Expand Down Expand Up @@ -221,16 +260,19 @@ class _MegatronParallelLinear(_ParallelLinear):
]

def _setup(self):
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
logger.warning("Context parallel group is not initialized, using data parallel group")
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
)
if not hasattr(self, "parallel_state") or self.parallel_state is None:
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
logger.warning(
"Context parallel group is not initialized, using data parallel group"
)
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
)
super()._setup()

def _process_quantizer_amax(self, k, v, quantizer_state_dict):
Expand Down Expand Up @@ -472,3 +514,132 @@ class _RealQuantMegatronRowParallelLinear(

def forward(self, input, *args, **kwargs):
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)


# Register the public te.pytorch.GroupedLinear class
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"})
class _QuantMegatronTEGroupedLinear(_MegatronParallelLinear):
_functionals_to_replace = [
(te_grouped_linear._GroupedLinear, "forward"),
(te_grouped_linear._GroupedLinear, "apply"),
]

def _setup(self):
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
# self.weight0 to self.weight to run the quantizer states initialization.
self.weight = self.weight0
# Memorize the original weight.dtype for modelopt_post_restore given that
# the dtype can change later.
super()._setup()
# Remove self.weight after setup.
delattr(self, "weight")

def modelopt_post_restore(self, prefix: str = ""):
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
# self.weight0 to self.weight to run the quantizer states initialization.
self.weight = self.weight0
super().modelopt_post_restore(prefix=prefix)
# Remove self.weight after post_restore.
delattr(self, "weight")

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
# sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
# TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms]
# for modelopt checkpoint restore
filtered_state_dict = {
k: v
for k, v in state_dict.items()
if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms))
}
return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs)

def _process_quantizer_amax(self, k, v, quantizer_state_dict):
assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization"
quantizer_state_dict[k] = v.view(-1)

@staticmethod
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
idx = 1 if func_name == "_forward" else 0
inp = args[idx]
num_gemms = len(args[idx + 1])
weights_and_biases = args[-2 * num_gemms :]
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
quantized_inputs = self.input_quantizer(inp)
quantized_weights = [self.weight_quantizer(weight) for weight in weights]

output = getattr(package, func_name)(
*(
args[0],
quantized_inputs,
)
if func_name == "_forward"
else (quantized_inputs,),
*args[idx + 1 : -2 * num_gemms],
*quantized_weights,
*biases,
)
return self.output_quantizer(output)

# Override the quantized linear function
_quantized_linear_fn = te_grouped_quantized_linear_fn


@QuantModuleRegistry.register(
{megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"}
)
class _MegatronTEGroupedColumnParallelLinear(
_QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear
):
_is_column_parallel = True


@QuantModuleRegistry.register(
{megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"}
)
class _MegatronTEGroupedRowParallelLinear(
_QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear
):
_is_row_parallel = True


# Register the public megatron_moe.TEGroupedMLP class
@QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"})
class _MegatronTEGroupedMLP(_MegatronMLP):
def _setup(self):
if not hasattr(self, "parallel_state") or self.parallel_state is None:
self.parallel_state = ParallelState(
mcore_parallel.get_expert_data_parallel_group(check_initialized=False),
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(
check_initialized=False
),
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(
check_initialized=False
),
)
# initialize parallel state for submodules linear_fc1 and linear_fc2
self.linear_fc1.parallel_state = self.parallel_state
self.linear_fc2.parallel_state = self.parallel_state


# Register the public megatron_moe.SequentialMLP class
@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"})
class _MegatronSequentialMLP(_MegatronMLP):
def _setup(self):
if not hasattr(self, "parallel_state") or self.parallel_state is None:
self.parallel_state = ParallelState(
mcore_parallel.get_expert_data_parallel_group(check_initialized=False),
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(
check_initialized=False
),
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(
check_initialized=False
),
)

# Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2
for expert in self.local_experts:
expert.linear_fc1.parallel_state = self.parallel_state
expert.linear_fc2.parallel_state = self.parallel_state
6 changes: 4 additions & 2 deletions modelopt/torch/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,10 @@ def is_quantized_linear(module):
isinstance(module, QuantModule)
and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer)
and hasattr(module, "weight_quantizer")
and getattr(module, "weight", None) is not None
and module.weight.dim() == 2
and (
(getattr(module, "weight", None) is not None and module.weight.dim() == 2)
or (getattr(module, "weight0", None) is not None and module.weight0.dim() == 2)
)
)


Expand Down
6 changes: 5 additions & 1 deletion modelopt/torch/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,16 +241,20 @@ def __init__(
self,
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
):
"""Initialize the parallel state."""
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group)

def __repr__(self) -> str:
return (
parallel_groups = (
f"data_parallel_group: {self.data_parallel_group}, "
f"tensor_parallel_group: {self.tensor_parallel_group}, "
f"expert_model_parallel_group: {self.expert_model_parallel_group}"
)
return parallel_groups


def get_group(ranks: list[int]):
Expand Down
Loading
Loading