From e39204de120617830eea28842104ef936c1987a4 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 10 Sep 2025 20:53:01 +0000 Subject: [PATCH 01/15] Fixed the CICD for Diffusion Signed-off-by: jingyu Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/diffusion_trt.py | 6 +++++- examples/diffusers/quantization/quantize.py | 8 ++++++-- examples/diffusers/quantization/requirements.txt | 1 + 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/diffusers/quantization/diffusion_trt.py b/examples/diffusers/quantization/diffusion_trt.py index 4db12e9c..1a0ec852 100644 --- a/examples/diffusers/quantization/diffusion_trt.py +++ b/examples/diffusers/quantization/diffusion_trt.py @@ -105,7 +105,11 @@ def main(): image_name = args.save_image_as if args.save_image_as else f"{args.model}.png" - pipe = PipelineManager.create_pipeline_from(MODEL_ID[args.model], dtype_map[args.model_dtype]) + pipe = PipelineManager.create_pipeline_from( + MODEL_ID[args.model], + dtype_map[args.model_dtype], + override_model_path=args.override_model_path, + ) # Save the backbone of the pipeline and move it to the GPU add_embedding = None diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 81c59392..f94a4a1a 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -309,7 +309,9 @@ def __init__(self, config: ModelConfig, logger: logging.Logger): @staticmethod def create_pipeline_from( - model_type: ModelType, torch_dtype: torch.dtype = torch.bfloat16 + model_type: ModelType, + torch_dtype: torch.dtype = torch.bfloat16, + override_model_path: str | None = None, ) -> DiffusionPipeline: """ Create and return an appropriate pipeline based on configuration. @@ -321,7 +323,9 @@ def create_pipeline_from( ValueError: If model type is unsupported """ try: - model_id = MODEL_REGISTRY[model_type] + model_id = ( + MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path + ) if model_type == ModelType.SD3_MEDIUM: pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype) elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: diff --git a/examples/diffusers/quantization/requirements.txt b/examples/diffusers/quantization/requirements.txt index 9c9a60b8..35d2ca4a 100644 --- a/examples/diffusers/quantization/requirements.txt +++ b/examples/diffusers/quantization/requirements.txt @@ -1,4 +1,5 @@ cuda-python +diffusers==0.34.0 nvtx onnx_graphsurgeon opencv-python>=4.8.1.78,<4.12.0.88 From 43ed09f5896288f91dc6f1d11f696771575ed59a Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 10 Sep 2025 21:01:13 +0000 Subject: [PATCH 02/15] Update req for diffusers Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/requirements.txt b/examples/diffusers/quantization/requirements.txt index 35d2ca4a..52921fe7 100644 --- a/examples/diffusers/quantization/requirements.txt +++ b/examples/diffusers/quantization/requirements.txt @@ -1,5 +1,5 @@ cuda-python -diffusers==0.34.0 +diffusers<=0.34.0 nvtx onnx_graphsurgeon opencv-python>=4.8.1.78,<4.12.0.88 From ececc24c229f2b7bc4ad8f2d8c18a2ad6cc8900c Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 15 Sep 2025 16:37:38 +0000 Subject: [PATCH 03/15] Add megatron lora support Signed-off-by: Jingyu Xin --- modelopt/torch/peft/__init__.py | 23 ++ modelopt/torch/peft/config.py | 434 ++++++++++++++++++++++ modelopt/torch/peft/conversion.py | 108 ++++++ modelopt/torch/peft/convert.py | 80 ++++ modelopt/torch/peft/lora/__init__.py | 3 + modelopt/torch/peft/lora/layer.py | 135 +++++++ modelopt/torch/peft/lora/tp_layer.py | 155 ++++++++ modelopt/torch/peft/mode.py | 73 ++++ modelopt/torch/quantization/conversion.py | 2 +- test.py | 189 ++++++++++ 10 files changed, 1201 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/peft/__init__.py create mode 100644 modelopt/torch/peft/config.py create mode 100644 modelopt/torch/peft/conversion.py create mode 100644 modelopt/torch/peft/convert.py create mode 100644 modelopt/torch/peft/lora/__init__.py create mode 100644 modelopt/torch/peft/lora/layer.py create mode 100644 modelopt/torch/peft/lora/tp_layer.py create mode 100644 modelopt/torch/peft/mode.py create mode 100644 test.py diff --git a/modelopt/torch/peft/__init__.py b/modelopt/torch/peft/__init__.py new file mode 100644 index 00000000..e2b8a0e9 --- /dev/null +++ b/modelopt/torch/peft/__init__.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distillation API subpackage for torch.""" + +from . import mode +from .config import * +from .convert import * +# isort: off +# Import plugins last to avoid circular imports +# from . import plugins diff --git a/modelopt/torch/peft/config.py b/modelopt/torch/peft/config.py new file mode 100644 index 00000000..cd007156 --- /dev/null +++ b/modelopt/torch/peft/config.py @@ -0,0 +1,434 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from typing import Literal + +from pydantic import ValidationInfo, field_validator, model_validator + +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField +from modelopt.torch.utils.network import ConstructorLike +BiasType = Literal["static", "dynamic"] +BiasMethod = Literal["mean", "max_min"] + +class QuantizerAttributeConfig(ModeloptBaseConfig): + """Quantizer attribute type.""" + + enable: bool = ModeloptField( + default=True, + title="Enable quantizer.", + description="""If True, enables the quantizer. If False, by-pass the quantizer and returns the input tensor.""", + ) + + num_bits: int | tuple[int, int] = ModeloptField( + default=8, + title="An integer or a tuple of two integers specifying the number of quantization bits.", + description="""`num_bits` can be: + + #. A positive integer argument for integer quantization. `num_bits` specify + the number of bits used for integer quantization. + + #. Constant integer tuple (E,M) for floating point quantization emulating + Nvidia's FPx quantization. E is the number of exponent bits and M is the number + of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1).""", + ) + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values): + """Validate quantizer config.""" + + def _validate_recursive(value): + """Recursively validate config structure.""" + if value is None: + return + + if isinstance(value, list): + for item in value: + _validate_recursive(item) + elif isinstance(value, dict): + if len(value) == 1 and "enable" in value and value["enable"] is True: + raise ValueError( + "Invalid quantizer config: Cannot specify only {'enable': True}. " + "Additional parameters are required when enabling quantization." + ) + # Recurse into nested dicts + for v in value.values(): + _validate_recursive(v) + + _validate_recursive(values) + return values + + @model_validator(mode="after") + def validate_num_bits(self): + """Validate `num_bits`.""" + num_bits = self.num_bits + + if isinstance(num_bits, int) and num_bits < 1: + raise ValueError("num_bits must be a positive integer or a tuple of positive integers.") + + if not isinstance(num_bits, tuple): + return self + + if not all(x > 0 for x in num_bits): + raise ValueError("num_bits must be a positive integer or a tuple of positive integers.") + + block_sizes = self.block_sizes + if num_bits not in [ + (4, 3), + (5, 2), + (2, 1), + (1, 2), + (0, 3), + (3, 0), + (3, 2), + (2, 3), + ]: + raise ValueError( + "Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1)." + ) + elif num_bits != (4, 3) and ( + block_sizes is None or block_sizes.get("type", None) != "dynamic" + ): + raise ValueError( + "Only blockwise dynamic quantization is supported with quantization " + "formats E{num_bis[0]}M{num_bits[1]}." + ) + return self + + axis: int | tuple[int, ...] | None = ModeloptField( + default=None, + title="None, integer or a tuple of integers specifying the axis to quantize.", + description="""This field is for static per-channel quantization. *It cannot coexist with `block_sizes`*. + You should set axis if you want a fixed shape of scale factor. + + For example, if axis is set to None, the scale factor will be a scalar (per-tensor quantization) + if the axis is set to 0, the scale factor will be a vector of shape (dim0, ) (per-channel quantization). + if the axis is set to (-2, -1), the scale factor will be a vector of shape (dim-2, dim-1) + + axis value must be in the range [-rank(input_tensor), rank(input_tensor)) + """, + ) + + fake_quant: bool = ModeloptField( + default=True, + title="Enable fake quantization.", + description="""If True, enable fake quantization.""", + ) + + unsigned: bool = ModeloptField( + default=False, + title="Enable unsigned quantization.", + description="""If True, enable unsigned quantization. Used only for integer quantization.""", + ) + + narrow_range: bool = ModeloptField( + default=False, + title="Enable narrow range quantization.", + description="""If True, enable narrow range quantization. Used only for integer quantization.""", + ) + + learn_amax: bool = ModeloptField( + default=False, + title="Enable learning amax.", + description="""``learn_amax`` is deprecated and reserved for backward compatibility.""", + ) + + @field_validator("learn_amax") + @classmethod + def validate_learn_amax(cls, v): + """Validate learn_amax.""" + assert v is not True, "learn_amax is deprecated and reserved for backward compatibility." + return v + + type: str = ModeloptField( + default="static", + title="""Specify whether the quantization is static or dynamic.""", + description="""The value is a string from ``["static", "dynamic"]``. + If ``"dynamic"``, dynamic quantization will be enabled which does not collect any statistics during + calibration.""", + pattern=r"^static$|^dynamic$", + ) + + block_sizes: dict[int | str, int | tuple[int, int] | str | dict[int, int] | None] | None = ( + ModeloptField( + default=None, + title="Optional dictionary specifying block quantization parameters.", + description="""This field is for static or dynamic block quantization. *It cannot coexist with ``axis``*. + You should set block_sizes if you want fixed number of elements to share every scale factor. + + The keys are the axes for block quantization and the + values are block sizes for quantization along the respective axes. Keys must be in the + range ``[-tensor.dim(), tensor.dim())``. Values, which are the block sizes for quantization must be + positive integers or ``None``. A positive block size specifies the block size for quantization along that + axis. ``None`` means that the block size will be the maximum possible size in that dimension - this is + useful for specifying certain quantization formats such per-token dynamic quantization which has the `amax` + shared along the last dimension. + + In addition, there can be special string keys ``"type"``, ``"scale_bits"`` and ``"scale_block_sizes"``. + + Key ``"type"`` should map to ``"dynamic"`` or ``"static"`` where ``"dynamic"`` + indicates dynamic block quantization and "static" + indicates static calibrated block quantization. By default, the type is ``"static"``. + + Key ``"scale_bits"`` specify the quantization bits for the per-block quantization scale factor + (i.e a double quantization scheme). + + Key ``"scale_block_sizes"`` specify the block size for double quantization. + By default per-block quantization scale is not quantized. + + For example, ``block_sizes = {-1: 32}`` will quantize the last axis of the input tensor in + blocks of size 32 with static calibration, with a total of ``numel(tensor) / 32`` scale factors. + ``block_sizes = {-1: 32, "type": "dynamic"}`` will perform dynamic block quantization. + ``block_sizes = {-1: None, "type": "dynamic"}`` can be used to + specify per-token dynamic quantization. + """, + ) + ) + + bias: dict[int | str, BiasType | BiasMethod | tuple[int, ...] | bool | int | None] | None = ( + ModeloptField( + default=None, + title="Bias configuration.", + description="""Configuration for bias handling in affine quantization. The keys are: + - "enable": Boolean to enable/disable bias handling, default is False + - "type": Specify the type of bias ["static", "dynamic"], default is "static" + - "method": Specify the method of bias calibration ["mean", "max_min"], default is "mean" + - "axis": Tuple of integers specifying axes for bias computation, default is None + + Examples: + bias = {"enable": True} + bias = {"enable": True, "type": "static", "axis": -1} + bias = {"enable": True, "type": "dynamic", "axis": (-1, -3)} + """, + ) + ) + + @staticmethod + def _get_block_quant_axes_and_sizes(block_sizes): + if block_sizes is None: + return None + return { + k: v + for k, v in block_sizes.items() + if k not in ["type", "scale_bits", "scale_block_sizes"] + } + + @field_validator("block_sizes") + @classmethod + def validate_block_sizes(cls, v, info: ValidationInfo): + """Validate block sizes.""" + if v is None: + return v + assert info.data["axis"] is None, "axis must be None when block_sizes is not None." + if v.get("type", None) == "dynamic": + assert len(cls._get_block_quant_axes_and_sizes(v)) == 1, ( + "Dynamic block quantization only supports quantization last axis." + ) + for _k, _v in v.items(): + if isinstance(_k, str): + assert _k in ["type", "scale_bits", "scale_block_sizes"] + else: + assert isinstance(_k, int) and (_v is None or isinstance(_v, int)) + return v + + @field_validator("bias") + @classmethod + def validate_bias(cls, v): + """Validate bias.""" + if v is None: + return v + + if "type" in v and v["type"] not in ["static", "dynamic"]: + raise ValueError(f"Invalid bias type: {v['type']}, expected 'static' or 'dynamic'") + + if "method" in v and v["method"] not in ["mean", "max_min"]: + raise ValueError(f"Invalid bias method: {v['method']}, expected 'mean' or 'max_min'") + + axis = [k for k in v.keys() if k not in ["type", "method"]] # noqa: SIM118 + assert len(axis) > 0, "The axis for bias computation is not specified." + for x in axis: + if not isinstance(x, int): + raise ValueError(f"Invalid axis type {type(axis)}, expected int") + + return v + + trt_high_precision_dtype: str = ModeloptField( + default="Float", + title="TRT StronglyType requires all weights and amax to be in the same dtype.", + description="""The value is a string from ``["Float", "Half", "BFloat16"]``. + The QDQs will be assigned the appropriate data type, and this variable will only be + used when the user is exporting the quantized ONNX model.""", + pattern=r"^Float$|^Half$|^BFloat16$", + ) + + calibrator: str | ConstructorLike = ModeloptField( + default="max", + title="""Specify the calibrator to use.""", + description="""The calibrator can be a string from ``["max", "histogram"]`` or a constructor + to create a calibrator which subclasses :class:`_Calibrator `. + See :meth:`standardize_constructor_args ` + for more information on how to specify the constructor.""", + ) + + @field_validator("calibrator") + @classmethod + def validate_calibrator(cls, v, info: ValidationInfo): + """Validate calibrator.""" + if isinstance(v, str): + assert v in ["max", "histogram"] + return v + + rotate: bool = ModeloptField( + default=False, + title="""If rotate the input before quantization.""", + description=""""If true, the input of the quantizer will be rotated with a hadamard matrix + given by scipy.linalg.hadamard, i.e. + ``input = input @ scipy.linalg.hadamard(input.shape[-1]) / sqrt(input.shape[-1])``. + + This can be used for ratation based PTQ methods, e.g. QuaRot or SpinQuant. + See https://arxiv.org/abs/2404.00456 for example.""", + ) + + pass_through_bwd: bool = ModeloptField( + default=False, + title="If set to true, fake quantization will be a pass through for gradient computation.", + description=""" + Gradient computation where fake quantization is pass through is called + 'Straight-Through Estimator (STE)'. STE does not require saving of the input tensor for + performing backward pass and hence consumes less memory. + + If set to False, we will use STE with zeroed outlier gradients. This setting could + yield better QAT accuracy depending on the quantization format. However, this setting + requires saving of the input tensor for computing gradients which uses more memory. + + For dynamic quantization formats like MXFP4, STE with zeroed outlier gradients + is not needed since fake quantization with dynamic amax results in minimal/no clipping. + """, + ) + + +class QuantizeAlgorithmConfig(ModeloptBaseConfig): + """Calibration algorithm config base.""" + + method: Literal[None] = ModeloptField( + None, + title="This field specifies the name of the calibration algorithm. If None, no calibration is performed.", + ) + + +class SVDQuantConfig(QuantizeAlgorithmConfig): + """The config for SVDQuant. + + Refer to the `SVDQuant paper `_ for more details. + """ + + method: Literal["svdquant"] = ModeloptField("svdquant") + + lowrank: int | None = ModeloptField( + default=32, + title="Low-rank dimension for the SVD LoRA", + description=( + "Specifies the rank of the LoRA used in the SVDQuant method, " + "which captures outliers from the original weights." + ), + ) + + +# QuantizeQuantCfgType = dict[ +# str | Callable, +# QuantizerAttributeConfig +# | list[QuantizerAttributeConfig] +# | dict[str | Callable, QuantizerAttributeConfig | list[QuantizerAttributeConfig]], +# ] + +# _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None + +# QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None + +#TODO Jingyu Xin +class PEFTConfig(ModeloptBaseConfig): + """Default configuration for ``quantize`` mode.""" + + adapter_name: str = ModeloptField( + default="default", + title="Placeholder", + validate_default=True, + ) + + adapter_cfg: dict = ModeloptField( + default={"default": {"rank": 128}}, + title="Placeholder", + validate_default=True, + ) + + adapter_type: str = ModeloptField( + default="lora", + title="Placeholder", + validate_default=True, + ) + +class ExportPEFTConfig(ModeloptBaseConfig): + """An empty config.""" +class CompressConfig(ModeloptBaseConfig): + """Default configuration for ``compress`` mode.""" + + compress: dict[str, bool] = ModeloptField( + default={"*": True}, + title="""Enable weight compression for the given pattern. Default is False for all weights. + Call `compress` function to compress the model weights.""", + ) + + quant_gemm: bool = ModeloptField( + default=True, + title="Enable quantized GEMM.", + description="If True, quantized GEMM compute will be enabled. Otherwise, we only do weight-only quantization.", + ) + + +CompressCfgType = dict[str, bool] | None | CompressConfig + + +class _QuantizeExportConfig(ModeloptBaseConfig): + """An empty config.""" + + +def need_calibration(config): + """Check if calibration is needed for the given config.""" + if config["algorithm"] is not None and config["algorithm"] != "max": + return True + + def _not_dynamic(cfg): + return ( + cfg.get("enable", True) + and cfg.get("type", "") != "dynamic" + and cfg.get("*", {}).get("enable", True) + ) + + for name, cfg in config.get("quant_cfg", {}).items(): + if "weight_quantizer" in name: + # We don't calibrate weight quantizer + continue + # quantization like W4A8 has a list of weight quantizers + if isinstance(cfg, list): + for _config in cfg: + if _not_dynamic(_config): + print(f"{cfg}: True") + return True + elif _not_dynamic(cfg): + print(f"{cfg}: True") + return True + + return False diff --git a/modelopt/torch/peft/conversion.py b/modelopt/torch/peft/conversion.py new file mode 100644 index 00000000..5081c0c2 --- /dev/null +++ b/modelopt/torch/peft/conversion.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantization conversion/restore utilities.""" + +import fnmatch +from collections.abc import Callable +from contextlib import contextmanager +from typing import Any + +import torch.nn as nn + +from modelopt.torch.opt.conversion import ApplyModeError, ModelLikeModule, ModeloptStateManager +from modelopt.torch.opt.dynamic import _DMRegistryCls +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict +from modelopt.torch.utils import get_unwrapped_name + +from .config import ( + PEFTConfig, + _QuantizeExportConfig, +) +from .lora.layer import LoRAModuleRegistry + +__all__ = [ + "replace_lora_module", +] + + +def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> ConvertReturnType: + """Convert the model to a quantized one as per `config`.""" + # initialize the true module if necessary + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + # TODO: Replace to LoRA module + replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config) + # set_quantizer_by_cfg(model, config.get("quant_cfg", {})) + + metadata = {} + # update_quantize_metadata(model, config, metadata) + + return model, metadata + +def restore_peft_model( + model: ModelLikeModule, config: PEFTConfig, metadata: MetadataDict +) -> nn.Module: + #TODO: implemente the restore logic + pass + + + +def update_peft_metadata( + model: nn.Module, config: PEFTConfig, metadata: MetadataDict +) -> None: + """Update the quantizer state in the metadata dict.""" + pass + + +def replace_lora_module(model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry): + """Recursively replace the module with quantized module.""" + #TODO: register the extra state for megatron-lm + + if type(model) in registry: + model = registry.convert(model) + _replace_lora_module(model, version=version, registry=registry) + +def export_peft_model(model: nn.Module, config): + """Export the quantized model to a quantized model.""" + raise NotImplementedError("Exporting a quantized model is not supported yet.") + + +def restore_export_peft_model( + model: nn.Module, config, metadata: MetadataDict +): + """Restores the quantized model from the given state dict.""" + raise NotImplementedError("Restoring a quantized & exported model is not supported yet.") + + +def _replace_lora_module(model: nn.Module, version=None,registry=LoRAModuleRegistry): + for name, child in model.named_children(): + if type(child) in registry: + lora_module = registry.convert(child) + setattr(model, name, lora_module) + + _replace_lora_module(getattr(model, name), version=version, registry=registry) + + +def export_quantized_model(model: nn.Module, config: _QuantizeExportConfig) -> ConvertReturnType: + """Export the quantized model to a quantized model.""" + raise NotImplementedError("Exporting a quantized model is not supported yet.") + + +def restore_export_quantized_model( + model: nn.Module, config: _QuantizeExportConfig, metadata: MetadataDict +) -> nn.Module: + """Restores the quantized model from the given state dict.""" + raise NotImplementedError("Restoring a quantized & exported model is not supported yet.") diff --git a/modelopt/torch/peft/convert.py b/modelopt/torch/peft/convert.py new file mode 100644 index 00000000..e3b98461 --- /dev/null +++ b/modelopt/torch/peft/convert.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""User-facing quantization API.""" + +import fnmatch +import inspect +import warnings +from collections.abc import Callable, Iterable +from typing import Any + +import torch +import torch.nn as nn + +# import modelopt.torch.quantization as mtq +from modelopt.torch.opt import apply_mode +# from modelopt.torch.opt.searcher import ForwardLoop +# from modelopt.torch.opt.utils import forward_with_reshard +from modelopt.torch.peft.config import PEFTConfig +# from modelopt.torch.quantization.conversion import set_quantizer_by_cfg + +# from . import config +# from .algorithms import AutoQuantizeSearcher +# from .config import QuantizeAlgoCfgType +# from .conversion import set_quantizer_attribute +from .mode import PEFTModeRegistry +from .lora.layer import LoRAModule +# from .nn import QuantModule, TensorQuantizer + +# __all__ = [ +# "auto_quantize", +# "calibrate", +# "disable_quantizer", +# "enable_quantizer", +# "fold_weight", +# "postprocess_amax", +# "print_quant_summary", +# "quantize", +# ] + +def update_model( + model: nn.Module, + config: dict[str, Any | PEFTConfig], +): + #TODO: deal with extra state, how to save the model + #TODO: sharded dict + #TODO: metadate + #TODO: how to restore the model + apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry) + return add_adapter(model, config) + +def add_adapter(model, config): + adapter_cfg = config["adapter_cfg"] + adapter_name = config["adapter_name"] + + for name, module in model.named_modules(): + if isinstance(module, LoRAModule): + for wildcard_or_filter_func, adapter_setting in adapter_cfg.items(): + if isinstance(wildcard_or_filter_func, str): + if not fnmatch.fnmatch(name, wildcard_or_filter_func): + continue + elif callable(wildcard_or_filter_func): + if not wildcard_or_filter_func(name): + continue + else: + raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") + module.update_layer_lora(adapter_name, adapter_setting["rank"]) + return model \ No newline at end of file diff --git a/modelopt/torch/peft/lora/__init__.py b/modelopt/torch/peft/lora/__init__.py new file mode 100644 index 00000000..2523392a --- /dev/null +++ b/modelopt/torch/peft/lora/__init__.py @@ -0,0 +1,3 @@ +from . import layer +from . import tp_layer +# from . import linear_layer \ No newline at end of file diff --git a/modelopt/torch/peft/lora/layer.py b/modelopt/torch/peft/lora/layer.py new file mode 100644 index 00000000..84676491 --- /dev/null +++ b/modelopt/torch/peft/lora/layer.py @@ -0,0 +1,135 @@ +"""LoRA (Low-Rank Adaptation) module implementation.""" + +from abc import abstractmethod +from typing import Dict, Tuple, Any, Optional +import torch +import torch.nn as nn + +from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls + +__all__ = [ + "LoRAModule", + "LoRAModuleRegistry", +] + + +class LoRAModule(DynamicModule): + """Base class for LoRA (Low-Rank Adaptation) modules. + + This module wraps existing layers and adds trainable low-rank decomposition + matrices (LoRA adapters) that are added to the original layer's output. + + Attributes: + _lora_adapters: Dictionary mapping adapter names to their LoRA A and B matrices + _active_adapters: Set of currently active adapter names + """ + + def _setup(self) -> None: + """Initialize LoRA-specific attributes.""" + self._lora_adapters: Dict[str, Dict[str, nn.Module]] = {} + self._active_adapters: set = set() + + @property + def adapter_names(self) -> set: + """Return the set of all registered adapter names.""" + return set(self._lora_adapters.keys()) + + @property + def active_adapters(self) -> set: + """Return the set of currently active adapter names.""" + return self._active_adapters.copy() + + def activate_adapter(self, adapter_name: str) -> None: + """Activate a specific adapter. + + Args: + adapter_name: Name of the adapter to activate + + Raises: + ValueError: If adapter_name is not registered + """ + if adapter_name not in self._lora_adapters: + raise ValueError(f"Adapter '{adapter_name}' not found. Available: {list(self._lora_adapters.keys())}") + self._active_adapters.add(adapter_name) + + def deactivate_adapter(self, adapter_name: str) -> None: + """Deactivate a specific adapter. + + Args: + adapter_name: Name of the adapter to deactivate + """ + self._active_adapters.discard(adapter_name) + + def activate_all_adapters(self) -> None: + """Activate all registered adapters.""" + self._active_adapters = self.adapter_names.copy() + + def deactivate_all_adapters(self) -> None: + """Deactivate all adapters.""" + self._active_adapters.clear() + + @abstractmethod + def update_layer_lora(self, adapter_name: str, rank: int = 64) -> None: + """Create and register a new LoRA adapter. + + This method must be implemented by subclasses to create the appropriate + LoRA A and B matrices for the specific layer type. + + Args: + adapter_name: Name for the new adapter + rank: Rank of the LoRA decomposition (default: 64) + """ + raise NotImplementedError("Subclasses must implement update_layer_lora") + + def forward(self, x: torch.Tensor, *args, **kwargs) -> Any: + """Forward pass with LoRA adaptation. + + Args: + x: Input tensor + *args: Additional positional arguments for the base layer + **kwargs: Additional keyword arguments for the base layer + + Returns: + Output from the base layer plus active LoRA adaptations + """ + # Call the base layer's forward method + output = super().forward(x, *args, **kwargs) + + # Handle different output types from base layer + if isinstance(output, tuple): + # If output is a tuple, assume first element is the main result + result = output[0] + other_outputs = output[1:] + else: + # If output is a single tensor + result = output + other_outputs = () + + # Apply active LoRA adapters + if self._active_adapters and self._lora_adapters: + for adapter_name in self._active_adapters: + if adapter_name in self._lora_adapters: + adapter = self._lora_adapters[adapter_name] + # LoRA computation: result = result + B(A(x)) + lora_a = adapter['lora_a'] + lora_b = adapter['lora_b'] + + # Handle different forward signatures + lora_a_output = lora_a(x) + if isinstance(lora_a_output, tuple): + lora_a_output = lora_a_output[0] + + lora_b_output = lora_b(lora_a_output) + if isinstance(lora_b_output, tuple): + lora_b_output = lora_b_output[0] + + result = result + lora_b_output + + # Return output in the same format as the base layer + if other_outputs: + return (result,) + other_outputs + else: + return result + + +LoRAModuleRegistry = _DMRegistryCls("LoRA", LoRAModule) \ No newline at end of file diff --git a/modelopt/torch/peft/lora/tp_layer.py b/modelopt/torch/peft/lora/tp_layer.py new file mode 100644 index 00000000..9e2ce7f2 --- /dev/null +++ b/modelopt/torch/peft/lora/tp_layer.py @@ -0,0 +1,155 @@ +"""Tensor Parallel LoRA implementations for Megatron layers.""" + +import math +from typing import Optional, Callable + +import torch +import torch.nn as nn +import torch.nn.init as init + +from megatron.core.tensor_parallel.layers import RowParallelLinear, ColumnParallelLinear + +from .layer import LoRAModuleRegistry, LoRAModule + + +# Default rank for LoRA decomposition +DEFAULT_LORA_RANK = 64 + + +class _MegatronParallelLoRABase(LoRAModule): + """Base class for Megatron tensor parallel LoRA implementations. + + This class provides common functionality for both ColumnParallel and RowParallel + LoRA implementations, reducing code duplication. + """ + + def _get_init_methods(self) -> tuple[Callable, Callable]: + """Get initialization methods for LoRA A and B matrices. + + Returns: + Tuple of (lora_a_init, lora_b_init) initialization functions + """ + # LoRA A uses Kaiming uniform initialization + lora_a_init = lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)) + # LoRA B is initialized to zero for stable training start + lora_b_init = lambda weight: init.zeros_(weight) + return lora_a_init, lora_b_init + + def _register_adapter(self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module) -> None: + """Register LoRA adapter modules. + + Args: + adapter_name: Name of the adapter + lora_a: LoRA A module (down-projection) + lora_b: LoRA B module (up-projection) + """ + # Move LoRA modules to the same device as the parent module + # Try to get device from parent module's parameters or buffers + device = None + for p in self.parameters(): + device = p.device + break + if device is None: + for b in self.buffers(): + device = b.device + break + + # If we found a device, move LoRA modules to it + if device is not None: + lora_a = lora_a.to(device) + lora_b = lora_b.to(device) + + # Add as submodules for proper parameter registration + self.add_module(f'lora_a_{adapter_name}', lora_a) + self.add_module(f'lora_b_{adapter_name}', lora_b) + + # Store in adapter dictionary + self._lora_adapters[adapter_name] = { + "lora_a": lora_a, + "lora_b": lora_b + } + + # Automatically activate new adapters + self.activate_adapter(adapter_name) + + +@LoRAModuleRegistry.register({ColumnParallelLinear: "megatron_ColumnParallelLinear"}) +class _MegatronColumnParallelLinear(_MegatronParallelLoRABase): + """LoRA implementation for Megatron ColumnParallelLinear layers. + + This implementation creates column-parallel LoRA adapters that match + the parallelization scheme of the base layer. + """ + + def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> None: + """Create and register a new LoRA adapter for ColumnParallelLinear. + + Args: + adapter_name: Name for the new adapter + rank: Rank of the LoRA decomposition + """ + lora_a_init, lora_b_init = self._get_init_methods() + + # Create LoRA A: input_size -> rank (with gather for full reduction) + lora_a = ColumnParallelLinear( + self.input_size, + rank, + config=self.config, + bias=False, + gather_output=True, # Gather outputs for complete transformation + init_method=lora_a_init, + disable_grad_reduce=getattr(self.config, 'sequence_parallel', False), + ) + + # Create LoRA B: rank -> output_size (no gather, stays distributed) + lora_b = ColumnParallelLinear( + rank, + self.output_size, + config=self.config, + bias=False, + gather_output=False, # Keep output distributed like base layer + init_method=lora_b_init, + ) + + self._register_adapter(adapter_name, lora_a, lora_b) + + +@LoRAModuleRegistry.register({RowParallelLinear: "megatron_RowParallelLinear"}) +class _MegatronRowParallelLinear(_MegatronParallelLoRABase): + """LoRA implementation for Megatron RowParallelLinear layers. + + This implementation creates row-parallel LoRA adapters that match + the parallelization scheme of the base layer. + """ + + def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> None: + """Create and register a new LoRA adapter for RowParallelLinear. + + Args: + adapter_name: Name for the new adapter + rank: Rank of the LoRA decomposition + """ + lora_a_init, lora_b_init = self._get_init_methods() + + # Create LoRA A: input_size -> rank (row parallel, input already distributed) + lora_a = RowParallelLinear( + self.input_size, + rank, + config=self.config, + input_is_parallel=True, # Input is already distributed + skip_bias_add=True, + bias=False, + init_method=lora_a_init, + ) + + # Create LoRA B: rank -> output_size (column parallel with gather) + lora_b = ColumnParallelLinear( + rank, + self.output_size, + config=self.config, + bias=False, + gather_output=True, # Gather to match base layer output + init_method=lora_b_init, + ) + + self._register_adapter(adapter_name, lora_a, lora_b) \ No newline at end of file diff --git a/modelopt/torch/peft/mode.py b/modelopt/torch/peft/mode.py new file mode 100644 index 00000000..8aefb211 --- /dev/null +++ b/modelopt/torch/peft/mode.py @@ -0,0 +1,73 @@ +from modelopt.torch.opt.config import ModeloptBaseConfig +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ConvertReturnType, + ModeConfigList, + ModeDescriptor, + RestoreEntrypoint, + UpdateEntrypoint, + _ModeRegistryCls, +) +from .config import PEFTConfig, ExportPEFTConfig +from .conversion import convert_to_peft_model, restore_peft_model, update_peft_metadata, export_peft_model, restore_export_peft_model + +PEFTModeRegistry = _ModeRegistryCls("PEFT") + +@PEFTModeRegistry.register_mode +class PEFTModeDescriptor(ModeDescriptor): + @property + def name(self) -> str: + return "peft" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + return PEFTConfig + + @property + def export_mode(self) -> str | None: + return "export_peft" + + @property + def convert(self) -> ConvertEntrypoint: + return convert_to_peft_model + + @property + def restore(self) -> RestoreEntrypoint: + return restore_peft_model + + @property + def update_for_save(self) -> UpdateEntrypoint: + return update_peft_metadata + + @property + def update_for_new_mode(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the models state before new mode.""" + return update_peft_metadata + +@PEFTModeRegistry.register_mode +class ExportPEFTModeDescriptor(ModeDescriptor): + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "export_peft" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return ExportPEFTConfig + + @property + def is_export_mode(self) -> bool: + """Specifies whether the mode is an export mode.""" + return True + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return export_peft_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_export_peft_model \ No newline at end of file diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 3c520877..3e37c2e0 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -176,7 +176,7 @@ def replace_quant_module(model: nn.Module, version=None, registry=QuantModuleReg """Recursively replace the module with quantized module.""" from .plugins.custom import ( register_custom_model_plugins_on_the_fly, - register_custom_post_conversion_plugins, + register_custom_post_conversion_plugins, ## not needed for lora megatron ) assert not is_quantized(model), "Model must not be quantized!" diff --git a/test.py b/test.py new file mode 100644 index 00000000..6b5188b2 --- /dev/null +++ b/test.py @@ -0,0 +1,189 @@ +# dummy_megatron_model.py +import os +import torch +import torch.nn.init as init +from megatron.core import parallel_state, tensor_parallel +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig + +import modelopt.torch.peft as mtp +import modelopt.torch.quantization as mtq + + +class DummyMegatronModel(MegatronModule): + """ + A simple dummy Megatron model with parallel linear layers for testing. + """ + def __init__(self, config: TransformerConfig): + super().__init__(config) + + # Column parallel linear layer (splits output dimension) + self.linear_0 = tensor_parallel.ColumnParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + gather_output=False, + ) + self.linear_1 = tensor_parallel.RowParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + input_is_parallel=True, + skip_bias_add=True, + ) + # Row parallel linear layer (splits input dimension) + self.lm_head_0 = tensor_parallel.ColumnParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + gather_output=False, + ) + self.lm_head_1 = tensor_parallel.RowParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + input_is_parallel=True, + skip_bias_add=True, + ) + + def forward(self, input): + x = self.linear_0(input)[0] + x = self.linear_1(x)[0] + x = self.lm_head_0(x)[0] + x = self.lm_head_1(x)[0] + return x + + +def initialize_distributed(rank=0, world_size=1): + """Initialize torch distributed for parallel training.""" + if torch.distributed.is_initialized(): + return + + print(f"Initializing torch.distributed with rank: {rank}, world_size: {world_size}") + torch.cuda.set_device(rank) + + init_method = "tcp://" + master_ip = os.getenv("MASTER_ADDR", "localhost") + master_port = os.getenv("MASTER_PORT", "6001") + init_method += master_ip + ":" + master_port + + torch.distributed.init_process_group( + backend="nccl", + world_size=world_size, + rank=rank, + init_method=init_method + ) + + +def initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=None, +): + """Initialize Megatron's model parallel groups.""" + # Destroy existing model parallel if any + parallel_state.destroy_model_parallel() + + # Initialize distributed if not already done + if not torch.distributed.is_initialized(): + initialize_distributed() + + # Initialize model parallel groups + parallel_state.initialize_model_parallel( + tensor_model_parallel_size, + pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size, + pipeline_model_parallel_split_rank, + ) + + +def create_dummy_megatron_model(): + """ + Create and return a dummy Megatron model. + + Returns: + DummyMegatronModel: The initialized model on CUDA + """ + # Initialize model parallel (single GPU by default) + initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1 + ) + + # Set random seed for reproducibility + model_parallel_cuda_manual_seed(123) + + # Configure the transformer + transformer_config = { + "num_layers": 2, + "hidden_size": 12, + "num_attention_heads": 4, + "use_cpu_initialization": True, + } + config = TransformerConfig(**transformer_config) + + # Create and return the model + model = DummyMegatronModel(config=config) + + if torch.cuda.is_available(): + model = model.cuda() + + return model + + +def cleanup(): + """Clean up distributed and model parallel groups.""" + parallel_state.destroy_model_parallel() + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + # Example usage + try: + # Create the model + model = create_dummy_megatron_model() + print(f"Created dummy Megatron model: {model}") + # Test forward pass + if torch.cuda.is_available(): + x = torch.randn(2, 4, 10).cuda() + output = model(x) + print(f"Input shape: {x.shape}") + print(f"Output shape: {output.shape}") + + # # Print model structure + # print("\nModel structure:") + # for name, module in model.named_modules(): + # print(f" {name}: {module.__class__.__name__}") + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*transformer*qkv*": {"rank": 64}, + "*ffn*": {"rank": 128}, + "*linear*": {"rank": 128} + } + } + # model = mtp.update(model, mode=[("peft", lora_config)]) + model = mtp.update_model(model, lora_config) + if torch.cuda.is_available(): + x = torch.randn(2, 4, 10).cuda() + output = model(x) + print(f"Input shape: {x.shape}") + print(f"Output shape: {output.shape}") + # mtq.quantize(model, mtq.MXFP4_DEFAULT_CFG) + finally: + # Clean up + cleanup() + print("\nCleaned up distributed environment") \ No newline at end of file From e70abb3c2a31dc086dcbb97b3e62429162d45ee2 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 17 Sep 2025 05:45:55 +0000 Subject: [PATCH 04/15] Update Signed-off-by: Jingyu Xin --- modelopt/torch/peft/__init__.py | 3 +- modelopt/torch/peft/conversion.py | 119 ++++++-- modelopt/torch/peft/convert.py | 55 +++- modelopt/torch/peft/custom.py | 29 ++ modelopt/torch/peft/lora/layer.py | 182 ++++++++++-- modelopt/torch/peft/lora/tp_layer.py | 73 ++--- modelopt/torch/peft/plugins/__init__.py | 22 ++ modelopt/torch/peft/plugins/megatron.py | 92 ++++++ run_tp_test.sh | 35 +++ test.py | 190 ++++++++---- test_single.py | 178 ++++++++++++ tests/gpu/torch/peft/test_forward_megatron.py | 274 ++++++++++++++++++ 12 files changed, 1078 insertions(+), 174 deletions(-) create mode 100644 modelopt/torch/peft/custom.py create mode 100644 modelopt/torch/peft/plugins/__init__.py create mode 100644 modelopt/torch/peft/plugins/megatron.py create mode 100644 run_tp_test.sh create mode 100644 test_single.py create mode 100644 tests/gpu/torch/peft/test_forward_megatron.py diff --git a/modelopt/torch/peft/__init__.py b/modelopt/torch/peft/__init__.py index e2b8a0e9..3a500359 100644 --- a/modelopt/torch/peft/__init__.py +++ b/modelopt/torch/peft/__init__.py @@ -18,6 +18,7 @@ from . import mode from .config import * from .convert import * + # isort: off # Import plugins last to avoid circular imports -# from . import plugins +from . import plugins diff --git a/modelopt/torch/peft/conversion.py b/modelopt/torch/peft/conversion.py index 5081c0c2..40873bf4 100644 --- a/modelopt/torch/peft/conversion.py +++ b/modelopt/torch/peft/conversion.py @@ -15,26 +15,20 @@ """Quantization conversion/restore utilities.""" -import fnmatch -from collections.abc import Callable -from contextlib import contextmanager from typing import Any import torch.nn as nn from modelopt.torch.opt.conversion import ApplyModeError, ModelLikeModule, ModeloptStateManager -from modelopt.torch.opt.dynamic import _DMRegistryCls from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict from modelopt.torch.utils import get_unwrapped_name -from .config import ( - PEFTConfig, - _QuantizeExportConfig, -) -from .lora.layer import LoRAModuleRegistry +from .config import PEFTConfig, _QuantizeExportConfig +from .lora.layer import LoRAModule, LoRAModuleRegistry __all__ = [ "replace_lora_module", + "update_peft_metadata_in_model", ] @@ -48,46 +42,88 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert # set_quantizer_by_cfg(model, config.get("quant_cfg", {})) metadata = {} - # update_quantize_metadata(model, config, metadata) + update_peft_metadata(model, config, metadata) return model, metadata + def restore_peft_model( model: ModelLikeModule, config: PEFTConfig, metadata: MetadataDict ) -> nn.Module: - #TODO: implemente the restore logic - pass - - - -def update_peft_metadata( - model: nn.Module, config: PEFTConfig, metadata: MetadataDict -) -> None: - """Update the quantizer state in the metadata dict.""" - pass - + convert_to_peft_model(model, config) + return restore_peft_state(model, metadata) + + +def restore_peft_state(model: ModelLikeModule, metadata: MetadataDict): + """Restore PEFT state from metadata or extra_state. + For backward compatibility, we check metadata first. For distributed + checkpoints (NeMo-MCore), the state will be in extra_state of each LoRAModule + and will be restored automatically via set_extra_state() during load_state_dict(). + + Args: + model: Model with LoRA modules to restore + metadata: Metadata dictionary that may contain peft_state + Returns: + The model with restored PEFT state + """ + if "peft_state" not in metadata: + # For distributed checkpoints (NeMo-MCore), peft_state is stored + # in each LoRAModule's extra_state and will be restored via + # set_extra_state() during load_state_dict() + return model + + # Legacy path: restore from metadata + peft_state_dict = metadata["peft_state"] + for name, module in model.named_modules(): + if isinstance(module, LoRAModule): + unwrapped_name = get_unwrapped_name(name) + if unwrapped_name in peft_state_dict: + try: + module.set_from_peft_state(peft_state_dict[unwrapped_name]) + except Exception as e: + raise ApplyModeError(f"Failed to restore PEFT state for module {name}: {e}") + + return model + + +def update_peft_metadata(model: nn.Module, config: PEFTConfig, metadata: MetadataDict) -> None: + """Update the PEFT/LoRA state in the metadata dict.""" + metadata["peft_state"] = peft_state(model) + + +def peft_state(model: nn.Module) -> dict[str, Any]: + return { + get_unwrapped_name(n): m.get_peft_state() + for n, m in model.named_modules() + if isinstance(m, LoRAModule) + } + + +def replace_lora_module( + model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry +): + """Recursively replace the module with LoRA module.""" + # Register custom plugins (e.g., for Megatron distributed checkpointing) + from .custom import register_custom_model_plugins_on_the_fly -def replace_lora_module(model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry): - """Recursively replace the module with quantized module.""" - #TODO: register the extra state for megatron-lm + register_custom_model_plugins_on_the_fly(model) if type(model) in registry: model = registry.convert(model) _replace_lora_module(model, version=version, registry=registry) + def export_peft_model(model: nn.Module, config): """Export the quantized model to a quantized model.""" raise NotImplementedError("Exporting a quantized model is not supported yet.") -def restore_export_peft_model( - model: nn.Module, config, metadata: MetadataDict -): +def restore_export_peft_model(model: nn.Module, config, metadata: MetadataDict): """Restores the quantized model from the given state dict.""" raise NotImplementedError("Restoring a quantized & exported model is not supported yet.") -def _replace_lora_module(model: nn.Module, version=None,registry=LoRAModuleRegistry): +def _replace_lora_module(model: nn.Module, version=None, registry=LoRAModuleRegistry): for name, child in model.named_children(): if type(child) in registry: lora_module = registry.convert(child) @@ -106,3 +142,30 @@ def restore_export_quantized_model( ) -> nn.Module: """Restores the quantized model from the given state dict.""" raise NotImplementedError("Restoring a quantized & exported model is not supported yet.") + + +def update_peft_metadata_in_model(model: nn.Module) -> None: + """Update the PEFT metadata in the model's ModeloptStateManager. + This function should be called after manually modifying LoRA adapters to ensure + the metadata stored in the ModeloptStateManager reflects the current state. + + Args: + model: Model with LoRA modules whose metadata needs updating + Example: + >>> # After manually adding/modifying adapters + >>> for module in model.modules(): + ... if isinstance(module, LoRAModule): + ... module.update_layer_lora("custom_adapter", rank=32) + >>> # Update metadata to reflect changes + >>> update_peft_metadata_in_model(model) + """ + # Check if model has ModeloptStateManager (has been converted with peft mode) + if not ModeloptStateManager.is_converted(model): + return + + # Get the state manager + manager = ModeloptStateManager(model) + + # Update the metadata with current PEFT state + if manager._state and manager._last_metadata is not None: + manager._last_metadata["peft_state"] = peft_state(model) diff --git a/modelopt/torch/peft/convert.py b/modelopt/torch/peft/convert.py index e3b98461..7e656136 100644 --- a/modelopt/torch/peft/convert.py +++ b/modelopt/torch/peft/convert.py @@ -16,27 +16,28 @@ """User-facing quantization API.""" import fnmatch -import inspect -import warnings -from collections.abc import Callable, Iterable from typing import Any -import torch import torch.nn as nn # import modelopt.torch.quantization as mtq from modelopt.torch.opt import apply_mode + +# from modelopt.torch.quantization.conversion import set_quantizer_by_cfg +from modelopt.torch.opt.conversion import ModeloptStateManager + # from modelopt.torch.opt.searcher import ForwardLoop # from modelopt.torch.opt.utils import forward_with_reshard from modelopt.torch.peft.config import PEFTConfig -# from modelopt.torch.quantization.conversion import set_quantizer_by_cfg + +from .lora.layer import LoRAModule # from . import config # from .algorithms import AutoQuantizeSearcher # from .config import QuantizeAlgoCfgType # from .conversion import set_quantizer_attribute from .mode import PEFTModeRegistry -from .lora.layer import LoRAModule + # from .nn import QuantModule, TensorQuantizer # __all__ = [ @@ -50,17 +51,19 @@ # "quantize", # ] + def update_model( model: nn.Module, config: dict[str, Any | PEFTConfig], ): - #TODO: deal with extra state, how to save the model - #TODO: sharded dict - #TODO: metadate - #TODO: how to restore the model + # TODO: deal with extra state, how to save the model + # TODO: sharded dict + # TODO: metadate + # TODO: how to restore the model apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry) return add_adapter(model, config) + def add_adapter(model, config): adapter_cfg = config["adapter_cfg"] adapter_name = config["adapter_name"] @@ -77,4 +80,34 @@ def add_adapter(model, config): else: raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") module.update_layer_lora(adapter_name, adapter_setting["rank"]) - return model \ No newline at end of file + + # Update the metadata in ModeloptStateManager after adding adapters + _update_peft_metadata_in_state(model) + return model + + +def _update_peft_metadata_in_state(model: nn.Module) -> None: + """Update the PEFT metadata in the ModeloptStateManager. + + This function updates the metadata to reflect the current state of LoRA adapters + after they have been added or modified. + """ + # Check if model has ModeloptStateManager (has been converted with peft mode) + if not ModeloptStateManager.is_converted(model): + return + + # Get the state manager + manager = ModeloptStateManager(model) + + # Get current PEFT state from all LoRA modules + current_peft_state = {} + for name, module in model.named_modules(): + if isinstance(module, LoRAModule): + from modelopt.torch.utils import get_unwrapped_name + + unwrapped_name = get_unwrapped_name(name) + current_peft_state[unwrapped_name] = module.get_peft_state() + + # Update the metadata in the last mode state (which should be 'peft') + if manager._state and manager._last_metadata is not None: + manager._last_metadata["peft_state"] = current_peft_state diff --git a/modelopt/torch/peft/custom.py b/modelopt/torch/peft/custom.py new file mode 100644 index 00000000..580efcf5 --- /dev/null +++ b/modelopt/torch/peft/custom.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom PEFT/LoRA plugins registry.""" + +# Registry for custom model plugins +CUSTOM_MODEL_PLUGINS = set() + + +def register_custom_model_plugins_on_the_fly(model): + """Registers custom PEFT/LoRA plugins on the fly. + + This is called before LoRAModule replacement to allow plugins + to configure the model (e.g., for distributed checkpointing). + """ + for callback in CUSTOM_MODEL_PLUGINS: + callback(model) diff --git a/modelopt/torch/peft/lora/layer.py b/modelopt/torch/peft/lora/layer.py index 84676491..82f41812 100644 --- a/modelopt/torch/peft/lora/layer.py +++ b/modelopt/torch/peft/lora/layer.py @@ -1,7 +1,9 @@ """LoRA (Low-Rank Adaptation) module implementation.""" +import warnings from abc import abstractmethod -from typing import Dict, Tuple, Any, Optional +from typing import Any + import torch import torch.nn as nn @@ -15,86 +17,206 @@ class LoRAModule(DynamicModule): """Base class for LoRA (Low-Rank Adaptation) modules. - + This module wraps existing layers and adds trainable low-rank decomposition matrices (LoRA adapters) that are added to the original layer's output. - + Attributes: _lora_adapters: Dictionary mapping adapter names to their LoRA A and B matrices _active_adapters: Set of currently active adapter names """ - + def _setup(self) -> None: """Initialize LoRA-specific attributes.""" - self._lora_adapters: Dict[str, Dict[str, nn.Module]] = {} + self._lora_adapters: dict[str, dict[str, Any]] = {} self._active_adapters: set = set() - + @property def adapter_names(self) -> set: """Return the set of all registered adapter names.""" return set(self._lora_adapters.keys()) - + @property def active_adapters(self) -> set: """Return the set of currently active adapter names.""" return self._active_adapters.copy() - + def activate_adapter(self, adapter_name: str) -> None: """Activate a specific adapter. - + Args: adapter_name: Name of the adapter to activate - + Raises: ValueError: If adapter_name is not registered """ if adapter_name not in self._lora_adapters: - raise ValueError(f"Adapter '{adapter_name}' not found. Available: {list(self._lora_adapters.keys())}") + raise ValueError( + f"Adapter '{adapter_name}' not found. Available: {list(self._lora_adapters.keys())}" + ) self._active_adapters.add(adapter_name) - + def deactivate_adapter(self, adapter_name: str) -> None: """Deactivate a specific adapter. - + Args: adapter_name: Name of the adapter to deactivate """ self._active_adapters.discard(adapter_name) - + def activate_all_adapters(self) -> None: """Activate all registered adapters.""" self._active_adapters = self.adapter_names.copy() - + def deactivate_all_adapters(self) -> None: """Deactivate all adapters.""" self._active_adapters.clear() - + + def _register_adapter( + self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int + ) -> None: + """Register a new LoRA adapter with explicit rank tracking. + + Args: + adapter_name: Name of the adapter + lora_a: LoRA A module (down-projection) + lora_b: LoRA B module (up-projection) + rank: Rank of the LoRA decomposition + """ + # Add as submodules for proper parameter registration + self.add_module(f"lora_a_{adapter_name}", lora_a) + self.add_module(f"lora_b_{adapter_name}", lora_b) + + # Store in adapter dictionary with explicit rank + self._lora_adapters[adapter_name] = { + "lora_a": lora_a, + "lora_b": lora_b, + "rank": rank, # Store rank explicitly for reliability + } + + # Automatically activate new adapters + self.activate_adapter(adapter_name) + @abstractmethod def update_layer_lora(self, adapter_name: str, rank: int = 64) -> None: """Create and register a new LoRA adapter. - + This method must be implemented by subclasses to create the appropriate LoRA A and B matrices for the specific layer type. - + Args: adapter_name: Name for the new adapter rank: Rank of the LoRA decomposition (default: 64) """ raise NotImplementedError("Subclasses must implement update_layer_lora") - + + def get_peft_state(self) -> dict[str, Any]: + """Get PEFT/LoRA state to be saved in checkpoint. + + This method returns the configuration and state of all LoRA adapters + without including the actual weight tensors. + + Returns: + Dictionary containing: + - adapters: Dict mapping adapter names to their configuration + - active_adapters: List of currently active adapter names + """ + modelopt_state = {} + + # Store adapter configurations + adapters_config = {} + for adapter_name, adapter_modules in self._lora_adapters.items(): + lora_a = adapter_modules["lora_a"] + lora_b = adapter_modules["lora_b"] + + # Get explicitly stored rank for reliability + rank = adapter_modules.get("rank", None) + + # If rank is not stored (legacy case), try to infer it + if rank is None: + if hasattr(lora_a, "output_size"): + rank = lora_a.output_size + elif hasattr(lora_b, "input_size"): + rank = lora_b.input_size + elif hasattr(lora_a, "out_features"): + rank = lora_a.out_features + elif hasattr(lora_b, "in_features"): + rank = lora_b.in_features + + adapters_config[adapter_name] = { + "rank": rank, + "is_active": adapter_name in self._active_adapters, + "lora_a_type": type(lora_a).__name__, + "lora_b_type": type(lora_b).__name__, + } + + modelopt_state["adapters"] = adapters_config + modelopt_state["active_adapters"] = list(self._active_adapters) + + # Store the base module type for validation + modelopt_state["base_module_type"] = type(self).__name__ + + return modelopt_state + + def get_extra_state(self) -> dict[str, Any]: + """Get extra state for distributed checkpointing. + + For distributed/sharded checkpoints (like NeMo-MCore), we store the PEFT state + as extra_state instead of in metadata. This handles cases where module names + change with different parallelism settings (TP, PP, EP). + + Returns: + Dictionary containing the PEFT/LoRA adapter state + """ + # Only return state if we have adapters + if not self._lora_adapters: + return {} + + # Get the current PEFT state + peft_state = self.get_peft_state() + + return {"modelopt_peft_state": peft_state} + + def set_extra_state(self, state: dict[str, Any]) -> None: + """Restore extra state for distributed checkpointing. + + This method is called during load_state_dict() to restore the PEFT/LoRA state + from distributed checkpoints. It handles the adapter configuration but not + the actual weights (which are restored through the normal state_dict mechanism). + + Args: + state: Dictionary containing the extra state to restore + """ + if state is None: + return + + peft_state = state.get("modelopt_peft_state") + if peft_state is None: + return + + # Restore the PEFT state + try: + self.set_from_peft_state(peft_state) + except Exception as e: + warnings.warn( + f"Failed to restore PEFT state from extra_state: {e}. " + "This might happen if the model structure has changed." + ) + def forward(self, x: torch.Tensor, *args, **kwargs) -> Any: """Forward pass with LoRA adaptation. - + Args: x: Input tensor *args: Additional positional arguments for the base layer **kwargs: Additional keyword arguments for the base layer - + Returns: Output from the base layer plus active LoRA adaptations """ # Call the base layer's forward method output = super().forward(x, *args, **kwargs) - + # Handle different output types from base layer if isinstance(output, tuple): # If output is a tuple, assume first element is the main result @@ -104,27 +226,27 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Any: # If output is a single tensor result = output other_outputs = () - + # Apply active LoRA adapters if self._active_adapters and self._lora_adapters: for adapter_name in self._active_adapters: if adapter_name in self._lora_adapters: adapter = self._lora_adapters[adapter_name] # LoRA computation: result = result + B(A(x)) - lora_a = adapter['lora_a'] - lora_b = adapter['lora_b'] - + lora_a = adapter["lora_a"] + lora_b = adapter["lora_b"] + # Handle different forward signatures lora_a_output = lora_a(x) if isinstance(lora_a_output, tuple): lora_a_output = lora_a_output[0] - + lora_b_output = lora_b(lora_a_output) if isinstance(lora_b_output, tuple): lora_b_output = lora_b_output[0] - + result = result + lora_b_output - + # Return output in the same format as the base layer if other_outputs: return (result,) + other_outputs @@ -132,4 +254,4 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Any: return result -LoRAModuleRegistry = _DMRegistryCls("LoRA", LoRAModule) \ No newline at end of file +LoRAModuleRegistry = _DMRegistryCls("LoRA", LoRAModule) diff --git a/modelopt/torch/peft/lora/tp_layer.py b/modelopt/torch/peft/lora/tp_layer.py index 9e2ce7f2..cceeac1a 100644 --- a/modelopt/torch/peft/lora/tp_layer.py +++ b/modelopt/torch/peft/lora/tp_layer.py @@ -1,16 +1,13 @@ """Tensor Parallel LoRA implementations for Megatron layers.""" import math -from typing import Optional, Callable +from collections.abc import Callable -import torch import torch.nn as nn import torch.nn.init as init +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.tensor_parallel.layers import RowParallelLinear, ColumnParallelLinear - -from .layer import LoRAModuleRegistry, LoRAModule - +from .layer import LoRAModule, LoRAModuleRegistry # Default rank for LoRA decomposition DEFAULT_LORA_RANK = 64 @@ -18,14 +15,14 @@ class _MegatronParallelLoRABase(LoRAModule): """Base class for Megatron tensor parallel LoRA implementations. - + This class provides common functionality for both ColumnParallel and RowParallel LoRA implementations, reducing code duplication. """ - + def _get_init_methods(self) -> tuple[Callable, Callable]: """Get initialization methods for LoRA A and B matrices. - + Returns: Tuple of (lora_a_init, lora_b_init) initialization functions """ @@ -34,14 +31,17 @@ def _get_init_methods(self) -> tuple[Callable, Callable]: # LoRA B is initialized to zero for stable training start lora_b_init = lambda weight: init.zeros_(weight) return lora_a_init, lora_b_init - - def _register_adapter(self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module) -> None: - """Register LoRA adapter modules. - + + def _register_adapter_with_device( + self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int + ) -> None: + """Register LoRA adapter modules and ensure correct device placement. + Args: adapter_name: Name of the adapter lora_a: LoRA A module (down-projection) lora_b: LoRA B module (up-projection) + rank: Rank of the LoRA decomposition """ # Move LoRA modules to the same device as the parent module # Try to get device from parent module's parameters or buffers @@ -53,43 +53,32 @@ def _register_adapter(self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Mod for b in self.buffers(): device = b.device break - + # If we found a device, move LoRA modules to it if device is not None: lora_a = lora_a.to(device) lora_b = lora_b.to(device) - - # Add as submodules for proper parameter registration - self.add_module(f'lora_a_{adapter_name}', lora_a) - self.add_module(f'lora_b_{adapter_name}', lora_b) - - # Store in adapter dictionary - self._lora_adapters[adapter_name] = { - "lora_a": lora_a, - "lora_b": lora_b - } - - # Automatically activate new adapters - self.activate_adapter(adapter_name) + + super()._register_adapter(adapter_name, lora_a, lora_b, rank) @LoRAModuleRegistry.register({ColumnParallelLinear: "megatron_ColumnParallelLinear"}) class _MegatronColumnParallelLinear(_MegatronParallelLoRABase): """LoRA implementation for Megatron ColumnParallelLinear layers. - + This implementation creates column-parallel LoRA adapters that match the parallelization scheme of the base layer. """ - + def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> None: """Create and register a new LoRA adapter for ColumnParallelLinear. - + Args: adapter_name: Name for the new adapter rank: Rank of the LoRA decomposition """ lora_a_init, lora_b_init = self._get_init_methods() - + # Create LoRA A: input_size -> rank (with gather for full reduction) lora_a = ColumnParallelLinear( self.input_size, @@ -98,9 +87,9 @@ def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> bias=False, gather_output=True, # Gather outputs for complete transformation init_method=lora_a_init, - disable_grad_reduce=getattr(self.config, 'sequence_parallel', False), + disable_grad_reduce=getattr(self.config, "sequence_parallel", False), ) - + # Create LoRA B: rank -> output_size (no gather, stays distributed) lora_b = ColumnParallelLinear( rank, @@ -110,27 +99,27 @@ def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> gather_output=False, # Keep output distributed like base layer init_method=lora_b_init, ) - - self._register_adapter(adapter_name, lora_a, lora_b) + + self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank) @LoRAModuleRegistry.register({RowParallelLinear: "megatron_RowParallelLinear"}) class _MegatronRowParallelLinear(_MegatronParallelLoRABase): """LoRA implementation for Megatron RowParallelLinear layers. - + This implementation creates row-parallel LoRA adapters that match the parallelization scheme of the base layer. """ - + def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> None: """Create and register a new LoRA adapter for RowParallelLinear. - + Args: adapter_name: Name for the new adapter rank: Rank of the LoRA decomposition """ lora_a_init, lora_b_init = self._get_init_methods() - + # Create LoRA A: input_size -> rank (row parallel, input already distributed) lora_a = RowParallelLinear( self.input_size, @@ -141,7 +130,7 @@ def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> bias=False, init_method=lora_a_init, ) - + # Create LoRA B: rank -> output_size (column parallel with gather) lora_b = ColumnParallelLinear( rank, @@ -151,5 +140,5 @@ def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> gather_output=True, # Gather to match base layer output init_method=lora_b_init, ) - - self._register_adapter(adapter_name, lora_a, lora_b) \ No newline at end of file + + self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank) diff --git a/modelopt/torch/peft/plugins/__init__.py b/modelopt/torch/peft/plugins/__init__.py new file mode 100644 index 00000000..d760b877 --- /dev/null +++ b/modelopt/torch/peft/plugins/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PEFT/LoRA plugins for various frameworks.""" + +# Import plugins to register them +try: + from . import megatron +except ImportError: + pass # Megatron not available diff --git a/modelopt/torch/peft/plugins/megatron.py b/modelopt/torch/peft/plugins/megatron.py new file mode 100644 index 00000000..6e349cb2 --- /dev/null +++ b/modelopt/torch/peft/plugins/megatron.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron-Core specific PEFT/LoRA plugins.""" + +from typing import Any + +import torch + +from modelopt.torch.opt.plugins.megatron import register_modelopt_extra_state_callbacks +from modelopt.torch.peft.lora.layer import LoRAModuleRegistry + +# Import MegatronModule if available +try: + from megatron.core.transformer.module import MegatronModule + + MEGATRON_AVAILABLE = True +except ImportError: + MegatronModule = None + MEGATRON_AVAILABLE = False + +from ..custom import CUSTOM_MODEL_PLUGINS + +__all__ = [] + + +def lora_module_get_extra_state(self) -> dict: + """Get extra state for LoRA modules. + + This is called by the modelopt extra state framework to gather + PEFT/LoRA state for distributed checkpointing. + """ + # LoRAModule already has get_extra_state method + return self.get_extra_state() + + +def lora_module_set_extra_state(self, state: Any): + """Set extra state for LoRA modules. + + This is called by the modelopt extra state framework to restore + PEFT/LoRA state from distributed checkpoints. + """ + # LoRAModule already has set_extra_state method + self.set_extra_state(state) + + +def megatron_replace_lora_module_hook(model: torch.nn.Module): + """Configure Megatron-Core model PEFT/LoRA support. + + This callback is called before the LoRAModule replacement to configure + distributed checkpointing support. For each MegatronModule: + 1. We enable heterogeneous distributed checkpointing + 2. We register extra_state callbacks for all LoRAModule submodules + """ + if not MEGATRON_AVAILABLE: + return + + def _register_extra_state_callbacks(model: torch.nn.Module): + """Register extra state callbacks for LoRA modules.""" + for name, module in model.named_modules(): + if type(module) in LoRAModuleRegistry: + # This module will be replaced as a LoRAModule + register_modelopt_extra_state_callbacks( + module, + lora_module_get_extra_state, + lora_module_set_extra_state, + ) + + for name, module in model.named_modules(): + if isinstance(module, MegatronModule): + # Enable heterogeneous distributed checkpointing + if hasattr(module, "config") and hasattr( + module.config, "hetereogenous_dist_checkpoint" + ): + module.config.hetereogenous_dist_checkpoint = True + _register_extra_state_callbacks(module) + + +# Register the hook +CUSTOM_MODEL_PLUGINS.add(megatron_replace_lora_module_hook) diff --git a/run_tp_test.sh b/run_tp_test.sh new file mode 100644 index 00000000..b38c55d5 --- /dev/null +++ b/run_tp_test.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Script to run the test with tensor parallelism + +# Set the number of GPUs for tensor parallelism +NUM_GPUS=2 + +echo "Running Megatron model with Tensor Parallelism (TP=$NUM_GPUS)" +echo "This will use $NUM_GPUS GPUs" + +# Check if torchrun is available +if command -v torchrun &> /dev/null; then + echo "Using torchrun to launch the distributed job..." + torchrun --nproc_per_node=$NUM_GPUS test.py +else + echo "torchrun not found, using manual distributed launch..." + + # Set environment variables + export MASTER_ADDR=localhost + export MASTER_PORT=6001 + export WORLD_SIZE=$NUM_GPUS + + # Launch processes + for ((rank=0; rank<$NUM_GPUS; rank++)); do + echo "Launching rank $rank..." + RANK=$rank python test.py & + pids[$rank]=$! + done + + # Wait for all processes to complete + for pid in ${pids[*]}; do + wait $pid + done +fi + +echo "Test completed!" diff --git a/test.py b/test.py index 6b5188b2..538f84fa 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,21 @@ -# dummy_megatron_model.py +"""Megatron Tensor Parallel Model Test Script + +This script demonstrates: +1. Creating a Megatron model with tensor parallelism (TP=2) +2. Applying LoRA adapters to tensor parallel layers +3. Testing the model with proper distributed initialization + +To run with tensor parallelism: + torchrun --nproc_per_node=2 test.py + or + bash run_tp_test.sh + +The model uses ColumnParallelLinear and RowParallelLinear layers which +automatically handle weight sharding across GPUs when TP > 1. +""" + import os + import torch import torch.nn.init as init from megatron.core import parallel_state, tensor_parallel @@ -8,28 +24,32 @@ from megatron.core.transformer.transformer_config import TransformerConfig import modelopt.torch.peft as mtp -import modelopt.torch.quantization as mtq class DummyMegatronModel(MegatronModule): + """A simple dummy Megatron model with parallel linear layers for testing. + Uses larger dimensions to better demonstrate tensor parallelism. """ - A simple dummy Megatron model with parallel linear layers for testing. - """ + def __init__(self, config: TransformerConfig): super().__init__(config) - + + # Larger dimensions for better tensor parallel demonstration + hidden_size = 1024 # Divisible by 2 for TP=2 + intermediate_size = 4096 # 4x hidden size, typical for transformers + # Column parallel linear layer (splits output dimension) self.linear_0 = tensor_parallel.ColumnParallelLinear( - input_size=10, - output_size=10, + input_size=hidden_size, + output_size=intermediate_size, config=config, init_method=init.xavier_normal_, bias=False, gather_output=False, ) self.linear_1 = tensor_parallel.RowParallelLinear( - input_size=10, - output_size=10, + input_size=intermediate_size, + output_size=hidden_size, config=config, init_method=init.xavier_normal_, bias=False, @@ -38,16 +58,16 @@ def __init__(self, config: TransformerConfig): ) # Row parallel linear layer (splits input dimension) self.lm_head_0 = tensor_parallel.ColumnParallelLinear( - input_size=10, - output_size=10, + input_size=hidden_size, + output_size=intermediate_size, config=config, init_method=init.xavier_normal_, bias=False, gather_output=False, ) self.lm_head_1 = tensor_parallel.RowParallelLinear( - input_size=10, - output_size=10, + input_size=intermediate_size, + output_size=hidden_size, config=config, init_method=init.xavier_normal_, bias=False, @@ -67,20 +87,17 @@ def initialize_distributed(rank=0, world_size=1): """Initialize torch distributed for parallel training.""" if torch.distributed.is_initialized(): return - + print(f"Initializing torch.distributed with rank: {rank}, world_size: {world_size}") torch.cuda.set_device(rank) - + init_method = "tcp://" master_ip = os.getenv("MASTER_ADDR", "localhost") master_port = os.getenv("MASTER_PORT", "6001") init_method += master_ip + ":" + master_port - + torch.distributed.init_process_group( - backend="nccl", - world_size=world_size, - rank=rank, - init_method=init_method + backend="nccl", world_size=world_size, rank=rank, init_method=init_method ) @@ -93,11 +110,11 @@ def initialize_model_parallel( """Initialize Megatron's model parallel groups.""" # Destroy existing model parallel if any parallel_state.destroy_model_parallel() - + # Initialize distributed if not already done if not torch.distributed.is_initialized(): initialize_distributed() - + # Initialize model parallel groups parallel_state.initialize_model_parallel( tensor_model_parallel_size, @@ -107,37 +124,46 @@ def initialize_model_parallel( ) -def create_dummy_megatron_model(): - """ - Create and return a dummy Megatron model. - +def create_dummy_megatron_model(tensor_model_parallel_size=2): + """Create and return a dummy Megatron model with tensor parallelism. + + Args: + tensor_model_parallel_size: Size of tensor model parallelism (default: 2) + Returns: DummyMegatronModel: The initialized model on CUDA """ - # Initialize model parallel (single GPU by default) + # Get rank from environment or default to 0 + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", str(tensor_model_parallel_size))) + + # Initialize distributed and model parallel + initialize_distributed(rank=rank, world_size=world_size) initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1 + tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=1 ) - + # Set random seed for reproducibility model_parallel_cuda_manual_seed(123) - - # Configure the transformer + + # Configure the transformer with larger dimensions transformer_config = { - "num_layers": 2, - "hidden_size": 12, - "num_attention_heads": 4, + "num_layers": 4, + "hidden_size": 1024, # Must match model dimensions + "num_attention_heads": 16, "use_cpu_initialization": True, + "sequence_parallel": False, # Set to True for sequence parallelism } config = TransformerConfig(**transformer_config) - + # Create and return the model model = DummyMegatronModel(config=config) - + if torch.cuda.is_available(): model = model.cuda() - + + print(f"Model created on rank {rank} with TP size {tensor_model_parallel_size}") + return model @@ -150,40 +176,80 @@ def cleanup(): if __name__ == "__main__": - # Example usage + """ + To run with tensor parallelism size 2, use: + torchrun --nproc_per_node=2 test.py + + Or manually with: + RANK=0 WORLD_SIZE=2 MASTER_ADDR=localhost MASTER_PORT=6001 python test.py & + RANK=1 WORLD_SIZE=2 MASTER_ADDR=localhost MASTER_PORT=6001 python test.py + """ try: - # Create the model - model = create_dummy_megatron_model() - print(f"Created dummy Megatron model: {model}") + # Create the model with TP=2 + tensor_parallel_size = 2 + model = create_dummy_megatron_model(tensor_model_parallel_size=tensor_parallel_size) + + # Get rank for printing + rank = int(os.environ.get("RANK", "0")) + + if rank == 0: + print(f"\nCreated dummy Megatron model with TP={tensor_parallel_size}") + print("Model structure:") + for name, module in model.named_modules(): + if hasattr(module, "__class__"): + print(f" {name}: {module.__class__.__name__}") + # Test forward pass if torch.cuda.is_available(): - x = torch.randn(2, 4, 10).cuda() + batch_size = 2 + seq_length = 512 + hidden_size = 1024 # Must match model hidden size + + # Create input tensor + x = torch.randn(batch_size, seq_length, hidden_size).cuda() + + # Synchronize before forward pass + if torch.distributed.is_initialized(): + torch.distributed.barrier() + output = model(x) - print(f"Input shape: {x.shape}") - print(f"Output shape: {output.shape}") - - # # Print model structure - # print("\nModel structure:") - # for name, module in model.named_modules(): - # print(f" {name}: {module.__class__.__name__}") + + if rank == 0: + print("\nForward pass successful!") + print(f"Input shape: {x.shape}") + print(f"Output shape: {output.shape}") + + # Test LoRA with tensor parallel model lora_config = { "adapter_type": "lora", "adapter_name": "default", - "adapter_cfg": { - "*transformer*qkv*": {"rank": 64}, - "*ffn*": {"rank": 128}, - "*linear*": {"rank": 128} - } + "adapter_cfg": {"*linear*": {"rank": 64}, "*lm_head*": {"rank": 128}}, } - # model = mtp.update(model, mode=[("peft", lora_config)]) + + if rank == 0: + print("\nApplying LoRA configuration...") + model = mtp.update_model(model, lora_config) + + # Test forward pass with LoRA if torch.cuda.is_available(): - x = torch.randn(2, 4, 10).cuda() - output = model(x) - print(f"Input shape: {x.shape}") - print(f"Output shape: {output.shape}") - # mtq.quantize(model, mtq.MXFP4_DEFAULT_CFG) + output_lora = model(x) + if rank == 0: + print("LoRA forward pass successful!") + print(model) + print(model.linear_0.lora_a_default) + print(f"Output shape with LoRA: {output_lora.shape}") + + # Optional: Test quantization (commented out) + # if rank == 0: + # print(f"\nApplying quantization...") + # mtq.quantize(model, mtq.INT8_DEFAULT_CFG) + + except Exception as e: + print(f"Error on rank {os.environ.get('RANK', '0')}: {e}") + raise finally: # Clean up cleanup() - print("\nCleaned up distributed environment") \ No newline at end of file + if int(os.environ.get("RANK", "0")) == 0: + print("\nCleaned up distributed environment") diff --git a/test_single.py b/test_single.py new file mode 100644 index 00000000..729eab87 --- /dev/null +++ b/test_single.py @@ -0,0 +1,178 @@ +# dummy_megatron_model.py +import os + +import torch +import torch.nn.init as init +from megatron.core import parallel_state, tensor_parallel +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig + + +class DummyMegatronModel(MegatronModule): + """A simple dummy Megatron model with parallel linear layers for testing.""" + + def __init__(self, config: TransformerConfig): + super().__init__(config) + + # Column parallel linear layer (splits output dimension) + self.linear_0 = tensor_parallel.ColumnParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + gather_output=False, + ) + self.linear_1 = tensor_parallel.RowParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + input_is_parallel=True, + skip_bias_add=True, + ) + # Row parallel linear layer (splits input dimension) + self.lm_head_0 = tensor_parallel.ColumnParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + gather_output=False, + ) + self.lm_head_1 = tensor_parallel.RowParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + input_is_parallel=True, + skip_bias_add=True, + ) + + def forward(self, input): + x = self.linear_0(input)[0] + x = self.linear_1(x)[0] + x = self.lm_head_0(x)[0] + x = self.lm_head_1(x)[0] + return x + + +def initialize_distributed(rank=0, world_size=1): + """Initialize torch distributed for parallel training.""" + if torch.distributed.is_initialized(): + return + + print(f"Initializing torch.distributed with rank: {rank}, world_size: {world_size}") + torch.cuda.set_device(rank) + + init_method = "tcp://" + master_ip = os.getenv("MASTER_ADDR", "localhost") + master_port = os.getenv("MASTER_PORT", "6001") + init_method += master_ip + ":" + master_port + + torch.distributed.init_process_group( + backend="nccl", world_size=world_size, rank=rank, init_method=init_method + ) + + +def initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=None, +): + """Initialize Megatron's model parallel groups.""" + # Destroy existing model parallel if any + parallel_state.destroy_model_parallel() + + # Initialize distributed if not already done + if not torch.distributed.is_initialized(): + initialize_distributed() + + # Initialize model parallel groups + parallel_state.initialize_model_parallel( + tensor_model_parallel_size, + pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size, + pipeline_model_parallel_split_rank, + ) + + +def create_dummy_megatron_model(): + """Create and return a dummy Megatron model. + + Returns: + DummyMegatronModel: The initialized model on CUDA + """ + # Initialize model parallel (single GPU by default) + initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + # Set random seed for reproducibility + model_parallel_cuda_manual_seed(123) + + # Configure the transformer + transformer_config = { + "num_layers": 2, + "hidden_size": 12, + "num_attention_heads": 4, + "use_cpu_initialization": True, + } + config = TransformerConfig(**transformer_config) + + # Create and return the model + model = DummyMegatronModel(config=config) + + if torch.cuda.is_available(): + model = model.cuda() + + return model + + +def cleanup(): + """Clean up distributed and model parallel groups.""" + parallel_state.destroy_model_parallel() + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + # Example usage + try: + # Create the model + model = create_dummy_megatron_model() + print(f"Created dummy Megatron model: {model}") + # Test forward pass + if torch.cuda.is_available(): + x = torch.randn(2, 4, 10).cuda() + output = model(x) + print(f"Input shape: {x.shape}") + print(f"Output shape: {output.shape}") + + # # Print model structure + # print("\nModel structure:") + # for name, module in model.named_modules(): + # print(f" {name}: {module.__class__.__name__}") + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*transformer*qkv*": {"rank": 64}, + "*ffn*": {"rank": 128}, + "*linear*": {"rank": 128}, + }, + } + # model = mtp.update_model(model, lora_config) + # if torch.cuda.is_available(): + # x = torch.randn(2, 4, 10).cuda() + # output = model(x) + # print(f"Input shape: {x.shape}") + # print(f"Output shape: {output.shape}") + # mtq.quantize(model, mtq.MXFP4_DEFAULT_CFG) + finally: + # Clean up + cleanup() + print("\nCleaned up distributed environment") diff --git a/tests/gpu/torch/peft/test_forward_megatron.py b/tests/gpu/torch/peft/test_forward_megatron.py new file mode 100644 index 00000000..3b6b2be3 --- /dev/null +++ b/tests/gpu/torch/peft/test_forward_megatron.py @@ -0,0 +1,274 @@ +# import json +# from copy import deepcopy +# from functools import partial + +# import pytest +# import torch +# import transformers +# # from _test_utils.import_helper import skip_if_no_megatron +# # from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job +# from _test_utils.torch_dist.plugins.megatron_common import get_mcore_gpt_model +# # from _test_utils.torch_model.transformers_models import create_tiny_llama_dir + +# # skip_if_no_megatron(apex_or_te_required=True) + +# import modelopt.torch.speculative as mtsp +# from modelopt.torch.export import export_mcore_gpt_to_hf, import_mcore_gpt_from_hf +# from modelopt.torch.speculative.eagle.default_config import default_eagle_config +# from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel +# from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel +from warnings import warn + +import torch +import torch.nn.functional as F +from megatron.core.models.gpt import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.parallel_state import ( + initialize_model_parallel, + is_pipeline_first_stage, + is_pipeline_last_stage, +) +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_config import TransformerConfig + +from modelopt.torch.utils.plugins import megatron_prefill + +try: + from megatron.core.extensions.transformer_engine import TENorm + from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec + + HAS_TE = True +except ImportError as e: + warn(f"Transformer Engine not installed: {e}") + HAS_TE = False + +try: + from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec + from megatron.core.ssm.mamba_layer import MambaLayer + + HAS_MAMBA = True +except ImportError as e: + warn(f"Mamba not installed: {e}") + HAS_MAMBA = False + +try: + import apex # noqa: F401 + + HAS_APEX = True +except ImportError as e: + warn(f"Apex not installed: {e}") + HAS_APEX = False +import modelopt.torch.peft as mtp + +lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*": {"rank": 64}, + }, +} + + +def initialize_for_megatron( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234 +): + """Initialize Megatron model parallelism. + + NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`. + """ + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size) + model_parallel_cuda_manual_seed(seed) + + +def get_mcore_gpt_model( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + initialize_megatron: bool = False, + *, + num_layers: int = 2, + num_layers_in_first_pipeline_stage: int | None = None, + num_layers_in_last_pipeline_stage: int | None = None, + hidden_size: int = 64, + num_attention_heads: int = 8, + num_query_groups: int | None = None, + ffn_hidden_size: int | None = 128, + max_sequence_length: int = 16, + vocab_size: int = 64, + activation_func: str = "swiglu", + normalization: str = "LayerNorm", + transformer_impl: str = "modelopt" if HAS_TE else "local", + use_cpu_initialization: bool = False, + bf16: bool = True, +) -> GPTModel: + assert activation_func in ["swiglu", "squared_relu"] + assert normalization in ["LayerNorm", "RMSNorm"] + assert transformer_impl in ["local", "transformer_engine", "modelopt"] + print(f"Using `{transformer_impl=}` model spec for building GPT Model.") + + if initialize_megatron: + initialize_for_megatron(tensor_model_parallel_size, pipeline_model_parallel_size) + + def squared_relu(x): + return torch.pow(F.relu(x), 2) + + config = TransformerConfig( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + sequence_parallel=False, + num_layers=num_layers, + num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, + num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + activation_func=squared_relu if activation_func == "squared_relu" else F.silu, + normalization=normalization, + gated_linear_unit=(activation_func == "swiglu"), + add_bias_linear=False, + use_cpu_initialization=use_cpu_initialization, + pipeline_dtype=torch.bfloat16 if bf16 else torch.float32, + bf16=bf16, + ) + + if transformer_impl == "local": + assert HAS_APEX, "Apex not installed" + transformer_layer_spec = get_gpt_layer_local_spec(normalization=normalization) + else: + assert HAS_TE, "Transformer Engine not installed" + transformer_layer_spec = ( + get_gpt_modelopt_spec(config, remap_te_layernorm=True) + if transformer_impl == "modelopt" + else get_gpt_layer_with_transformer_engine_spec() + ) + + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + pre_process=is_pipeline_first_stage(), + post_process=is_pipeline_last_stage(), + share_embeddings_and_output_weights=False, + position_embedding_type="rope", + ) + if bf16: + model = model.to(torch.bfloat16) + + return model + + +def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): + """Build the model.""" + + if meta_device: + with torch.device("meta"): + gpt_model = get_mcore_gpt_model( + tensor_model_parallel_size=tp_size, + num_layers=4, + ffn_hidden_size=None, + num_attention_heads=4, + activation_func="squared_relu", + transformer_impl="local", + hidden_size=hidden_size, + vocab_size=vocab_size, + use_cpu_initialization=meta_device, + ) + else: + gpt_model = get_mcore_gpt_model( + tensor_model_parallel_size=tp_size, + num_layers=4, + ffn_hidden_size=None, + num_attention_heads=4, + activation_func="squared_relu", + transformer_impl="local", + hidden_size=hidden_size, + vocab_size=vocab_size, + ).cuda() + return gpt_model.eval() + + +import os + +from megatron.core import parallel_state + +from tests.gpu.torch.peft.test_forward_megatron import ( + _gpt_model_provider, + initialize_for_megatron, + megatron_prefill, +) + + +def _test_lora_forward(): + """Test LoRA forward pass with Megatron model.""" + # Initialize model parallel groups with proper CUDA RNG seed + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + # Create model + model = _gpt_model_provider(tp_size=1) + + # Create input tokens + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + # Run forward pass + output = megatron_prefill(model, prompt_tokens) + print( + f"Forward pass successful! Output shape: {output.shape if hasattr(output, 'shape') else 'N/A'}" + ) + + # Now test with LoRA + + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": {"*attention*": {"rank": 32}, "*mlp*": {"rank": 64}}, + } + + # # Apply LoRA + model = mtp.update_model(model, lora_config) + print("LoRA adapters added successfully!") + + # Test forward pass with LoRA + output_lora = megatron_prefill(model, prompt_tokens) + print( + f"LoRA forward pass successful! Output shape: {output_lora.shape if hasattr(output_lora, 'shape') else 'N/A'}" + ) + + # Check if LoRA modules were added + lora_count = 0 + for name, module in model.named_modules(): + if hasattr(module, "_lora_adapters"): + lora_count += 1 + print(f"LoRA module found: {name}") + + print(f"\nTotal LoRA modules: {lora_count}") + print("Test passed!") + + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def main(): + """Main function to setup distributed and run test.""" + # Setup distributed environment + if not torch.distributed.is_initialized(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) + + try: + _test_lora_forward() + finally: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() From 63e64c529307ab727c249674778eb170af5af789 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 18 Sep 2025 04:33:42 +0000 Subject: [PATCH 05/15] Add more functions Signed-off-by: Jingyu Xin --- modelopt/torch/peft/config.py | 11 +- modelopt/torch/peft/conversion.py | 1 + modelopt/torch/peft/convert.py | 65 +++-- modelopt/torch/peft/lora/layer.py | 44 ++- modelopt/torch/peft/lora/tp_layer.py | 53 +++- modelopt/torch/peft/plugins/megatron.py | 41 +-- tests/gpu/torch/peft/test_forward_megatron.py | 264 +++++++++++++++++- 7 files changed, 388 insertions(+), 91 deletions(-) diff --git a/modelopt/torch/peft/config.py b/modelopt/torch/peft/config.py index cd007156..f30a20c4 100644 --- a/modelopt/torch/peft/config.py +++ b/modelopt/torch/peft/config.py @@ -13,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable from typing import Literal from pydantic import ValidationInfo, field_validator, model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.utils.network import ConstructorLike + BiasType = Literal["static", "dynamic"] BiasMethod = Literal["mean", "max_min"] + class QuantizerAttributeConfig(ModeloptBaseConfig): """Quantizer attribute type.""" @@ -358,9 +359,10 @@ class SVDQuantConfig(QuantizeAlgorithmConfig): # QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None -#TODO Jingyu Xin + +# TODO Jingyu Xin class PEFTConfig(ModeloptBaseConfig): - """Default configuration for ``quantize`` mode.""" + """Default configuration for ``peft`` mode.""" adapter_name: str = ModeloptField( default="default", @@ -380,8 +382,11 @@ class PEFTConfig(ModeloptBaseConfig): validate_default=True, ) + class ExportPEFTConfig(ModeloptBaseConfig): """An empty config.""" + + class CompressConfig(ModeloptBaseConfig): """Default configuration for ``compress`` mode.""" diff --git a/modelopt/torch/peft/conversion.py b/modelopt/torch/peft/conversion.py index 40873bf4..91e930ba 100644 --- a/modelopt/torch/peft/conversion.py +++ b/modelopt/torch/peft/conversion.py @@ -42,6 +42,7 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert # set_quantizer_by_cfg(model, config.get("quant_cfg", {})) metadata = {} + # Should return adapaters, active_adapters update_peft_metadata(model, config, metadata) return model, metadata diff --git a/modelopt/torch/peft/convert.py b/modelopt/torch/peft/convert.py index 7e656136..43850234 100644 --- a/modelopt/torch/peft/convert.py +++ b/modelopt/torch/peft/convert.py @@ -20,47 +20,32 @@ import torch.nn as nn -# import modelopt.torch.quantization as mtq from modelopt.torch.opt import apply_mode - -# from modelopt.torch.quantization.conversion import set_quantizer_by_cfg from modelopt.torch.opt.conversion import ModeloptStateManager - -# from modelopt.torch.opt.searcher import ForwardLoop -# from modelopt.torch.opt.utils import forward_with_reshard from modelopt.torch.peft.config import PEFTConfig from .lora.layer import LoRAModule - -# from . import config -# from .algorithms import AutoQuantizeSearcher -# from .config import QuantizeAlgoCfgType -# from .conversion import set_quantizer_attribute from .mode import PEFTModeRegistry -# from .nn import QuantModule, TensorQuantizer - -# __all__ = [ -# "auto_quantize", -# "calibrate", -# "disable_quantizer", -# "enable_quantizer", -# "fold_weight", -# "postprocess_amax", -# "print_quant_summary", -# "quantize", -# ] - def update_model( model: nn.Module, config: dict[str, Any | PEFTConfig], ): - # TODO: deal with extra state, how to save the model - # TODO: sharded dict - # TODO: metadate - # TODO: how to restore the model - apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry) + """Update model with PEFT/LoRA adapters. + This function handles both initial PEFT conversion and adding additional adapters: + - First call: Converts modules to LoRAModules and adds the first adapter + - Subsequent calls: Adds new adapters to existing LoRAModules + Args: + model: The model to update + config: PEFT configuration containing adapter settings + Returns: + The updated model with LoRA adapters + """ + # Check if model is already in PEFT mode by looking for LoRA modules + if not is_peft_model(model): + # First time - need to convert to PEFT mode + apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry) return add_adapter(model, config) @@ -79,7 +64,9 @@ def add_adapter(model, config): continue else: raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") - module.update_layer_lora(adapter_name, adapter_setting["rank"]) + module.update_layer_lora( + adapter_name, adapter_setting["rank"], adapter_setting.get("scale", 1.0) + ) # Update the metadata in ModeloptStateManager after adding adapters _update_peft_metadata_in_state(model) @@ -111,3 +98,21 @@ def _update_peft_metadata_in_state(model: nn.Module) -> None: # Update the metadata in the last mode state (which should be 'peft') if manager._state and manager._last_metadata is not None: manager._last_metadata["peft_state"] = current_peft_state + + +def is_peft_model(model: nn.Module) -> bool: + """Check if the model has been converted to PEFT/LoRA model. + + This function checks if any modules in the model are LoRAModule instances, + which indicates the model has already been converted to PEFT mode. + + Args: + model: The model to check + + Returns: + True if the model contains LoRA modules, False otherwise + """ + for _, module in model.named_modules(): + if isinstance(module, LoRAModule): + return True + return False diff --git a/modelopt/torch/peft/lora/layer.py b/modelopt/torch/peft/lora/layer.py index 82f41812..470e6484 100644 --- a/modelopt/torch/peft/lora/layer.py +++ b/modelopt/torch/peft/lora/layer.py @@ -73,7 +73,7 @@ def deactivate_all_adapters(self) -> None: self._active_adapters.clear() def _register_adapter( - self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int + self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int, scale: float = 1.0 ) -> None: """Register a new LoRA adapter with explicit rank tracking. @@ -82,6 +82,7 @@ def _register_adapter( lora_a: LoRA A module (down-projection) lora_b: LoRA B module (up-projection) rank: Rank of the LoRA decomposition + scale: Scale factor for the LoRA output """ # Add as submodules for proper parameter registration self.add_module(f"lora_a_{adapter_name}", lora_a) @@ -92,13 +93,14 @@ def _register_adapter( "lora_a": lora_a, "lora_b": lora_b, "rank": rank, # Store rank explicitly for reliability + "scale": scale, } # Automatically activate new adapters self.activate_adapter(adapter_name) @abstractmethod - def update_layer_lora(self, adapter_name: str, rank: int = 64) -> None: + def update_layer_lora(self, adapter_name: str, rank: int = 64, scale: float = 1.0) -> None: """Create and register a new LoRA adapter. This method must be implemented by subclasses to create the appropriate @@ -107,6 +109,7 @@ def update_layer_lora(self, adapter_name: str, rank: int = 64) -> None: Args: adapter_name: Name for the new adapter rank: Rank of the LoRA decomposition (default: 64) + scale: Scale factor for the LoRA output (default: 1.0) """ raise NotImplementedError("Subclasses must implement update_layer_lora") @@ -148,14 +151,12 @@ def get_peft_state(self) -> dict[str, Any]: "is_active": adapter_name in self._active_adapters, "lora_a_type": type(lora_a).__name__, "lora_b_type": type(lora_b).__name__, + "scale": adapter_modules.get("scale", 1.0), } modelopt_state["adapters"] = adapters_config modelopt_state["active_adapters"] = list(self._active_adapters) - # Store the base module type for validation - modelopt_state["base_module_type"] = type(self).__name__ - return modelopt_state def get_extra_state(self) -> dict[str, Any]: @@ -177,6 +178,36 @@ def get_extra_state(self) -> dict[str, Any]: return {"modelopt_peft_state": peft_state} + def set_from_peft_state(self, peft_state: dict[str, Any]) -> None: + """Restore LoRA adapters from saved PEFT state. + + This method recreates LoRA adapters based on their saved configuration. + Note: This only restores the adapter structure, not the weights. + + Args: + peft_state: Dictionary containing adapter configurations + """ + adapters_config = peft_state.get("adapters", {}) + + # Clear existing adapters first + self._lora_adapters.clear() + self._active_adapters.clear() + + # Recreate each adapter based on saved configuration + for adapter_name, config in adapters_config.items(): + rank = config.get("rank") + scale = config.get("scale", 1.0) + + if rank is not None: + # Create the adapter with saved configuration + self.update_layer_lora(adapter_name, rank=rank, scale=scale) + + # Set activation state + if config.get("is_active", False): + self.activate_adapter(adapter_name) + else: + self.deactivate_adapter(adapter_name) + def set_extra_state(self, state: dict[str, Any]) -> None: """Restore extra state for distributed checkpointing. @@ -245,7 +276,8 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Any: if isinstance(lora_b_output, tuple): lora_b_output = lora_b_output[0] - result = result + lora_b_output + scale = adapter.get("scale", 1.0) + result = result + scale * lora_b_output # Return output in the same format as the base layer if other_outputs: diff --git a/modelopt/torch/peft/lora/tp_layer.py b/modelopt/torch/peft/lora/tp_layer.py index cceeac1a..bd80d6b5 100644 --- a/modelopt/torch/peft/lora/tp_layer.py +++ b/modelopt/torch/peft/lora/tp_layer.py @@ -9,8 +9,20 @@ from .layer import LoRAModule, LoRAModuleRegistry -# Default rank for LoRA decomposition +try: + from modelopt.torch.quantization.plugins.megatron import ( + _MegatronColumnParallelLinear as QuantColumnParallelLinear, + ) + from modelopt.torch.quantization.plugins.megatron import ( + _MegatronRowParallelLinear as QuantRowParallelLinear, + ) + + QUANT_MODULES_AVAILABLE = True +except ImportError: + QUANT_MODULES_AVAILABLE = False + DEFAULT_LORA_RANK = 64 +DEFAULT_SCALE = 1.0 class _MegatronParallelLoRABase(LoRAModule): @@ -33,7 +45,7 @@ def _get_init_methods(self) -> tuple[Callable, Callable]: return lora_a_init, lora_b_init def _register_adapter_with_device( - self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int + self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int, scale: float ) -> None: """Register LoRA adapter modules and ensure correct device placement. @@ -43,23 +55,29 @@ def _register_adapter_with_device( lora_b: LoRA B module (up-projection) rank: Rank of the LoRA decomposition """ - # Move LoRA modules to the same device as the parent module - # Try to get device from parent module's parameters or buffers + # Move LoRA modules to the same device and dtype as the parent module + # Try to get device and dtype from parent module's parameters or buffers device = None + dtype = None for p in self.parameters(): device = p.device + dtype = p.dtype break if device is None: for b in self.buffers(): device = b.device + dtype = b.dtype break - # If we found a device, move LoRA modules to it + # If we found a device and dtype, move LoRA modules to match if device is not None: lora_a = lora_a.to(device) lora_b = lora_b.to(device) + if dtype is not None: + lora_a = lora_a.to(dtype) + lora_b = lora_b.to(dtype) - super()._register_adapter(adapter_name, lora_a, lora_b, rank) + super()._register_adapter(adapter_name, lora_a, lora_b, rank, scale) @LoRAModuleRegistry.register({ColumnParallelLinear: "megatron_ColumnParallelLinear"}) @@ -70,7 +88,9 @@ class _MegatronColumnParallelLinear(_MegatronParallelLoRABase): the parallelization scheme of the base layer. """ - def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> None: + def update_layer_lora( + self, adapter_name: str, rank: int = DEFAULT_LORA_RANK, scale: float = DEFAULT_SCALE + ) -> None: """Create and register a new LoRA adapter for ColumnParallelLinear. Args: @@ -100,7 +120,7 @@ def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> init_method=lora_b_init, ) - self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank) + self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank, scale) @LoRAModuleRegistry.register({RowParallelLinear: "megatron_RowParallelLinear"}) @@ -111,7 +131,9 @@ class _MegatronRowParallelLinear(_MegatronParallelLoRABase): the parallelization scheme of the base layer. """ - def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> None: + def update_layer_lora( + self, adapter_name: str, rank: int = DEFAULT_LORA_RANK, scale: float = DEFAULT_SCALE + ) -> None: """Create and register a new LoRA adapter for RowParallelLinear. Args: @@ -141,4 +163,15 @@ def update_layer_lora(self, adapter_name: str, rank: int = DEFAULT_LORA_RANK) -> init_method=lora_b_init, ) - self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank) + self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank, scale) + + +# Register quantized versions if available +if QUANT_MODULES_AVAILABLE: + # Register the same LoRA implementations for quantized modules + LoRAModuleRegistry.register({QuantColumnParallelLinear: "quant_megatron_ColumnParallelLinear"})( + _MegatronColumnParallelLinear + ) + LoRAModuleRegistry.register({QuantRowParallelLinear: "quant_megatron_RowParallelLinear"})( + _MegatronRowParallelLinear + ) diff --git a/modelopt/torch/peft/plugins/megatron.py b/modelopt/torch/peft/plugins/megatron.py index 6e349cb2..6390f7b3 100644 --- a/modelopt/torch/peft/plugins/megatron.py +++ b/modelopt/torch/peft/plugins/megatron.py @@ -15,13 +15,8 @@ """Megatron-Core specific PEFT/LoRA plugins.""" -from typing import Any - import torch -from modelopt.torch.opt.plugins.megatron import register_modelopt_extra_state_callbacks -from modelopt.torch.peft.lora.layer import LoRAModuleRegistry - # Import MegatronModule if available try: from megatron.core.transformer.module import MegatronModule @@ -36,48 +31,19 @@ __all__ = [] -def lora_module_get_extra_state(self) -> dict: - """Get extra state for LoRA modules. - - This is called by the modelopt extra state framework to gather - PEFT/LoRA state for distributed checkpointing. - """ - # LoRAModule already has get_extra_state method - return self.get_extra_state() - - -def lora_module_set_extra_state(self, state: Any): - """Set extra state for LoRA modules. - - This is called by the modelopt extra state framework to restore - PEFT/LoRA state from distributed checkpoints. - """ - # LoRAModule already has set_extra_state method - self.set_extra_state(state) - - def megatron_replace_lora_module_hook(model: torch.nn.Module): """Configure Megatron-Core model PEFT/LoRA support. This callback is called before the LoRAModule replacement to configure distributed checkpointing support. For each MegatronModule: 1. We enable heterogeneous distributed checkpointing - 2. We register extra_state callbacks for all LoRAModule submodules + + Note: LoRAModule already has built-in get_extra_state and set_extra_state methods, + so we don't need to register callbacks for them. """ if not MEGATRON_AVAILABLE: return - def _register_extra_state_callbacks(model: torch.nn.Module): - """Register extra state callbacks for LoRA modules.""" - for name, module in model.named_modules(): - if type(module) in LoRAModuleRegistry: - # This module will be replaced as a LoRAModule - register_modelopt_extra_state_callbacks( - module, - lora_module_get_extra_state, - lora_module_set_extra_state, - ) - for name, module in model.named_modules(): if isinstance(module, MegatronModule): # Enable heterogeneous distributed checkpointing @@ -85,7 +51,6 @@ def _register_extra_state_callbacks(model: torch.nn.Module): module.config, "hetereogenous_dist_checkpoint" ): module.config.hetereogenous_dist_checkpoint = True - _register_extra_state_callbacks(module) # Register the hook diff --git a/tests/gpu/torch/peft/test_forward_megatron.py b/tests/gpu/torch/peft/test_forward_megatron.py index 3b6b2be3..9a7ce837 100644 --- a/tests/gpu/torch/peft/test_forward_megatron.py +++ b/tests/gpu/torch/peft/test_forward_megatron.py @@ -17,10 +17,13 @@ # from modelopt.torch.speculative.eagle.default_config import default_eagle_config # from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel # from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel +import copy +import os from warnings import warn import torch import torch.nn.functional as F +from megatron.core import dist_checkpointing from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_local_spec, @@ -34,6 +37,10 @@ from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig +from modelopt.torch.opt.plugins.mcore_dist_checkpointing import ( + restore_sharded_modelopt_state, + save_sharded_modelopt_state, +) from modelopt.torch.utils.plugins import megatron_prefill try: @@ -62,6 +69,7 @@ warn(f"Apex not installed: {e}") HAS_APEX = False import modelopt.torch.peft as mtp +import modelopt.torch.quantization as mtq lora_config = { "adapter_type": "lora", @@ -191,8 +199,6 @@ def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_devic return gpt_model.eval() -import os - from megatron.core import parallel_state from tests.gpu.torch.peft.test_forward_megatron import ( @@ -225,12 +231,79 @@ def _test_lora_forward(): lora_config = { "adapter_type": "lora", "adapter_name": "default", - "adapter_cfg": {"*attention*": {"rank": 32}, "*mlp*": {"rank": 64}}, + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + + # # Apply LoRA + model = mtp.update_model(model, lora_config) + print("LoRA adapters added successfully!") + + # Test forward pass with LoRA + output_lora = megatron_prefill(model, prompt_tokens) + print( + f"LoRA forward pass successful! Output shape: {output_lora.shape if hasattr(output_lora, 'shape') else 'N/A'}" + ) + + # Check if LoRA modules were added + lora_count = 0 + for name, module in model.named_modules(): + if hasattr(module, "_lora_adapters"): + lora_count += 1 + print(f"LoRA module found: {name}") + + print(f"\nTotal LoRA modules: {lora_count}") + print("Test passed!") + + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def _test_lora_add_2nd_lora(): + """Test LoRA forward pass with Megatron model.""" + # Initialize model parallel groups with proper CUDA RNG seed + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + # Create model + model = _gpt_model_provider(tp_size=1) + + # Create input tokens + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + # Run forward pass + output = megatron_prefill(model, prompt_tokens) + print( + f"Forward pass successful! Output shape: {output.shape if hasattr(output, 'shape') else 'N/A'}" + ) + + # Now test with LoRA + + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + + lora_2d_config = { + "adapter_type": "lora", + "adapter_name": "2nd", + "adapter_cfg": { + "*attention*": {"rank": 128, "scale": 1}, + "*mlp*": {"rank": 128, "scale": 1}, + }, } # # Apply LoRA model = mtp.update_model(model, lora_config) print("LoRA adapters added successfully!") + model = mtp.update_model(model, lora_2d_config) # Test forward pass with LoRA output_lora = megatron_prefill(model, prompt_tokens) @@ -253,6 +326,185 @@ def _test_lora_forward(): parallel_state.destroy_model_parallel() +def _test_lora_save_and_restore(): + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + model_ref = _gpt_model_provider(tp_size=1) + model_test = _gpt_model_provider(tp_size=1) + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + model_ref = mtp.update_model(model_ref, lora_config) + state_dict = copy.deepcopy(model_ref.state_dict()) + tmp_path = "./model_ref" + save_distributed_checkpoint(tmp_path, model_ref) + save_sharded_modelopt_state([model_ref], tmp_path) + restore_sharded_modelopt_state([model_test], tmp_path) + model_test = load_distributed_checkpoint(tmp_path, model_test) + + prompt_tokens = torch.randint( + 0, model_test.vocab_size, (2, model_test.max_sequence_length) + ).cuda() + + # Run forward pass + output_test = megatron_prefill(model_test, prompt_tokens) + output_ref = megatron_prefill(model_ref, prompt_tokens) + print( + f"Forward pass successful! Output shape: {output_test.shape if hasattr(output_test, 'shape') else 'N/A'}" + ) + print(output_test) + print(output_ref) + + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def _test_lora_save_and_restore_with2loras(): + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + model_ref = _gpt_model_provider(tp_size=1) + model_test = _gpt_model_provider(tp_size=1) + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 10}, + }, + } + lora_2d_config = { + "adapter_type": "lora", + "adapter_name": "2nd", + "adapter_cfg": { + "*attention*": {"rank": 128, "scale": 1}, + "*mlp*": {"rank": 128, "scale": 1}, + }, + } + lora_3d_config = { + "adapter_type": "lora", + "adapter_name": "3rd", + "adapter_cfg": { + "*attention*": {"rank": 128, "scale": 1}, + "*mlp*": {"rank": 128, "scale": 1}, + }, + } + model_ref = mtp.update_model(model_ref, lora_config) + model_ref = mtp.update_model(model_ref, lora_2d_config) + tmp_path = "./model_ref" + save_distributed_checkpoint(tmp_path, model_ref) + save_sharded_modelopt_state([model_ref], tmp_path) + restore_sharded_modelopt_state([model_test], tmp_path) + model_test = load_distributed_checkpoint(tmp_path, model_test) + # model_test = mtp.update_model(model_test, lora_3d_config) + + # Debug: Check active adapters + print("\n=== Active Adapters ===") + for name, module in model_test.named_modules(): + if hasattr(module, "_lora_adapters") and module._lora_adapters: + print( + f"{name}: adapters={list(module._lora_adapters.keys())}, active={list(module._active_adapters)}" + ) + break # Just show one module as example + + prompt_tokens = torch.randint( + 0, model_test.vocab_size, (2, model_test.max_sequence_length) + ).cuda() + + # Run forward pass + output_test = megatron_prefill(model_test, prompt_tokens) + output_ref = megatron_prefill(model_ref, prompt_tokens) + print( + f"Forward pass successful! Output shape: {output_test.shape if hasattr(output_test, 'shape') else 'N/A'}" + ) + print(output_test) + print(output_ref) + + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def _test_quantize_then_lora(): + """Test LoRA forward pass with Megatron model.""" + # Initialize model parallel groups with proper CUDA RNG seed + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + model = _gpt_model_provider(tp_size=1) + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + + def forward_func(mod): + output = megatron_prefill(model, prompt_tokens) + + mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_func) + model = mtp.update_model(model, lora_config) + lora_count = 0 + for name, module in model.named_modules(): + if hasattr(module, "_lora_adapters"): + lora_count += 1 + print(f"LoRA module found: {name}") + print(f"\nTotal LoRA modules: {lora_count}") + output_lora_quant = megatron_prefill(model, prompt_tokens) + print( + f"LoRA forward pass successful! Output shape: {output_lora_quant.shape if hasattr(output_lora_quant, 'shape') else 'N/A'}" + ) + print("Test passed!") + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def _test_quantize_then_lora_save_restore(): + pass + + +def _test_lora_then_quantize(): + pass + + +def _test_lora_then_quantize_save_restore(): + pass + + +def _test_disable_lora(): + pass + + +def _test_disable_lora_restore(): + pass + + +def save_distributed_checkpoint(checkpoint_path, gpt_model): + os.makedirs(checkpoint_path, exist_ok=True) + sharded_state_dict = gpt_model.sharded_state_dict(prefix="") + dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path) + + +def load_distributed_checkpoint(checkpoint_path, gpt_model): + sharded_state_dict = gpt_model.sharded_state_dict(prefix="") + checkpoint = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path + ) + gpt_model.load_state_dict(checkpoint) + return gpt_model + + def main(): """Main function to setup distributed and run test.""" # Setup distributed environment @@ -264,7 +516,11 @@ def main(): torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) try: - _test_lora_forward() + # _test_lora_forward() + # _test_lora_save_and_restore() + # _test_lora_add_2nd_lora() + # _test_lora_save_and_restore_with2loras() + _test_quantize_then_lora() finally: if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() From bf33ae45038829227b50c5febe10eaaad7ddac76 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 18 Sep 2025 21:27:40 +0000 Subject: [PATCH 06/15] Update: to support quantize the lora layers Signed-off-by: Jingyu Xin --- modelopt/torch/peft/__init__.py | 1 + modelopt/torch/peft/lora/tp_layer.py | 48 ++++++++++++++++++++----- modelopt/torch/peft/plugins/__init__.py | 9 +++-- 3 files changed, 45 insertions(+), 13 deletions(-) diff --git a/modelopt/torch/peft/__init__.py b/modelopt/torch/peft/__init__.py index 3a500359..4874cbeb 100644 --- a/modelopt/torch/peft/__init__.py +++ b/modelopt/torch/peft/__init__.py @@ -17,6 +17,7 @@ from . import mode from .config import * +from .conversion import * from .convert import * # isort: off diff --git a/modelopt/torch/peft/lora/tp_layer.py b/modelopt/torch/peft/lora/tp_layer.py index bd80d6b5..8f466cfa 100644 --- a/modelopt/torch/peft/lora/tp_layer.py +++ b/modelopt/torch/peft/lora/tp_layer.py @@ -38,10 +38,8 @@ def _get_init_methods(self) -> tuple[Callable, Callable]: Returns: Tuple of (lora_a_init, lora_b_init) initialization functions """ - # LoRA A uses Kaiming uniform initialization - lora_a_init = lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)) - # LoRA B is initialized to zero for stable training start - lora_b_init = lambda weight: init.zeros_(weight) + lora_a_init = lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)) # noqa: E731 # LoRA A: Kaiming uniform + lora_b_init = lambda weight: init.zeros_(weight) # noqa: E731 # LoRA B: zeros return lora_a_init, lora_b_init def _register_adapter_with_device( @@ -81,7 +79,7 @@ def _register_adapter_with_device( @LoRAModuleRegistry.register({ColumnParallelLinear: "megatron_ColumnParallelLinear"}) -class _MegatronColumnParallelLinear(_MegatronParallelLoRABase): +class _LoRAMegatronColumnParallelLinear(_MegatronParallelLoRABase): """LoRA implementation for Megatron ColumnParallelLinear layers. This implementation creates column-parallel LoRA adapters that match @@ -124,7 +122,7 @@ def update_layer_lora( @LoRAModuleRegistry.register({RowParallelLinear: "megatron_RowParallelLinear"}) -class _MegatronRowParallelLinear(_MegatronParallelLoRABase): +class _LoRAMegatronRowParallelLinear(_MegatronParallelLoRABase): """LoRA implementation for Megatron RowParallelLinear layers. This implementation creates row-parallel LoRA adapters that match @@ -170,8 +168,42 @@ def update_layer_lora( if QUANT_MODULES_AVAILABLE: # Register the same LoRA implementations for quantized modules LoRAModuleRegistry.register({QuantColumnParallelLinear: "quant_megatron_ColumnParallelLinear"})( - _MegatronColumnParallelLinear + _LoRAMegatronColumnParallelLinear ) LoRAModuleRegistry.register({QuantRowParallelLinear: "quant_megatron_RowParallelLinear"})( - _MegatronRowParallelLinear + _LoRAMegatronRowParallelLinear ) + + from modelopt.torch.quantization.nn import QuantModuleRegistry + + class _QuantLoRAMegatronColumnParallelLinear( + _LoRAMegatronColumnParallelLinear, QuantColumnParallelLinear + ): + """Quantized LoRA ColumnParallelLinear that combines LoRA and quantization. + + This class ensures that the base layer functionality is quantized while + preserving LoRA adapter functionality. + """ + + def _setup(self): + QuantColumnParallelLinear._setup(self) + + class _QuantLoRAMegatronRowParallelLinear( + _LoRAMegatronRowParallelLinear, QuantRowParallelLinear + ): + """Quantized LoRA RowParallelLinear that combines LoRA and quantization. + + This class ensures that the base layer functionality is quantized while + preserving LoRA adapter functionality. + """ + + def _setup(self): + QuantRowParallelLinear._setup(self) + + # Register LoRA modules in QuantModuleRegistry so they can be quantized + QuantModuleRegistry.register( + {_LoRAMegatronColumnParallelLinear: "lora_megatron_ColumnParallelLinear"} + )(_QuantLoRAMegatronColumnParallelLinear) + QuantModuleRegistry.register( + {_LoRAMegatronRowParallelLinear: "lora_megatron_RowParallelLinear"} + )(_QuantLoRAMegatronRowParallelLinear) diff --git a/modelopt/torch/peft/plugins/__init__.py b/modelopt/torch/peft/plugins/__init__.py index d760b877..03cd81fe 100644 --- a/modelopt/torch/peft/plugins/__init__.py +++ b/modelopt/torch/peft/plugins/__init__.py @@ -15,8 +15,7 @@ """PEFT/LoRA plugins for various frameworks.""" -# Import plugins to register them -try: - from . import megatron -except ImportError: - pass # Megatron not available +from contextlib import suppress + +with suppress(ImportError): + from . import megatron as _megatron From 2b09c92c95fc351870921e0594f138cf334a9f3a Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 18 Sep 2025 21:28:30 +0000 Subject: [PATCH 07/15] Update test cases Signed-off-by: Jingyu Xin --- tests/gpu/torch/peft/test_forward_megatron.py | 81 ++++++++++++++++++- 1 file changed, 78 insertions(+), 3 deletions(-) diff --git a/tests/gpu/torch/peft/test_forward_megatron.py b/tests/gpu/torch/peft/test_forward_megatron.py index 9a7ce837..a4b96286 100644 --- a/tests/gpu/torch/peft/test_forward_megatron.py +++ b/tests/gpu/torch/peft/test_forward_megatron.py @@ -464,6 +464,7 @@ def forward_func(mod): print( f"LoRA forward pass successful! Output shape: {output_lora_quant.shape if hasattr(output_lora_quant, 'shape') else 'N/A'}" ) + print(model) print("Test passed!") finally: # Clean up model parallel groups @@ -471,11 +472,83 @@ def forward_func(mod): def _test_quantize_then_lora_save_restore(): - pass + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + model_ref = _gpt_model_provider(tp_size=1) + model_test = _gpt_model_provider(tp_size=1) + prompt_tokens = torch.randint( + 0, model_test.vocab_size, (2, model_test.max_sequence_length) + ).cuda() + + def forward_func(mod): + output = megatron_prefill(model_ref, prompt_tokens) + + mtq.quantize(model_ref, mtq.FP8_DEFAULT_CFG, forward_func) + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + model_ref = mtp.update_model(model_ref, lora_config) + tmp_path = "./model_ref" + save_distributed_checkpoint(tmp_path, model_ref) + save_sharded_modelopt_state([model_ref], tmp_path) + restore_sharded_modelopt_state([model_test], tmp_path) + model_test = load_distributed_checkpoint(tmp_path, model_test) + # Run forward pass + output_test = megatron_prefill(model_test, prompt_tokens) + output_ref = megatron_prefill(model_ref, prompt_tokens) + print( + f"Forward pass successful! Output shape: {output_test.shape if hasattr(output_test, 'shape') else 'N/A'}" + ) + print(model_ref) + print(f"output_test: {output_test}") + print(f"output_ref: {output_ref}") + + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() def _test_lora_then_quantize(): - pass + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + model = _gpt_model_provider(tp_size=1) + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + + def forward_func(mod): + output = megatron_prefill(model, prompt_tokens) + + model = mtp.update_model(model, lora_config) + mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_func) + lora_count = 0 + for name, module in model.named_modules(): + if hasattr(module, "_lora_adapters"): + lora_count += 1 + print(f"LoRA module found: {name}") + print(f"\nTotal LoRA modules: {lora_count}") + output_lora_quant = megatron_prefill(model, prompt_tokens) + print( + f"LoRA forward pass successful! Output shape: {output_lora_quant.shape if hasattr(output_lora_quant, 'shape') else 'N/A'}" + ) + print("Test passed!") + print(model) + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() def _test_lora_then_quantize_save_restore(): @@ -520,7 +593,9 @@ def main(): # _test_lora_save_and_restore() # _test_lora_add_2nd_lora() # _test_lora_save_and_restore_with2loras() - _test_quantize_then_lora() + # _test_quantize_then_lora() + # _test_quantize_then_lora_save_restore() + _test_lora_then_quantize() finally: if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() From 913e535b53f87977062dec2491c4e770ceab6ed7 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 18 Sep 2025 21:40:17 +0000 Subject: [PATCH 08/15] Clean up code Signed-off-by: Jingyu Xin --- modelopt/torch/peft/config.py | 397 +-------------------------- modelopt/torch/peft/conversion.py | 27 +- modelopt/torch/peft/convert.py | 19 +- modelopt/torch/peft/lora/__init__.py | 6 +- 4 files changed, 25 insertions(+), 424 deletions(-) diff --git a/modelopt/torch/peft/config.py b/modelopt/torch/peft/config.py index f30a20c4..838f1c4e 100644 --- a/modelopt/torch/peft/config.py +++ b/modelopt/torch/peft/config.py @@ -13,354 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal - -from pydantic import ValidationInfo, field_validator, model_validator +"""Configuration classes for PEFT methods.""" from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField -from modelopt.torch.utils.network import ConstructorLike - -BiasType = Literal["static", "dynamic"] -BiasMethod = Literal["mean", "max_min"] - - -class QuantizerAttributeConfig(ModeloptBaseConfig): - """Quantizer attribute type.""" - - enable: bool = ModeloptField( - default=True, - title="Enable quantizer.", - description="""If True, enables the quantizer. If False, by-pass the quantizer and returns the input tensor.""", - ) - - num_bits: int | tuple[int, int] = ModeloptField( - default=8, - title="An integer or a tuple of two integers specifying the number of quantization bits.", - description="""`num_bits` can be: - - #. A positive integer argument for integer quantization. `num_bits` specify - the number of bits used for integer quantization. - - #. Constant integer tuple (E,M) for floating point quantization emulating - Nvidia's FPx quantization. E is the number of exponent bits and M is the number - of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1).""", - ) - - @model_validator(mode="before") - @classmethod - def validate_config(cls, values): - """Validate quantizer config.""" - - def _validate_recursive(value): - """Recursively validate config structure.""" - if value is None: - return - - if isinstance(value, list): - for item in value: - _validate_recursive(item) - elif isinstance(value, dict): - if len(value) == 1 and "enable" in value and value["enable"] is True: - raise ValueError( - "Invalid quantizer config: Cannot specify only {'enable': True}. " - "Additional parameters are required when enabling quantization." - ) - # Recurse into nested dicts - for v in value.values(): - _validate_recursive(v) - - _validate_recursive(values) - return values - - @model_validator(mode="after") - def validate_num_bits(self): - """Validate `num_bits`.""" - num_bits = self.num_bits - - if isinstance(num_bits, int) and num_bits < 1: - raise ValueError("num_bits must be a positive integer or a tuple of positive integers.") - - if not isinstance(num_bits, tuple): - return self - - if not all(x > 0 for x in num_bits): - raise ValueError("num_bits must be a positive integer or a tuple of positive integers.") - - block_sizes = self.block_sizes - if num_bits not in [ - (4, 3), - (5, 2), - (2, 1), - (1, 2), - (0, 3), - (3, 0), - (3, 2), - (2, 3), - ]: - raise ValueError( - "Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1)." - ) - elif num_bits != (4, 3) and ( - block_sizes is None or block_sizes.get("type", None) != "dynamic" - ): - raise ValueError( - "Only blockwise dynamic quantization is supported with quantization " - "formats E{num_bis[0]}M{num_bits[1]}." - ) - return self - - axis: int | tuple[int, ...] | None = ModeloptField( - default=None, - title="None, integer or a tuple of integers specifying the axis to quantize.", - description="""This field is for static per-channel quantization. *It cannot coexist with `block_sizes`*. - You should set axis if you want a fixed shape of scale factor. - - For example, if axis is set to None, the scale factor will be a scalar (per-tensor quantization) - if the axis is set to 0, the scale factor will be a vector of shape (dim0, ) (per-channel quantization). - if the axis is set to (-2, -1), the scale factor will be a vector of shape (dim-2, dim-1) - - axis value must be in the range [-rank(input_tensor), rank(input_tensor)) - """, - ) - - fake_quant: bool = ModeloptField( - default=True, - title="Enable fake quantization.", - description="""If True, enable fake quantization.""", - ) - - unsigned: bool = ModeloptField( - default=False, - title="Enable unsigned quantization.", - description="""If True, enable unsigned quantization. Used only for integer quantization.""", - ) - - narrow_range: bool = ModeloptField( - default=False, - title="Enable narrow range quantization.", - description="""If True, enable narrow range quantization. Used only for integer quantization.""", - ) - - learn_amax: bool = ModeloptField( - default=False, - title="Enable learning amax.", - description="""``learn_amax`` is deprecated and reserved for backward compatibility.""", - ) - - @field_validator("learn_amax") - @classmethod - def validate_learn_amax(cls, v): - """Validate learn_amax.""" - assert v is not True, "learn_amax is deprecated and reserved for backward compatibility." - return v - - type: str = ModeloptField( - default="static", - title="""Specify whether the quantization is static or dynamic.""", - description="""The value is a string from ``["static", "dynamic"]``. - If ``"dynamic"``, dynamic quantization will be enabled which does not collect any statistics during - calibration.""", - pattern=r"^static$|^dynamic$", - ) - - block_sizes: dict[int | str, int | tuple[int, int] | str | dict[int, int] | None] | None = ( - ModeloptField( - default=None, - title="Optional dictionary specifying block quantization parameters.", - description="""This field is for static or dynamic block quantization. *It cannot coexist with ``axis``*. - You should set block_sizes if you want fixed number of elements to share every scale factor. - - The keys are the axes for block quantization and the - values are block sizes for quantization along the respective axes. Keys must be in the - range ``[-tensor.dim(), tensor.dim())``. Values, which are the block sizes for quantization must be - positive integers or ``None``. A positive block size specifies the block size for quantization along that - axis. ``None`` means that the block size will be the maximum possible size in that dimension - this is - useful for specifying certain quantization formats such per-token dynamic quantization which has the `amax` - shared along the last dimension. - - In addition, there can be special string keys ``"type"``, ``"scale_bits"`` and ``"scale_block_sizes"``. - - Key ``"type"`` should map to ``"dynamic"`` or ``"static"`` where ``"dynamic"`` - indicates dynamic block quantization and "static" - indicates static calibrated block quantization. By default, the type is ``"static"``. - - Key ``"scale_bits"`` specify the quantization bits for the per-block quantization scale factor - (i.e a double quantization scheme). - - Key ``"scale_block_sizes"`` specify the block size for double quantization. - By default per-block quantization scale is not quantized. - - For example, ``block_sizes = {-1: 32}`` will quantize the last axis of the input tensor in - blocks of size 32 with static calibration, with a total of ``numel(tensor) / 32`` scale factors. - ``block_sizes = {-1: 32, "type": "dynamic"}`` will perform dynamic block quantization. - ``block_sizes = {-1: None, "type": "dynamic"}`` can be used to - specify per-token dynamic quantization. - """, - ) - ) - - bias: dict[int | str, BiasType | BiasMethod | tuple[int, ...] | bool | int | None] | None = ( - ModeloptField( - default=None, - title="Bias configuration.", - description="""Configuration for bias handling in affine quantization. The keys are: - - "enable": Boolean to enable/disable bias handling, default is False - - "type": Specify the type of bias ["static", "dynamic"], default is "static" - - "method": Specify the method of bias calibration ["mean", "max_min"], default is "mean" - - "axis": Tuple of integers specifying axes for bias computation, default is None - - Examples: - bias = {"enable": True} - bias = {"enable": True, "type": "static", "axis": -1} - bias = {"enable": True, "type": "dynamic", "axis": (-1, -3)} - """, - ) - ) - - @staticmethod - def _get_block_quant_axes_and_sizes(block_sizes): - if block_sizes is None: - return None - return { - k: v - for k, v in block_sizes.items() - if k not in ["type", "scale_bits", "scale_block_sizes"] - } - - @field_validator("block_sizes") - @classmethod - def validate_block_sizes(cls, v, info: ValidationInfo): - """Validate block sizes.""" - if v is None: - return v - assert info.data["axis"] is None, "axis must be None when block_sizes is not None." - if v.get("type", None) == "dynamic": - assert len(cls._get_block_quant_axes_and_sizes(v)) == 1, ( - "Dynamic block quantization only supports quantization last axis." - ) - for _k, _v in v.items(): - if isinstance(_k, str): - assert _k in ["type", "scale_bits", "scale_block_sizes"] - else: - assert isinstance(_k, int) and (_v is None or isinstance(_v, int)) - return v - @field_validator("bias") - @classmethod - def validate_bias(cls, v): - """Validate bias.""" - if v is None: - return v - if "type" in v and v["type"] not in ["static", "dynamic"]: - raise ValueError(f"Invalid bias type: {v['type']}, expected 'static' or 'dynamic'") - - if "method" in v and v["method"] not in ["mean", "max_min"]: - raise ValueError(f"Invalid bias method: {v['method']}, expected 'mean' or 'max_min'") - - axis = [k for k in v.keys() if k not in ["type", "method"]] # noqa: SIM118 - assert len(axis) > 0, "The axis for bias computation is not specified." - for x in axis: - if not isinstance(x, int): - raise ValueError(f"Invalid axis type {type(axis)}, expected int") - - return v - - trt_high_precision_dtype: str = ModeloptField( - default="Float", - title="TRT StronglyType requires all weights and amax to be in the same dtype.", - description="""The value is a string from ``["Float", "Half", "BFloat16"]``. - The QDQs will be assigned the appropriate data type, and this variable will only be - used when the user is exporting the quantized ONNX model.""", - pattern=r"^Float$|^Half$|^BFloat16$", - ) - - calibrator: str | ConstructorLike = ModeloptField( - default="max", - title="""Specify the calibrator to use.""", - description="""The calibrator can be a string from ``["max", "histogram"]`` or a constructor - to create a calibrator which subclasses :class:`_Calibrator `. - See :meth:`standardize_constructor_args ` - for more information on how to specify the constructor.""", - ) - - @field_validator("calibrator") - @classmethod - def validate_calibrator(cls, v, info: ValidationInfo): - """Validate calibrator.""" - if isinstance(v, str): - assert v in ["max", "histogram"] - return v - - rotate: bool = ModeloptField( - default=False, - title="""If rotate the input before quantization.""", - description=""""If true, the input of the quantizer will be rotated with a hadamard matrix - given by scipy.linalg.hadamard, i.e. - ``input = input @ scipy.linalg.hadamard(input.shape[-1]) / sqrt(input.shape[-1])``. - - This can be used for ratation based PTQ methods, e.g. QuaRot or SpinQuant. - See https://arxiv.org/abs/2404.00456 for example.""", - ) - - pass_through_bwd: bool = ModeloptField( - default=False, - title="If set to true, fake quantization will be a pass through for gradient computation.", - description=""" - Gradient computation where fake quantization is pass through is called - 'Straight-Through Estimator (STE)'. STE does not require saving of the input tensor for - performing backward pass and hence consumes less memory. - - If set to False, we will use STE with zeroed outlier gradients. This setting could - yield better QAT accuracy depending on the quantization format. However, this setting - requires saving of the input tensor for computing gradients which uses more memory. - - For dynamic quantization formats like MXFP4, STE with zeroed outlier gradients - is not needed since fake quantization with dynamic amax results in minimal/no clipping. - """, - ) - - -class QuantizeAlgorithmConfig(ModeloptBaseConfig): - """Calibration algorithm config base.""" - - method: Literal[None] = ModeloptField( - None, - title="This field specifies the name of the calibration algorithm. If None, no calibration is performed.", - ) - - -class SVDQuantConfig(QuantizeAlgorithmConfig): - """The config for SVDQuant. - - Refer to the `SVDQuant paper `_ for more details. - """ - - method: Literal["svdquant"] = ModeloptField("svdquant") - - lowrank: int | None = ModeloptField( - default=32, - title="Low-rank dimension for the SVD LoRA", - description=( - "Specifies the rank of the LoRA used in the SVDQuant method, " - "which captures outliers from the original weights." - ), - ) - - -# QuantizeQuantCfgType = dict[ -# str | Callable, -# QuantizerAttributeConfig -# | list[QuantizerAttributeConfig] -# | dict[str | Callable, QuantizerAttributeConfig | list[QuantizerAttributeConfig]], -# ] - -# _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None - -# QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None - - -# TODO Jingyu Xin class PEFTConfig(ModeloptBaseConfig): """Default configuration for ``peft`` mode.""" @@ -385,55 +42,3 @@ class PEFTConfig(ModeloptBaseConfig): class ExportPEFTConfig(ModeloptBaseConfig): """An empty config.""" - - -class CompressConfig(ModeloptBaseConfig): - """Default configuration for ``compress`` mode.""" - - compress: dict[str, bool] = ModeloptField( - default={"*": True}, - title="""Enable weight compression for the given pattern. Default is False for all weights. - Call `compress` function to compress the model weights.""", - ) - - quant_gemm: bool = ModeloptField( - default=True, - title="Enable quantized GEMM.", - description="If True, quantized GEMM compute will be enabled. Otherwise, we only do weight-only quantization.", - ) - - -CompressCfgType = dict[str, bool] | None | CompressConfig - - -class _QuantizeExportConfig(ModeloptBaseConfig): - """An empty config.""" - - -def need_calibration(config): - """Check if calibration is needed for the given config.""" - if config["algorithm"] is not None and config["algorithm"] != "max": - return True - - def _not_dynamic(cfg): - return ( - cfg.get("enable", True) - and cfg.get("type", "") != "dynamic" - and cfg.get("*", {}).get("enable", True) - ) - - for name, cfg in config.get("quant_cfg", {}).items(): - if "weight_quantizer" in name: - # We don't calibrate weight quantizer - continue - # quantization like W4A8 has a list of weight quantizers - if isinstance(cfg, list): - for _config in cfg: - if _not_dynamic(_config): - print(f"{cfg}: True") - return True - elif _not_dynamic(cfg): - print(f"{cfg}: True") - return True - - return False diff --git a/modelopt/torch/peft/conversion.py b/modelopt/torch/peft/conversion.py index 91e930ba..75bf4740 100644 --- a/modelopt/torch/peft/conversion.py +++ b/modelopt/torch/peft/conversion.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Quantization conversion/restore utilities.""" +"""PEFT conversion and restore utilities for LoRA modules.""" from typing import Any @@ -23,7 +23,7 @@ from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict from modelopt.torch.utils import get_unwrapped_name -from .config import PEFTConfig, _QuantizeExportConfig +from .config import PEFTConfig from .lora.layer import LoRAModule, LoRAModuleRegistry __all__ = [ @@ -33,13 +33,12 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> ConvertReturnType: - """Convert the model to a quantized one as per `config`.""" + """Convert the model to a peft one as per `config`.""" # initialize the true module if necessary model = model.init_modellike() if isinstance(model, ModelLikeModule) else model # TODO: Replace to LoRA module replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config) - # set_quantizer_by_cfg(model, config.get("quant_cfg", {})) metadata = {} # Should return adapaters, active_adapters @@ -57,6 +56,7 @@ def restore_peft_model( def restore_peft_state(model: ModelLikeModule, metadata: MetadataDict): """Restore PEFT state from metadata or extra_state. + For backward compatibility, we check metadata first. For distributed checkpoints (NeMo-MCore), the state will be in extra_state of each LoRAModule and will be restored automatically via set_extra_state() during load_state_dict(). @@ -115,13 +115,11 @@ def replace_lora_module( def export_peft_model(model: nn.Module, config): - """Export the quantized model to a quantized model.""" - raise NotImplementedError("Exporting a quantized model is not supported yet.") + raise NotImplementedError("Exporting a peft model is not supported yet.") def restore_export_peft_model(model: nn.Module, config, metadata: MetadataDict): - """Restores the quantized model from the given state dict.""" - raise NotImplementedError("Restoring a quantized & exported model is not supported yet.") + raise NotImplementedError("Restoring a peft & exported model is not supported yet.") def _replace_lora_module(model: nn.Module, version=None, registry=LoRAModuleRegistry): @@ -133,20 +131,9 @@ def _replace_lora_module(model: nn.Module, version=None, registry=LoRAModuleRegi _replace_lora_module(getattr(model, name), version=version, registry=registry) -def export_quantized_model(model: nn.Module, config: _QuantizeExportConfig) -> ConvertReturnType: - """Export the quantized model to a quantized model.""" - raise NotImplementedError("Exporting a quantized model is not supported yet.") - - -def restore_export_quantized_model( - model: nn.Module, config: _QuantizeExportConfig, metadata: MetadataDict -) -> nn.Module: - """Restores the quantized model from the given state dict.""" - raise NotImplementedError("Restoring a quantized & exported model is not supported yet.") - - def update_peft_metadata_in_model(model: nn.Module) -> None: """Update the PEFT metadata in the model's ModeloptStateManager. + This function should be called after manually modifying LoRA adapters to ensure the metadata stored in the ModeloptStateManager reflects the current state. diff --git a/modelopt/torch/peft/convert.py b/modelopt/torch/peft/convert.py index 43850234..e24db097 100644 --- a/modelopt/torch/peft/convert.py +++ b/modelopt/torch/peft/convert.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""User-facing quantization API.""" +"""User-facing PEFT API for LoRA module conversion and adapter management.""" import fnmatch from typing import Any @@ -33,12 +33,15 @@ def update_model( config: dict[str, Any | PEFTConfig], ): """Update model with PEFT/LoRA adapters. + This function handles both initial PEFT conversion and adding additional adapters: - First call: Converts modules to LoRAModules and adds the first adapter - Subsequent calls: Adds new adapters to existing LoRAModules + Args: model: The model to update config: PEFT configuration containing adapter settings + Returns: The updated model with LoRA adapters """ @@ -50,6 +53,15 @@ def update_model( def add_adapter(model, config): + """Add a new LoRA adapter to the model. + + Args: + model: Model with LoRA modules to add adapters to + config: Configuration dict containing adapter_cfg and adapter_name + + Returns: + The model with the new adapter added + """ adapter_cfg = config["adapter_cfg"] adapter_name = config["adapter_name"] @@ -112,7 +124,4 @@ def is_peft_model(model: nn.Module) -> bool: Returns: True if the model contains LoRA modules, False otherwise """ - for _, module in model.named_modules(): - if isinstance(module, LoRAModule): - return True - return False + return any(isinstance(module, LoRAModule) for _, module in model.named_modules()) diff --git a/modelopt/torch/peft/lora/__init__.py b/modelopt/torch/peft/lora/__init__.py index 2523392a..3fed2846 100644 --- a/modelopt/torch/peft/lora/__init__.py +++ b/modelopt/torch/peft/lora/__init__.py @@ -1,3 +1,3 @@ -from . import layer -from . import tp_layer -# from . import linear_layer \ No newline at end of file +"""LoRA (Low-Rank Adaptation) implementation for parameter-efficient fine-tuning.""" + +from . import layer, tp_layer From 7326e3859e0e2e8638d12ae9bb49a7c5c7e02f9d Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 18 Sep 2025 21:41:16 +0000 Subject: [PATCH 09/15] Clean up code, more Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index e7aaa40b..7c2f84b8 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -176,7 +176,7 @@ def replace_quant_module(model: nn.Module, version=None, registry=QuantModuleReg """Recursively replace the module with quantized module.""" from .plugins.custom import ( register_custom_model_plugins_on_the_fly, - register_custom_post_conversion_plugins, ## not needed for lora megatron + register_custom_post_conversion_plugins, ) assert not is_quantized(model), "Model must not be quantized!" From 9b006f939c5cc7fd703822cb63420f156cb806be Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 18 Sep 2025 21:42:14 +0000 Subject: [PATCH 10/15] more clean up Signed-off-by: Jingyu Xin --- run_tp_test.sh | 35 ------- test.py | 255 ------------------------------------------------- test_single.py | 178 ---------------------------------- 3 files changed, 468 deletions(-) delete mode 100644 run_tp_test.sh delete mode 100644 test.py delete mode 100644 test_single.py diff --git a/run_tp_test.sh b/run_tp_test.sh deleted file mode 100644 index b38c55d5..00000000 --- a/run_tp_test.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# Script to run the test with tensor parallelism - -# Set the number of GPUs for tensor parallelism -NUM_GPUS=2 - -echo "Running Megatron model with Tensor Parallelism (TP=$NUM_GPUS)" -echo "This will use $NUM_GPUS GPUs" - -# Check if torchrun is available -if command -v torchrun &> /dev/null; then - echo "Using torchrun to launch the distributed job..." - torchrun --nproc_per_node=$NUM_GPUS test.py -else - echo "torchrun not found, using manual distributed launch..." - - # Set environment variables - export MASTER_ADDR=localhost - export MASTER_PORT=6001 - export WORLD_SIZE=$NUM_GPUS - - # Launch processes - for ((rank=0; rank<$NUM_GPUS; rank++)); do - echo "Launching rank $rank..." - RANK=$rank python test.py & - pids[$rank]=$! - done - - # Wait for all processes to complete - for pid in ${pids[*]}; do - wait $pid - done -fi - -echo "Test completed!" diff --git a/test.py b/test.py deleted file mode 100644 index 538f84fa..00000000 --- a/test.py +++ /dev/null @@ -1,255 +0,0 @@ -"""Megatron Tensor Parallel Model Test Script - -This script demonstrates: -1. Creating a Megatron model with tensor parallelism (TP=2) -2. Applying LoRA adapters to tensor parallel layers -3. Testing the model with proper distributed initialization - -To run with tensor parallelism: - torchrun --nproc_per_node=2 test.py - or - bash run_tp_test.sh - -The model uses ColumnParallelLinear and RowParallelLinear layers which -automatically handle weight sharding across GPUs when TP > 1. -""" - -import os - -import torch -import torch.nn.init as init -from megatron.core import parallel_state, tensor_parallel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig - -import modelopt.torch.peft as mtp - - -class DummyMegatronModel(MegatronModule): - """A simple dummy Megatron model with parallel linear layers for testing. - Uses larger dimensions to better demonstrate tensor parallelism. - """ - - def __init__(self, config: TransformerConfig): - super().__init__(config) - - # Larger dimensions for better tensor parallel demonstration - hidden_size = 1024 # Divisible by 2 for TP=2 - intermediate_size = 4096 # 4x hidden size, typical for transformers - - # Column parallel linear layer (splits output dimension) - self.linear_0 = tensor_parallel.ColumnParallelLinear( - input_size=hidden_size, - output_size=intermediate_size, - config=config, - init_method=init.xavier_normal_, - bias=False, - gather_output=False, - ) - self.linear_1 = tensor_parallel.RowParallelLinear( - input_size=intermediate_size, - output_size=hidden_size, - config=config, - init_method=init.xavier_normal_, - bias=False, - input_is_parallel=True, - skip_bias_add=True, - ) - # Row parallel linear layer (splits input dimension) - self.lm_head_0 = tensor_parallel.ColumnParallelLinear( - input_size=hidden_size, - output_size=intermediate_size, - config=config, - init_method=init.xavier_normal_, - bias=False, - gather_output=False, - ) - self.lm_head_1 = tensor_parallel.RowParallelLinear( - input_size=intermediate_size, - output_size=hidden_size, - config=config, - init_method=init.xavier_normal_, - bias=False, - input_is_parallel=True, - skip_bias_add=True, - ) - - def forward(self, input): - x = self.linear_0(input)[0] - x = self.linear_1(x)[0] - x = self.lm_head_0(x)[0] - x = self.lm_head_1(x)[0] - return x - - -def initialize_distributed(rank=0, world_size=1): - """Initialize torch distributed for parallel training.""" - if torch.distributed.is_initialized(): - return - - print(f"Initializing torch.distributed with rank: {rank}, world_size: {world_size}") - torch.cuda.set_device(rank) - - init_method = "tcp://" - master_ip = os.getenv("MASTER_ADDR", "localhost") - master_port = os.getenv("MASTER_PORT", "6001") - init_method += master_ip + ":" + master_port - - torch.distributed.init_process_group( - backend="nccl", world_size=world_size, rank=rank, init_method=init_method - ) - - -def initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - virtual_pipeline_model_parallel_size=None, - pipeline_model_parallel_split_rank=None, -): - """Initialize Megatron's model parallel groups.""" - # Destroy existing model parallel if any - parallel_state.destroy_model_parallel() - - # Initialize distributed if not already done - if not torch.distributed.is_initialized(): - initialize_distributed() - - # Initialize model parallel groups - parallel_state.initialize_model_parallel( - tensor_model_parallel_size, - pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size, - pipeline_model_parallel_split_rank, - ) - - -def create_dummy_megatron_model(tensor_model_parallel_size=2): - """Create and return a dummy Megatron model with tensor parallelism. - - Args: - tensor_model_parallel_size: Size of tensor model parallelism (default: 2) - - Returns: - DummyMegatronModel: The initialized model on CUDA - """ - # Get rank from environment or default to 0 - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", str(tensor_model_parallel_size))) - - # Initialize distributed and model parallel - initialize_distributed(rank=rank, world_size=world_size) - initialize_model_parallel( - tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=1 - ) - - # Set random seed for reproducibility - model_parallel_cuda_manual_seed(123) - - # Configure the transformer with larger dimensions - transformer_config = { - "num_layers": 4, - "hidden_size": 1024, # Must match model dimensions - "num_attention_heads": 16, - "use_cpu_initialization": True, - "sequence_parallel": False, # Set to True for sequence parallelism - } - config = TransformerConfig(**transformer_config) - - # Create and return the model - model = DummyMegatronModel(config=config) - - if torch.cuda.is_available(): - model = model.cuda() - - print(f"Model created on rank {rank} with TP size {tensor_model_parallel_size}") - - return model - - -def cleanup(): - """Clean up distributed and model parallel groups.""" - parallel_state.destroy_model_parallel() - if torch.distributed.is_initialized(): - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - -if __name__ == "__main__": - """ - To run with tensor parallelism size 2, use: - torchrun --nproc_per_node=2 test.py - - Or manually with: - RANK=0 WORLD_SIZE=2 MASTER_ADDR=localhost MASTER_PORT=6001 python test.py & - RANK=1 WORLD_SIZE=2 MASTER_ADDR=localhost MASTER_PORT=6001 python test.py - """ - try: - # Create the model with TP=2 - tensor_parallel_size = 2 - model = create_dummy_megatron_model(tensor_model_parallel_size=tensor_parallel_size) - - # Get rank for printing - rank = int(os.environ.get("RANK", "0")) - - if rank == 0: - print(f"\nCreated dummy Megatron model with TP={tensor_parallel_size}") - print("Model structure:") - for name, module in model.named_modules(): - if hasattr(module, "__class__"): - print(f" {name}: {module.__class__.__name__}") - - # Test forward pass - if torch.cuda.is_available(): - batch_size = 2 - seq_length = 512 - hidden_size = 1024 # Must match model hidden size - - # Create input tensor - x = torch.randn(batch_size, seq_length, hidden_size).cuda() - - # Synchronize before forward pass - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - output = model(x) - - if rank == 0: - print("\nForward pass successful!") - print(f"Input shape: {x.shape}") - print(f"Output shape: {output.shape}") - - # Test LoRA with tensor parallel model - lora_config = { - "adapter_type": "lora", - "adapter_name": "default", - "adapter_cfg": {"*linear*": {"rank": 64}, "*lm_head*": {"rank": 128}}, - } - - if rank == 0: - print("\nApplying LoRA configuration...") - - model = mtp.update_model(model, lora_config) - - # Test forward pass with LoRA - if torch.cuda.is_available(): - output_lora = model(x) - if rank == 0: - print("LoRA forward pass successful!") - print(model) - print(model.linear_0.lora_a_default) - print(f"Output shape with LoRA: {output_lora.shape}") - - # Optional: Test quantization (commented out) - # if rank == 0: - # print(f"\nApplying quantization...") - # mtq.quantize(model, mtq.INT8_DEFAULT_CFG) - - except Exception as e: - print(f"Error on rank {os.environ.get('RANK', '0')}: {e}") - raise - finally: - # Clean up - cleanup() - if int(os.environ.get("RANK", "0")) == 0: - print("\nCleaned up distributed environment") diff --git a/test_single.py b/test_single.py deleted file mode 100644 index 729eab87..00000000 --- a/test_single.py +++ /dev/null @@ -1,178 +0,0 @@ -# dummy_megatron_model.py -import os - -import torch -import torch.nn.init as init -from megatron.core import parallel_state, tensor_parallel -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig - - -class DummyMegatronModel(MegatronModule): - """A simple dummy Megatron model with parallel linear layers for testing.""" - - def __init__(self, config: TransformerConfig): - super().__init__(config) - - # Column parallel linear layer (splits output dimension) - self.linear_0 = tensor_parallel.ColumnParallelLinear( - input_size=10, - output_size=10, - config=config, - init_method=init.xavier_normal_, - bias=False, - gather_output=False, - ) - self.linear_1 = tensor_parallel.RowParallelLinear( - input_size=10, - output_size=10, - config=config, - init_method=init.xavier_normal_, - bias=False, - input_is_parallel=True, - skip_bias_add=True, - ) - # Row parallel linear layer (splits input dimension) - self.lm_head_0 = tensor_parallel.ColumnParallelLinear( - input_size=10, - output_size=10, - config=config, - init_method=init.xavier_normal_, - bias=False, - gather_output=False, - ) - self.lm_head_1 = tensor_parallel.RowParallelLinear( - input_size=10, - output_size=10, - config=config, - init_method=init.xavier_normal_, - bias=False, - input_is_parallel=True, - skip_bias_add=True, - ) - - def forward(self, input): - x = self.linear_0(input)[0] - x = self.linear_1(x)[0] - x = self.lm_head_0(x)[0] - x = self.lm_head_1(x)[0] - return x - - -def initialize_distributed(rank=0, world_size=1): - """Initialize torch distributed for parallel training.""" - if torch.distributed.is_initialized(): - return - - print(f"Initializing torch.distributed with rank: {rank}, world_size: {world_size}") - torch.cuda.set_device(rank) - - init_method = "tcp://" - master_ip = os.getenv("MASTER_ADDR", "localhost") - master_port = os.getenv("MASTER_PORT", "6001") - init_method += master_ip + ":" + master_port - - torch.distributed.init_process_group( - backend="nccl", world_size=world_size, rank=rank, init_method=init_method - ) - - -def initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - virtual_pipeline_model_parallel_size=None, - pipeline_model_parallel_split_rank=None, -): - """Initialize Megatron's model parallel groups.""" - # Destroy existing model parallel if any - parallel_state.destroy_model_parallel() - - # Initialize distributed if not already done - if not torch.distributed.is_initialized(): - initialize_distributed() - - # Initialize model parallel groups - parallel_state.initialize_model_parallel( - tensor_model_parallel_size, - pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size, - pipeline_model_parallel_split_rank, - ) - - -def create_dummy_megatron_model(): - """Create and return a dummy Megatron model. - - Returns: - DummyMegatronModel: The initialized model on CUDA - """ - # Initialize model parallel (single GPU by default) - initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) - - # Set random seed for reproducibility - model_parallel_cuda_manual_seed(123) - - # Configure the transformer - transformer_config = { - "num_layers": 2, - "hidden_size": 12, - "num_attention_heads": 4, - "use_cpu_initialization": True, - } - config = TransformerConfig(**transformer_config) - - # Create and return the model - model = DummyMegatronModel(config=config) - - if torch.cuda.is_available(): - model = model.cuda() - - return model - - -def cleanup(): - """Clean up distributed and model parallel groups.""" - parallel_state.destroy_model_parallel() - if torch.distributed.is_initialized(): - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - -if __name__ == "__main__": - # Example usage - try: - # Create the model - model = create_dummy_megatron_model() - print(f"Created dummy Megatron model: {model}") - # Test forward pass - if torch.cuda.is_available(): - x = torch.randn(2, 4, 10).cuda() - output = model(x) - print(f"Input shape: {x.shape}") - print(f"Output shape: {output.shape}") - - # # Print model structure - # print("\nModel structure:") - # for name, module in model.named_modules(): - # print(f" {name}: {module.__class__.__name__}") - lora_config = { - "adapter_type": "lora", - "adapter_name": "default", - "adapter_cfg": { - "*transformer*qkv*": {"rank": 64}, - "*ffn*": {"rank": 128}, - "*linear*": {"rank": 128}, - }, - } - # model = mtp.update_model(model, lora_config) - # if torch.cuda.is_available(): - # x = torch.randn(2, 4, 10).cuda() - # output = model(x) - # print(f"Input shape: {x.shape}") - # print(f"Output shape: {output.shape}") - # mtq.quantize(model, mtq.MXFP4_DEFAULT_CFG) - finally: - # Clean up - cleanup() - print("\nCleaned up distributed environment") From ca9698d1f18acc876ce76a60a5a66b51ffea2fbb Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 19 Sep 2025 18:36:59 +0000 Subject: [PATCH 11/15] Update more, config + conversation Signed-off-by: Jingyu Xin --- modelopt/torch/peft/config.py | 110 ++++++++++++++++++++++++++- modelopt/torch/peft/conversion.py | 59 ++++++++++++++ modelopt/torch/peft/convert.py | 73 +++--------------- modelopt/torch/peft/lora/layer.py | 31 ++++---- modelopt/torch/peft/lora/tp_layer.py | 57 +++++++------- 5 files changed, 219 insertions(+), 111 deletions(-) diff --git a/modelopt/torch/peft/config.py b/modelopt/torch/peft/config.py index 838f1c4e..877bf3f6 100644 --- a/modelopt/torch/peft/config.py +++ b/modelopt/torch/peft/config.py @@ -15,30 +15,132 @@ """Configuration classes for PEFT methods.""" +import math +from collections.abc import Callable +from collections.abc import Callable as CallableType + +import torch.nn.init as init +from pydantic import field_validator, model_validator + from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField +__all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"] + + +class PEFTAttributeConfig(ModeloptBaseConfig): + """Configuration for PEFT adapter attributes.""" + + enable: bool = ModeloptField( + default=True, + title="Enable adapter", + description="If True, enables the adapter. If False, by-passes the adapter.", + ) + + rank: int = ModeloptField( + default=64, + title="LoRA rank", + description=( + "The rank (dimension) of the LoRA matrices. " + "Higher rank allows more expressiveness but uses more memory." + ), + ) + + scale: float = ModeloptField( + default=1.0, + title="LoRA scaling factor", + description="Scaling factor for the LoRA output. Controls the magnitude of the adaptation.", + ) + + lora_a_init: Callable[[object], None] | None = ModeloptField( + default=lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)), + title="LoRA A matrix initializer", + description="Custom initialization function for LoRA A matrix. Default to Kaiming uniform initialization.", + ) + + lora_b_init: Callable[[object], None] | None = ModeloptField( + default=lambda weight: init.zeros_(weight), + title="LoRA B matrix initializer", + description="Custom initialization function for LoRA B matrix. Default to zero initialization.", + ) + + @field_validator("rank") + @classmethod + def validate_rank(cls, v): + """Validate rank is positive.""" + if v < 1: + raise ValueError("rank must be a positive integer") + return v + + @field_validator("scale") + @classmethod + def validate_scale(cls, v): + """Validate scale is positive.""" + if v <= 0: + raise ValueError("scale must be a positive number") + return v + + @model_validator(mode="after") + def validate_init_functions(self): + """Validate initialization functions are callable.""" + if self.lora_a_init is not None and not callable(self.lora_a_init): + raise ValueError("lora_a_init must be callable") + if self.lora_b_init is not None and not callable(self.lora_b_init): + raise ValueError("lora_b_init must be callable") + return self + + +# Type alias for adapter configuration +PEFTAdapterCfgType = dict[str | CallableType, PEFTAttributeConfig | dict] + class PEFTConfig(ModeloptBaseConfig): """Default configuration for ``peft`` mode.""" adapter_name: str = ModeloptField( default="default", - title="Placeholder", + title="Adapter name", + description="Name of the adapter to create or update.", validate_default=True, ) - adapter_cfg: dict = ModeloptField( + adapter_cfg: PEFTAdapterCfgType = ModeloptField( default={"default": {"rank": 128}}, - title="Placeholder", + title="Adapter configuration", + description="Configuration for adapters. Maps module patterns to PEFTAttributeConfig or dict.", validate_default=True, ) adapter_type: str = ModeloptField( default="lora", - title="Placeholder", + title="Adapter type", + description="Type of PEFT adapter to use. Currently only 'lora' is supported.", validate_default=True, ) + @field_validator("adapter_type") + @classmethod + def validate_adapter_type(cls, v): + """Validate adapter type.""" + if v not in ["lora"]: + raise ValueError(f"Unsupported adapter type: {v}. Only 'lora' is currently supported.") + return v + + @field_validator("adapter_cfg") + @classmethod + def validate_adapter_cfg(cls, v): + """Validate and convert adapter configurations.""" + validated_cfg = {} + for key, value in v.items(): + if isinstance(value, dict) and not isinstance(value, PEFTAttributeConfig): + # Convert dict to PEFTAttributeConfig to trigger validation + try: + validated_cfg[key] = PEFTAttributeConfig(**value) + except Exception as e: + raise ValueError(f"Invalid adapter configuration for '{key}': {e}") + else: + validated_cfg[key] = value + return validated_cfg + class ExportPEFTConfig(ModeloptBaseConfig): """An empty config.""" diff --git a/modelopt/torch/peft/conversion.py b/modelopt/torch/peft/conversion.py index 75bf4740..7a4fba1a 100644 --- a/modelopt/torch/peft/conversion.py +++ b/modelopt/torch/peft/conversion.py @@ -15,6 +15,7 @@ """PEFT conversion and restore utilities for LoRA modules.""" +import fnmatch from typing import Any import torch.nn as nn @@ -41,6 +42,7 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config) metadata = {} + add_adapter(model, config) # Should return adapaters, active_adapters update_peft_metadata(model, config, metadata) @@ -157,3 +159,60 @@ def update_peft_metadata_in_model(model: nn.Module) -> None: # Update the metadata with current PEFT state if manager._state and manager._last_metadata is not None: manager._last_metadata["peft_state"] = peft_state(model) + + +def add_adapter(model, config: PEFTConfig): + """Add a new LoRA adapter to the model. + + Args: + model: Model with LoRA modules to add adapters to + config: PEFTConfig instance containing adapter_cfg and adapter_name + + Returns: + The model with the new adapter added + """ + adapter_cfg = config.adapter_cfg + adapter_name = config.adapter_name + + for name, module in model.named_modules(): + if isinstance(module, LoRAModule): + for wildcard_or_filter_func, adapter_setting in adapter_cfg.items(): + if isinstance(wildcard_or_filter_func, str): + if not fnmatch.fnmatch(name, wildcard_or_filter_func): + continue + elif callable(wildcard_or_filter_func): + if not wildcard_or_filter_func(name): + continue + else: + raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") + if adapter_setting.enable: # type: ignore[union-attr] + module.update_layer_lora( + adapter_name, + adapter_setting, + ) + + _update_peft_metadata_in_state(model) + return model + + +def _update_peft_metadata_in_state(model: nn.Module) -> None: + """Update the PEFT metadata in the ModeloptStateManager. + + This function updates the metadata to reflect the current state of LoRA adapters + after they have been added or modified. + """ + if not ModeloptStateManager.is_converted(model): + return + + manager = ModeloptStateManager(model) + + current_peft_state = {} + for name, module in model.named_modules(): + if isinstance(module, LoRAModule): + from modelopt.torch.utils import get_unwrapped_name + + unwrapped_name = get_unwrapped_name(name) + current_peft_state[unwrapped_name] = module.get_peft_state() + + if manager._state and manager._last_metadata is not None: + manager._last_metadata["peft_state"] = current_peft_state diff --git a/modelopt/torch/peft/convert.py b/modelopt/torch/peft/convert.py index e24db097..c73042b6 100644 --- a/modelopt/torch/peft/convert.py +++ b/modelopt/torch/peft/convert.py @@ -15,14 +15,13 @@ """User-facing PEFT API for LoRA module conversion and adapter management.""" -import fnmatch from typing import Any import torch.nn as nn from modelopt.torch.opt import apply_mode -from modelopt.torch.opt.conversion import ModeloptStateManager from modelopt.torch.peft.config import PEFTConfig +from modelopt.torch.peft.conversion import add_adapter from .lora.layer import LoRAModule from .mode import PEFTModeRegistry @@ -30,7 +29,7 @@ def update_model( model: nn.Module, - config: dict[str, Any | PEFTConfig], + config: dict[str, Any] | PEFTConfig, ): """Update model with PEFT/LoRA adapters. @@ -40,78 +39,24 @@ def update_model( Args: model: The model to update - config: PEFT configuration containing adapter settings + config: PEFT configuration dict or PEFTConfig instance Returns: The updated model with LoRA adapters """ + # Validate config by converting to PEFTConfig if needed + # Check if model is already in PEFT mode by looking for LoRA modules if not is_peft_model(model): # First time - need to convert to PEFT mode apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry) - return add_adapter(model, config) - - -def add_adapter(model, config): - """Add a new LoRA adapter to the model. - - Args: - model: Model with LoRA modules to add adapters to - config: Configuration dict containing adapter_cfg and adapter_name - - Returns: - The model with the new adapter added - """ - adapter_cfg = config["adapter_cfg"] - adapter_name = config["adapter_name"] - - for name, module in model.named_modules(): - if isinstance(module, LoRAModule): - for wildcard_or_filter_func, adapter_setting in adapter_cfg.items(): - if isinstance(wildcard_or_filter_func, str): - if not fnmatch.fnmatch(name, wildcard_or_filter_func): - continue - elif callable(wildcard_or_filter_func): - if not wildcard_or_filter_func(name): - continue - else: - raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") - module.update_layer_lora( - adapter_name, adapter_setting["rank"], adapter_setting.get("scale", 1.0) - ) - - # Update the metadata in ModeloptStateManager after adding adapters - _update_peft_metadata_in_state(model) + else: + if not isinstance(config, PEFTConfig): + config = PEFTConfig(**config) + add_adapter(model, config) return model -def _update_peft_metadata_in_state(model: nn.Module) -> None: - """Update the PEFT metadata in the ModeloptStateManager. - - This function updates the metadata to reflect the current state of LoRA adapters - after they have been added or modified. - """ - # Check if model has ModeloptStateManager (has been converted with peft mode) - if not ModeloptStateManager.is_converted(model): - return - - # Get the state manager - manager = ModeloptStateManager(model) - - # Get current PEFT state from all LoRA modules - current_peft_state = {} - for name, module in model.named_modules(): - if isinstance(module, LoRAModule): - from modelopt.torch.utils import get_unwrapped_name - - unwrapped_name = get_unwrapped_name(name) - current_peft_state[unwrapped_name] = module.get_peft_state() - - # Update the metadata in the last mode state (which should be 'peft') - if manager._state and manager._last_metadata is not None: - manager._last_metadata["peft_state"] = current_peft_state - - def is_peft_model(model: nn.Module) -> bool: """Check if the model has been converted to PEFT/LoRA model. diff --git a/modelopt/torch/peft/lora/layer.py b/modelopt/torch/peft/lora/layer.py index 470e6484..a795598f 100644 --- a/modelopt/torch/peft/lora/layer.py +++ b/modelopt/torch/peft/lora/layer.py @@ -9,6 +9,8 @@ from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls +from ..config import PEFTAttributeConfig + __all__ = [ "LoRAModule", "LoRAModuleRegistry", @@ -100,7 +102,11 @@ def _register_adapter( self.activate_adapter(adapter_name) @abstractmethod - def update_layer_lora(self, adapter_name: str, rank: int = 64, scale: float = 1.0) -> None: + def update_layer_lora( + self, + adapter_name: str, + attr_config: PEFTAttributeConfig, + ) -> None: """Create and register a new LoRA adapter. This method must be implemented by subclasses to create the appropriate @@ -110,6 +116,8 @@ def update_layer_lora(self, adapter_name: str, rank: int = 64, scale: float = 1. adapter_name: Name for the new adapter rank: Rank of the LoRA decomposition (default: 64) scale: Scale factor for the LoRA output (default: 1.0) + lora_a_init: Optional initialization function for LoRA A matrix + lora_b_init: Optional initialization function for LoRA B matrix """ raise NotImplementedError("Subclasses must implement update_layer_lora") @@ -189,24 +197,17 @@ def set_from_peft_state(self, peft_state: dict[str, Any]) -> None: """ adapters_config = peft_state.get("adapters", {}) - # Clear existing adapters first self._lora_adapters.clear() self._active_adapters.clear() - # Recreate each adapter based on saved configuration for adapter_name, config in adapters_config.items(): - rank = config.get("rank") - scale = config.get("scale", 1.0) - - if rank is not None: - # Create the adapter with saved configuration - self.update_layer_lora(adapter_name, rank=rank, scale=scale) + self.update_layer_lora(adapter_name, config) - # Set activation state - if config.get("is_active", False): - self.activate_adapter(adapter_name) - else: - self.deactivate_adapter(adapter_name) + # Set activation state + if config.get("is_active", False): + self.activate_adapter(adapter_name) + else: + self.deactivate_adapter(adapter_name) def set_extra_state(self, state: dict[str, Any]) -> None: """Restore extra state for distributed checkpointing. @@ -281,7 +282,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Any: # Return output in the same format as the base layer if other_outputs: - return (result,) + other_outputs + return (result, *other_outputs) else: return result diff --git a/modelopt/torch/peft/lora/tp_layer.py b/modelopt/torch/peft/lora/tp_layer.py index 8f466cfa..9e4c588d 100644 --- a/modelopt/torch/peft/lora/tp_layer.py +++ b/modelopt/torch/peft/lora/tp_layer.py @@ -7,6 +7,7 @@ import torch.nn.init as init from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from ..config import PEFTAttributeConfig from .layer import LoRAModule, LoRAModuleRegistry try: @@ -32,14 +33,16 @@ class _MegatronParallelLoRABase(LoRAModule): LoRA implementations, reducing code duplication. """ - def _get_init_methods(self) -> tuple[Callable, Callable]: + def _get_init_methods(self, lora_a_init, lora_b_init) -> tuple[Callable, Callable]: """Get initialization methods for LoRA A and B matrices. Returns: Tuple of (lora_a_init, lora_b_init) initialization functions """ - lora_a_init = lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)) # noqa: E731 # LoRA A: Kaiming uniform - lora_b_init = lambda weight: init.zeros_(weight) # noqa: E731 # LoRA B: zeros + if lora_a_init is None: + lora_a_init = lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)) # noqa: E731 # LoRA A: Kaiming uniform + if lora_b_init is None: + lora_b_init = lambda weight: init.zeros_(weight) # noqa: E731 # LoRA B: zeros return lora_a_init, lora_b_init def _register_adapter_with_device( @@ -87,7 +90,9 @@ class _LoRAMegatronColumnParallelLinear(_MegatronParallelLoRABase): """ def update_layer_lora( - self, adapter_name: str, rank: int = DEFAULT_LORA_RANK, scale: float = DEFAULT_SCALE + self, + adapter_name: str, + attr_config: PEFTAttributeConfig, ) -> None: """Create and register a new LoRA adapter for ColumnParallelLinear. @@ -95,30 +100,28 @@ def update_layer_lora( adapter_name: Name for the new adapter rank: Rank of the LoRA decomposition """ - lora_a_init, lora_b_init = self._get_init_methods() - - # Create LoRA A: input_size -> rank (with gather for full reduction) lora_a = ColumnParallelLinear( self.input_size, - rank, + attr_config.rank, config=self.config, bias=False, - gather_output=True, # Gather outputs for complete transformation - init_method=lora_a_init, + gather_output=True, + init_method=attr_config.lora_a_init, disable_grad_reduce=getattr(self.config, "sequence_parallel", False), ) - # Create LoRA B: rank -> output_size (no gather, stays distributed) lora_b = ColumnParallelLinear( - rank, + attr_config.rank, self.output_size, config=self.config, bias=False, gather_output=False, # Keep output distributed like base layer - init_method=lora_b_init, + init_method=attr_config.lora_a_init, ) - self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank, scale) + self._register_adapter_with_device( + adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale + ) @LoRAModuleRegistry.register({RowParallelLinear: "megatron_RowParallelLinear"}) @@ -130,7 +133,9 @@ class _LoRAMegatronRowParallelLinear(_MegatronParallelLoRABase): """ def update_layer_lora( - self, adapter_name: str, rank: int = DEFAULT_LORA_RANK, scale: float = DEFAULT_SCALE + self, + adapter_name: str, + attr_config: PEFTAttributeConfig, ) -> None: """Create and register a new LoRA adapter for RowParallelLinear. @@ -138,35 +143,32 @@ def update_layer_lora( adapter_name: Name for the new adapter rank: Rank of the LoRA decomposition """ - lora_a_init, lora_b_init = self._get_init_methods() - - # Create LoRA A: input_size -> rank (row parallel, input already distributed) lora_a = RowParallelLinear( self.input_size, - rank, + attr_config.rank, config=self.config, - input_is_parallel=True, # Input is already distributed + input_is_parallel=True, skip_bias_add=True, bias=False, - init_method=lora_a_init, + init_method=attr_config.lora_a_init, ) - # Create LoRA B: rank -> output_size (column parallel with gather) lora_b = ColumnParallelLinear( - rank, + attr_config.rank, self.output_size, config=self.config, bias=False, - gather_output=True, # Gather to match base layer output - init_method=lora_b_init, + gather_output=True, + init_method=attr_config.lora_b_init, ) - self._register_adapter_with_device(adapter_name, lora_a, lora_b, rank, scale) + self._register_adapter_with_device( + adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale + ) # Register quantized versions if available if QUANT_MODULES_AVAILABLE: - # Register the same LoRA implementations for quantized modules LoRAModuleRegistry.register({QuantColumnParallelLinear: "quant_megatron_ColumnParallelLinear"})( _LoRAMegatronColumnParallelLinear ) @@ -200,7 +202,6 @@ class _QuantLoRAMegatronRowParallelLinear( def _setup(self): QuantRowParallelLinear._setup(self) - # Register LoRA modules in QuantModuleRegistry so they can be quantized QuantModuleRegistry.register( {_LoRAMegatronColumnParallelLinear: "lora_megatron_ColumnParallelLinear"} )(_QuantLoRAMegatronColumnParallelLinear) From 019efb078364d30f49ce31f5954185d8bc8f5fc8 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 19 Sep 2025 20:03:30 +0000 Subject: [PATCH 12/15] Update disable/enable logic Signed-off-by: Jingyu Xin --- modelopt/torch/peft/conversion.py | 10 ++-- modelopt/torch/peft/lora/layer.py | 90 +++++++--------------------- modelopt/torch/peft/lora/tp_layer.py | 14 +++-- 3 files changed, 37 insertions(+), 77 deletions(-) diff --git a/modelopt/torch/peft/conversion.py b/modelopt/torch/peft/conversion.py index 7a4fba1a..256cb185 100644 --- a/modelopt/torch/peft/conversion.py +++ b/modelopt/torch/peft/conversion.py @@ -185,13 +185,11 @@ def add_adapter(model, config: PEFTConfig): continue else: raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") - if adapter_setting.enable: # type: ignore[union-attr] - module.update_layer_lora( - adapter_name, - adapter_setting, - ) + module.update_layer_lora( + adapter_name, + adapter_setting, + ) - _update_peft_metadata_in_state(model) return model diff --git a/modelopt/torch/peft/lora/layer.py b/modelopt/torch/peft/lora/layer.py index a795598f..f0eadb1d 100644 --- a/modelopt/torch/peft/lora/layer.py +++ b/modelopt/torch/peft/lora/layer.py @@ -31,7 +31,6 @@ class LoRAModule(DynamicModule): def _setup(self) -> None: """Initialize LoRA-specific attributes.""" self._lora_adapters: dict[str, dict[str, Any]] = {} - self._active_adapters: set = set() @property def adapter_names(self) -> set: @@ -43,39 +42,14 @@ def active_adapters(self) -> set: """Return the set of currently active adapter names.""" return self._active_adapters.copy() - def activate_adapter(self, adapter_name: str) -> None: - """Activate a specific adapter. - - Args: - adapter_name: Name of the adapter to activate - - Raises: - ValueError: If adapter_name is not registered - """ - if adapter_name not in self._lora_adapters: - raise ValueError( - f"Adapter '{adapter_name}' not found. Available: {list(self._lora_adapters.keys())}" - ) - self._active_adapters.add(adapter_name) - - def deactivate_adapter(self, adapter_name: str) -> None: - """Deactivate a specific adapter. - - Args: - adapter_name: Name of the adapter to deactivate - """ - self._active_adapters.discard(adapter_name) - - def activate_all_adapters(self) -> None: - """Activate all registered adapters.""" - self._active_adapters = self.adapter_names.copy() - - def deactivate_all_adapters(self) -> None: - """Deactivate all adapters.""" - self._active_adapters.clear() - def _register_adapter( - self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int, scale: float = 1.0 + self, + adapter_name: str, + lora_a: nn.Module, + lora_b: nn.Module, + rank: int, + scale: float = 1.0, + enable: bool = True, ) -> None: """Register a new LoRA adapter with explicit rank tracking. @@ -86,7 +60,6 @@ def _register_adapter( rank: Rank of the LoRA decomposition scale: Scale factor for the LoRA output """ - # Add as submodules for proper parameter registration self.add_module(f"lora_a_{adapter_name}", lora_a) self.add_module(f"lora_b_{adapter_name}", lora_b) @@ -94,13 +67,11 @@ def _register_adapter( self._lora_adapters[adapter_name] = { "lora_a": lora_a, "lora_b": lora_b, - "rank": rank, # Store rank explicitly for reliability + "rank": rank, "scale": scale, + "enable": enable, } - # Automatically activate new adapters - self.activate_adapter(adapter_name) - @abstractmethod def update_layer_lora( self, @@ -156,14 +127,11 @@ def get_peft_state(self) -> dict[str, Any]: adapters_config[adapter_name] = { "rank": rank, - "is_active": adapter_name in self._active_adapters, - "lora_a_type": type(lora_a).__name__, - "lora_b_type": type(lora_b).__name__, + "enable": adapter_modules.get("enable", True), "scale": adapter_modules.get("scale", 1.0), } modelopt_state["adapters"] = adapters_config - modelopt_state["active_adapters"] = list(self._active_adapters) return modelopt_state @@ -246,41 +214,29 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Any: Returns: Output from the base layer plus active LoRA adaptations """ - # Call the base layer's forward method output = super().forward(x, *args, **kwargs) - # Handle different output types from base layer if isinstance(output, tuple): - # If output is a tuple, assume first element is the main result result = output[0] other_outputs = output[1:] else: - # If output is a single tensor result = output other_outputs = () - # Apply active LoRA adapters - if self._active_adapters and self._lora_adapters: - for adapter_name in self._active_adapters: - if adapter_name in self._lora_adapters: - adapter = self._lora_adapters[adapter_name] - # LoRA computation: result = result + B(A(x)) - lora_a = adapter["lora_a"] - lora_b = adapter["lora_b"] - - # Handle different forward signatures - lora_a_output = lora_a(x) - if isinstance(lora_a_output, tuple): - lora_a_output = lora_a_output[0] - - lora_b_output = lora_b(lora_a_output) - if isinstance(lora_b_output, tuple): - lora_b_output = lora_b_output[0] - - scale = adapter.get("scale", 1.0) - result = result + scale * lora_b_output + for adapter_name in self._lora_adapters: + adapter = self._lora_adapters[adapter_name] + if adapter["enable"]: + lora_a = adapter["lora_a"] + lora_b = adapter["lora_b"] + lora_a_output = lora_a(x) + if isinstance(lora_a_output, tuple): + lora_a_output = lora_a_output[0] + lora_b_output = lora_b(lora_a_output) + if isinstance(lora_b_output, tuple): + lora_b_output = lora_b_output[0] + scale = adapter["scale"] + result = result + scale * lora_b_output - # Return output in the same format as the base layer if other_outputs: return (result, *other_outputs) else: diff --git a/modelopt/torch/peft/lora/tp_layer.py b/modelopt/torch/peft/lora/tp_layer.py index 9e4c588d..eff6054f 100644 --- a/modelopt/torch/peft/lora/tp_layer.py +++ b/modelopt/torch/peft/lora/tp_layer.py @@ -46,7 +46,13 @@ def _get_init_methods(self, lora_a_init, lora_b_init) -> tuple[Callable, Callabl return lora_a_init, lora_b_init def _register_adapter_with_device( - self, adapter_name: str, lora_a: nn.Module, lora_b: nn.Module, rank: int, scale: float + self, + adapter_name: str, + lora_a: nn.Module, + lora_b: nn.Module, + rank: int, + scale: float, + enable: bool, ) -> None: """Register LoRA adapter modules and ensure correct device placement. @@ -78,7 +84,7 @@ def _register_adapter_with_device( lora_a = lora_a.to(dtype) lora_b = lora_b.to(dtype) - super()._register_adapter(adapter_name, lora_a, lora_b, rank, scale) + super()._register_adapter(adapter_name, lora_a, lora_b, rank, scale, enable) @LoRAModuleRegistry.register({ColumnParallelLinear: "megatron_ColumnParallelLinear"}) @@ -120,7 +126,7 @@ def update_layer_lora( ) self._register_adapter_with_device( - adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale + adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable ) @@ -163,7 +169,7 @@ def update_layer_lora( ) self._register_adapter_with_device( - adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale + adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable ) From 1d38784e8c6929e8f71b4a3e6297bba70a6dec03 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 19 Sep 2025 21:15:47 +0000 Subject: [PATCH 13/15] Update restore logic Signed-off-by: Jingyu Xin --- modelopt/torch/peft/config.py | 38 +++++++++++++++++++++++++++---- modelopt/torch/peft/lora/layer.py | 12 ++-------- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/peft/config.py b/modelopt/torch/peft/config.py index 877bf3f6..86800e72 100644 --- a/modelopt/torch/peft/config.py +++ b/modelopt/torch/peft/config.py @@ -16,8 +16,8 @@ """Configuration classes for PEFT methods.""" import math +import pickle # nosec B403 - Only checking picklability from collections.abc import Callable -from collections.abc import Callable as CallableType import torch.nn.init as init from pydantic import field_validator, model_validator @@ -27,6 +27,16 @@ __all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"] +def default_lora_a_init(weight): + """Default initialization for LoRA A matrix using Kaiming uniform.""" + return init.kaiming_uniform_(weight, a=math.sqrt(5)) + + +def default_lora_b_init(weight): + """Default initialization for LoRA B matrix using zeros.""" + return init.zeros_(weight) + + class PEFTAttributeConfig(ModeloptBaseConfig): """Configuration for PEFT adapter attributes.""" @@ -52,13 +62,13 @@ class PEFTAttributeConfig(ModeloptBaseConfig): ) lora_a_init: Callable[[object], None] | None = ModeloptField( - default=lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)), + default=default_lora_a_init, title="LoRA A matrix initializer", description="Custom initialization function for LoRA A matrix. Default to Kaiming uniform initialization.", ) lora_b_init: Callable[[object], None] | None = ModeloptField( - default=lambda weight: init.zeros_(weight), + default=default_lora_b_init, title="LoRA B matrix initializer", description="Custom initialization function for LoRA B matrix. Default to zero initialization.", ) @@ -81,16 +91,34 @@ def validate_scale(cls, v): @model_validator(mode="after") def validate_init_functions(self): - """Validate initialization functions are callable.""" + """Validate initialization functions are callable and picklable.""" if self.lora_a_init is not None and not callable(self.lora_a_init): raise ValueError("lora_a_init must be callable") if self.lora_b_init is not None and not callable(self.lora_b_init): raise ValueError("lora_b_init must be callable") + if self.lora_a_init is not None: + try: + _del = pickle.dumps(self.lora_a_init) + del _del + except (pickle.PicklingError, TypeError, AttributeError) as e: + raise ValueError( + f"lora_a_init cannot be pickled: {e}. " + "Please use a module-level function instead of a lambda or nested function." + ) + if self.lora_b_init is not None: + try: + _del = pickle.dumps(self.lora_b_init) + del _del + except (pickle.PicklingError, TypeError, AttributeError) as e: + raise ValueError( + f"lora_b_init cannot be pickled: {e}. " + "Please use a module-level function instead of a lambda or nested function." + ) return self # Type alias for adapter configuration -PEFTAdapterCfgType = dict[str | CallableType, PEFTAttributeConfig | dict] +PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict] class PEFTConfig(ModeloptBaseConfig): diff --git a/modelopt/torch/peft/lora/layer.py b/modelopt/torch/peft/lora/layer.py index f0eadb1d..658017e2 100644 --- a/modelopt/torch/peft/lora/layer.py +++ b/modelopt/torch/peft/lora/layer.py @@ -165,17 +165,9 @@ def set_from_peft_state(self, peft_state: dict[str, Any]) -> None: """ adapters_config = peft_state.get("adapters", {}) - self._lora_adapters.clear() - self._active_adapters.clear() - for adapter_name, config in adapters_config.items(): - self.update_layer_lora(adapter_name, config) - - # Set activation state - if config.get("is_active", False): - self.activate_adapter(adapter_name) - else: - self.deactivate_adapter(adapter_name) + if adapter_name not in self._lora_adapters: + self.update_layer_lora(adapter_name, config) def set_extra_state(self, state: dict[str, Any]) -> None: """Restore extra state for distributed checkpointing. From 935b07ba340561d05c887dc02ca162bb7dabd378 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 19 Sep 2025 22:57:07 +0000 Subject: [PATCH 14/15] Update sharded axis Signed-off-by: Jingyu Xin --- modelopt/torch/peft/convert.py | 3 -- modelopt/torch/peft/lora/layer.py | 2 + modelopt/torch/peft/lora/tp_layer.py | 72 ++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/peft/convert.py b/modelopt/torch/peft/convert.py index c73042b6..b9d7789f 100644 --- a/modelopt/torch/peft/convert.py +++ b/modelopt/torch/peft/convert.py @@ -44,11 +44,8 @@ def update_model( Returns: The updated model with LoRA adapters """ - # Validate config by converting to PEFTConfig if needed - # Check if model is already in PEFT mode by looking for LoRA modules if not is_peft_model(model): - # First time - need to convert to PEFT mode apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry) else: if not isinstance(config, PEFTConfig): diff --git a/modelopt/torch/peft/lora/layer.py b/modelopt/torch/peft/lora/layer.py index 658017e2..d3f39d44 100644 --- a/modelopt/torch/peft/lora/layer.py +++ b/modelopt/torch/peft/lora/layer.py @@ -64,6 +64,8 @@ def _register_adapter( self.add_module(f"lora_b_{adapter_name}", lora_b) # Store in adapter dictionary with explicit rank + if adapter_name in self._lora_adapters: + raise ValueError(f"adapter_name: {adapter_name} is already exist..") self._lora_adapters[adapter_name] = { "lora_a": lora_a, "lora_b": lora_b, diff --git a/modelopt/torch/peft/lora/tp_layer.py b/modelopt/torch/peft/lora/tp_layer.py index eff6054f..ea694c54 100644 --- a/modelopt/torch/peft/lora/tp_layer.py +++ b/modelopt/torch/peft/lora/tp_layer.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.nn.init as init from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from ..config import PEFTAttributeConfig from .layer import LoRAModule, LoRAModuleRegistry @@ -129,6 +130,40 @@ def update_layer_lora( adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable ) + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + """Sharding along axis 0 for ColumnParallelLinear, bias not sharded. + + For ColumnParallelLinear: + - lora_a weight: sharded at dim 0 + - lora_b weight: sharded at dim 0 + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + if hasattr(self, "_lora_adapters"): + lora_state_dict = {} + state_dict = self.state_dict(prefix="", keep_vars=True) + + for adapter_name in self._lora_adapters: + lora_a_key = f"lora_a_{adapter_name}.weight" + lora_b_key = f"lora_b_{adapter_name}.weight" + + if lora_a_key in state_dict: + lora_state_dict[lora_a_key] = state_dict[lora_a_key] + if lora_b_key in state_dict: + lora_state_dict[lora_b_key] = state_dict[lora_b_key] + + lora_sharding_dims = {} + for key in lora_state_dict: + lora_sharding_dims[key] = 0 + + if lora_state_dict: + lora_sharded = make_sharded_tensors_for_checkpoint( + lora_state_dict, prefix, lora_sharding_dims, sharded_offsets + ) + sharded_state_dict.update(lora_sharded) + + return sharded_state_dict + @LoRAModuleRegistry.register({RowParallelLinear: "megatron_RowParallelLinear"}) class _LoRAMegatronRowParallelLinear(_MegatronParallelLoRABase): @@ -172,6 +207,43 @@ def update_layer_lora( adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable ) + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + """Sharding along axis 1 for RowParallelLinear, bias not sharded. + + For RowParallelLinear: + - lora_a weight: sharded at dim 1 (RowParallelLinear) + - lora_b weight: sharded at dim 0 (ColumnParallelLinear) + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + if hasattr(self, "_lora_adapters"): + lora_state_dict = {} + state_dict = self.state_dict() + + for adapter_name in self._lora_adapters: + lora_a_key = f"lora_a_{adapter_name}.weight" + lora_b_key = f"lora_b_{adapter_name}.weight" + + if lora_a_key in state_dict: + lora_state_dict[lora_a_key] = state_dict[lora_a_key] + if lora_b_key in state_dict: + lora_state_dict[lora_b_key] = state_dict[lora_b_key] + + lora_sharding_dims = {} + for key in lora_state_dict: + if "lora_a_" in key: + lora_sharding_dims[key] = 1 + elif "lora_b_" in key: + lora_sharding_dims[key] = 0 + + if lora_state_dict: + lora_sharded = make_sharded_tensors_for_checkpoint( + lora_state_dict, prefix, lora_sharding_dims, sharded_offsets + ) + sharded_state_dict.update(lora_sharded) + + return sharded_state_dict + # Register quantized versions if available if QUANT_MODULES_AVAILABLE: From d6b2e6045847238143d82f2949c0cb0fdffd684e Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 19 Sep 2025 23:28:41 +0000 Subject: [PATCH 15/15] Add disable_adapters enable_adpaters support, removed some codes Signed-off-by: Jingyu Xin --- modelopt/torch/peft/conversion.py | 23 ----- modelopt/torch/peft/convert.py | 151 ++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 23 deletions(-) diff --git a/modelopt/torch/peft/conversion.py b/modelopt/torch/peft/conversion.py index 256cb185..78b64115 100644 --- a/modelopt/torch/peft/conversion.py +++ b/modelopt/torch/peft/conversion.py @@ -191,26 +191,3 @@ def add_adapter(model, config: PEFTConfig): ) return model - - -def _update_peft_metadata_in_state(model: nn.Module) -> None: - """Update the PEFT metadata in the ModeloptStateManager. - - This function updates the metadata to reflect the current state of LoRA adapters - after they have been added or modified. - """ - if not ModeloptStateManager.is_converted(model): - return - - manager = ModeloptStateManager(model) - - current_peft_state = {} - for name, module in model.named_modules(): - if isinstance(module, LoRAModule): - from modelopt.torch.utils import get_unwrapped_name - - unwrapped_name = get_unwrapped_name(name) - current_peft_state[unwrapped_name] = module.get_peft_state() - - if manager._state and manager._last_metadata is not None: - manager._last_metadata["peft_state"] = current_peft_state diff --git a/modelopt/torch/peft/convert.py b/modelopt/torch/peft/convert.py index b9d7789f..710b7382 100644 --- a/modelopt/torch/peft/convert.py +++ b/modelopt/torch/peft/convert.py @@ -15,6 +15,7 @@ """User-facing PEFT API for LoRA module conversion and adapter management.""" +import fnmatch from typing import Any import torch.nn as nn @@ -26,6 +27,14 @@ from .lora.layer import LoRAModule from .mode import PEFTModeRegistry +__all__ = [ + "disable_adapters", + "enable_adapters", + "get_adapter_states", + "is_peft_model", + "update_model", +] + def update_model( model: nn.Module, @@ -67,3 +76,145 @@ def is_peft_model(model: nn.Module) -> bool: True if the model contains LoRA modules, False otherwise """ return any(isinstance(module, LoRAModule) for _, module in model.named_modules()) + + +def _set_adapter_state(model, enable_state, layer_patterns=None, adapter_patterns=None): + """Helper function to set adapter states. + + Args: + model: Model with LoRA adapters + enable_state: Boolean state to set for matching adapters + layer_patterns: Optional list of layer name patterns (wildcards or callables) + adapter_patterns: Optional list of adapter name patterns (wildcards) + """ + assert is_peft_model(model), "It's not a MO-PEFT model" + + def matches_any_pattern(name, patterns, allow_callable=True): + for pattern in patterns: + if isinstance(pattern, str): + if fnmatch.fnmatch(name, pattern): + return True + elif allow_callable and callable(pattern): + if pattern(name): + return True + else: + pattern_type = "pattern" if allow_callable else "adapter pattern" + raise TypeError(f"Unsupported {pattern_type} type: {type(pattern)}") + return False + + for module_name, module in model.named_modules(): + if isinstance(module, LoRAModule): + if layer_patterns is not None: + if not matches_any_pattern(module_name, layer_patterns, allow_callable=True): + continue + + for adapter_name, adapter_dict in module._lora_adapters.items(): + if adapter_patterns is not None: + if not matches_any_pattern( + adapter_name, adapter_patterns, allow_callable=False + ): + continue + + adapter_dict["enable"] = enable_state + + +def disable_adapters(model, layers_to_disable=None, adapters_to_disable=None): + """Disable LoRA adapters in the model. + + Args: + model: Model with LoRA adapters + layers_to_disable: Optional list of layer name patterns (wildcards or callables) + to disable adapters on. If None, disables on all layers. + adapters_to_disable: Optional list of adapter name patterns (wildcards) to disable. + If None, disables all adapters. + + Examples: + # Disable all adapters + disable_adapters(model) + + # Disable adapters only on attention layers + disable_adapters(model, layers_to_disable=["*attention*"]) + + # Disable only "default" adapters + disable_adapters(model, adapters_to_disable=["*default*"]) + + # Disable "default" adapters on attention layers only + disable_adapters(model, layers_to_disable=["*attention*"], adapters_to_disable=["*default*"]) + """ + _set_adapter_state( + model, + enable_state=False, + layer_patterns=layers_to_disable, + adapter_patterns=adapters_to_disable, + ) + + +def enable_adapters(model, layers_to_enable=None, adapters_to_enable=None): + """Enable LoRA adapters in the model. + + Args: + model: Model with LoRA adapters + layers_to_enable: Optional list of layer name patterns (wildcards or callables) + to enable adapters on. If None, enables on all layers. + adapters_to_enable: Optional list of adapter name patterns (wildcards) to enable. + If None, enables all adapters. + + Examples: + # Enable all adapters + enable_adapters(model) + + # Enable adapters only on MLP layers + enable_adapters(model, layers_to_enable=["*mlp*"]) + + # Enable only "finetuned" adapters + enable_adapters(model, adapters_to_enable=["*finetuned*"]) + + # Enable "finetuned" adapters on MLP layers only + enable_adapters(model, layers_to_enable=["*mlp*"], adapters_to_enable=["*finetuned*"]) + """ + _set_adapter_state( + model, + enable_state=True, + layer_patterns=layers_to_enable, + adapter_patterns=adapters_to_enable, + ) + + +def get_adapter_states(model): + """Get the current state of all adapters in the model. + + Args: + model: Model with LoRA adapters + + Returns: + Dict mapping module names to their adapter states + + Example: + >>> states = get_adapter_states(model) + >>> print(states) + { + 'transformer.layers.0.attention': { + 'default': {'enabled': True, 'rank': 32}, + 'finetuned': {'enabled': False, 'rank': 64} + }, + 'transformer.layers.0.mlp': { + 'default': {'enabled': True, 'rank': 32} + } + } + """ + assert is_peft_model(model), "It's not a MO-PEFT model" + + adapter_states = {} + for module_name, module in model.named_modules(): + if isinstance(module, LoRAModule): + module_adapters = {} + for adapter_name, adapter_dict in module._lora_adapters.items(): + module_adapters[adapter_name] = { + "enabled": adapter_dict.get("enable", True), + "rank": adapter_dict.get("rank", "unknown"), + "scale": adapter_dict.get("scale", 1.0), + } + if module_adapters: + adapter_states[module_name] = module_adapters + + return adapter_states