diff --git a/modelopt/torch/peft/__init__.py b/modelopt/torch/peft/__init__.py new file mode 100644 index 00000000..4874cbeb --- /dev/null +++ b/modelopt/torch/peft/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distillation API subpackage for torch.""" + +from . import mode +from .config import * +from .conversion import * +from .convert import * + +# isort: off +# Import plugins last to avoid circular imports +from . import plugins diff --git a/modelopt/torch/peft/config.py b/modelopt/torch/peft/config.py new file mode 100644 index 00000000..42d12e9f --- /dev/null +++ b/modelopt/torch/peft/config.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration classes for PEFT methods.""" + +import math +import pickle # nosec B403 - Only checking picklability +from collections.abc import Callable + +import torch.nn.init as init +from pydantic import field_validator, model_validator + +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField + +__all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"] + + +def kaiming_init(weight): + """Default initialization for LoRA A matrix using Kaiming uniform.""" + return init.kaiming_uniform_(weight, a=math.sqrt(5)) + + +def zero_init(weight): + """Default initialization for LoRA B matrix using zeros.""" + return init.zeros_(weight) + + +class PEFTAttributeConfig(ModeloptBaseConfig): + """Configuration for PEFT adapter attributes.""" + + enable: bool = ModeloptField( + default=True, + title="Enable adapter", + description="If True, enables the adapter. If False, by-passes the adapter.", + ) + + rank: int = ModeloptField( + default=64, + title="LoRA rank", + description=( + "The rank (dimension) of the LoRA matrices. " + "Higher rank allows more expressiveness but uses more memory." + ), + ) + + scale: float = ModeloptField( + default=1.0, + title="LoRA scaling factor", + description="Scaling factor for the LoRA output. Controls the magnitude of the adaptation.", + ) + + lora_a_init: Callable[[object], None] | None = ModeloptField( + default=kaiming_init, + title="LoRA A matrix initializer", + description="Custom initialization function for LoRA A matrix. Default to Kaiming uniform initialization.", + ) + + lora_b_init: Callable[[object], None] | None = ModeloptField( + default=zero_init, + title="LoRA B matrix initializer", + description="Custom initialization function for LoRA B matrix. Default to zero initialization.", + ) + + @field_validator("rank") + @classmethod + def validate_rank(cls, v): + """Validate rank is positive.""" + if v < 1: + raise ValueError("rank must be a positive integer") + return v + + @field_validator("scale") + @classmethod + def validate_scale(cls, v): + """Validate scale is positive.""" + if v <= 0: + raise ValueError("scale must be a positive number") + return v + + @model_validator(mode="after") + def validate_init_functions(self): + """Validate initialization functions are callable and picklable.""" + if self.lora_a_init is not None and not callable(self.lora_a_init): + raise ValueError("lora_a_init must be callable") + if self.lora_b_init is not None and not callable(self.lora_b_init): + raise ValueError("lora_b_init must be callable") + if self.lora_a_init is not None: + try: + _del = pickle.dumps(self.lora_a_init) + del _del + except (pickle.PicklingError, TypeError, AttributeError) as e: + raise ValueError( + f"lora_a_init cannot be pickled: {e}. " + "Please use a module-level function instead of a lambda or nested function." + ) + if self.lora_b_init is not None: + try: + _del = pickle.dumps(self.lora_b_init) + del _del + except (pickle.PicklingError, TypeError, AttributeError) as e: + raise ValueError( + f"lora_b_init cannot be pickled: {e}. " + "Please use a module-level function instead of a lambda or nested function." + ) + return self + + +# Type alias for adapter configuration +PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict] + + +class PEFTConfig(ModeloptBaseConfig): + """Default configuration for ``peft`` mode.""" + + adapter_name: str = ModeloptField( + default="default", + title="Adapter name", + description="Name of the adapter to create or update.", + validate_default=True, + ) + + adapter_cfg: PEFTAdapterCfgType = ModeloptField( + default={"default": {"rank": 128}}, + title="Adapter configuration", + description="Configuration for adapters. Maps module patterns to PEFTAttributeConfig or dict.", + validate_default=True, + ) + + adapter_type: str = ModeloptField( + default="lora", + title="Adapter type", + description="Type of PEFT adapter to use. Currently only 'lora' is supported.", + validate_default=True, + ) + + @field_validator("adapter_type") + @classmethod + def validate_adapter_type(cls, v): + """Validate adapter type.""" + if v not in ["lora"]: + raise ValueError(f"Unsupported adapter type: {v}. Only 'lora' is currently supported.") + return v + + @field_validator("adapter_cfg") + @classmethod + def validate_adapter_cfg(cls, v): + """Validate and convert adapter configurations.""" + validated_cfg = {} + for key, value in v.items(): + if isinstance(value, dict) and not isinstance(value, PEFTAttributeConfig): + # Convert dict to PEFTAttributeConfig to trigger validation + try: + validated_cfg[key] = PEFTAttributeConfig(**value) + except Exception as e: + raise ValueError(f"Invalid adapter configuration for '{key}': {e}") + else: + validated_cfg[key] = value + return validated_cfg + + +class ExportPEFTConfig(ModeloptBaseConfig): + """An empty config.""" diff --git a/modelopt/torch/peft/conversion.py b/modelopt/torch/peft/conversion.py new file mode 100644 index 00000000..78b64115 --- /dev/null +++ b/modelopt/torch/peft/conversion.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PEFT conversion and restore utilities for LoRA modules.""" + +import fnmatch +from typing import Any + +import torch.nn as nn + +from modelopt.torch.opt.conversion import ApplyModeError, ModelLikeModule, ModeloptStateManager +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict +from modelopt.torch.utils import get_unwrapped_name + +from .config import PEFTConfig +from .lora.layer import LoRAModule, LoRAModuleRegistry + +__all__ = [ + "replace_lora_module", + "update_peft_metadata_in_model", +] + + +def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> ConvertReturnType: + """Convert the model to a peft one as per `config`.""" + # initialize the true module if necessary + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + # TODO: Replace to LoRA module + replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config) + + metadata = {} + add_adapter(model, config) + # Should return adapaters, active_adapters + update_peft_metadata(model, config, metadata) + + return model, metadata + + +def restore_peft_model( + model: ModelLikeModule, config: PEFTConfig, metadata: MetadataDict +) -> nn.Module: + convert_to_peft_model(model, config) + return restore_peft_state(model, metadata) + + +def restore_peft_state(model: ModelLikeModule, metadata: MetadataDict): + """Restore PEFT state from metadata or extra_state. + + For backward compatibility, we check metadata first. For distributed + checkpoints (NeMo-MCore), the state will be in extra_state of each LoRAModule + and will be restored automatically via set_extra_state() during load_state_dict(). + + Args: + model: Model with LoRA modules to restore + metadata: Metadata dictionary that may contain peft_state + Returns: + The model with restored PEFT state + """ + if "peft_state" not in metadata: + # For distributed checkpoints (NeMo-MCore), peft_state is stored + # in each LoRAModule's extra_state and will be restored via + # set_extra_state() during load_state_dict() + return model + + # Legacy path: restore from metadata + peft_state_dict = metadata["peft_state"] + for name, module in model.named_modules(): + if isinstance(module, LoRAModule): + unwrapped_name = get_unwrapped_name(name) + if unwrapped_name in peft_state_dict: + try: + module.set_from_peft_state(peft_state_dict[unwrapped_name]) + except Exception as e: + raise ApplyModeError(f"Failed to restore PEFT state for module {name}: {e}") + + return model + + +def update_peft_metadata(model: nn.Module, config: PEFTConfig, metadata: MetadataDict) -> None: + """Update the PEFT/LoRA state in the metadata dict.""" + metadata["peft_state"] = peft_state(model) + + +def peft_state(model: nn.Module) -> dict[str, Any]: + return { + get_unwrapped_name(n): m.get_peft_state() + for n, m in model.named_modules() + if isinstance(m, LoRAModule) + } + + +def replace_lora_module( + model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry +): + """Recursively replace the module with LoRA module.""" + # Register custom plugins (e.g., for Megatron distributed checkpointing) + from .custom import register_custom_model_plugins_on_the_fly + + register_custom_model_plugins_on_the_fly(model) + + if type(model) in registry: + model = registry.convert(model) + _replace_lora_module(model, version=version, registry=registry) + + +def export_peft_model(model: nn.Module, config): + raise NotImplementedError("Exporting a peft model is not supported yet.") + + +def restore_export_peft_model(model: nn.Module, config, metadata: MetadataDict): + raise NotImplementedError("Restoring a peft & exported model is not supported yet.") + + +def _replace_lora_module(model: nn.Module, version=None, registry=LoRAModuleRegistry): + for name, child in model.named_children(): + if type(child) in registry: + lora_module = registry.convert(child) + setattr(model, name, lora_module) + + _replace_lora_module(getattr(model, name), version=version, registry=registry) + + +def update_peft_metadata_in_model(model: nn.Module) -> None: + """Update the PEFT metadata in the model's ModeloptStateManager. + + This function should be called after manually modifying LoRA adapters to ensure + the metadata stored in the ModeloptStateManager reflects the current state. + + Args: + model: Model with LoRA modules whose metadata needs updating + Example: + >>> # After manually adding/modifying adapters + >>> for module in model.modules(): + ... if isinstance(module, LoRAModule): + ... module.update_layer_lora("custom_adapter", rank=32) + >>> # Update metadata to reflect changes + >>> update_peft_metadata_in_model(model) + """ + # Check if model has ModeloptStateManager (has been converted with peft mode) + if not ModeloptStateManager.is_converted(model): + return + + # Get the state manager + manager = ModeloptStateManager(model) + + # Update the metadata with current PEFT state + if manager._state and manager._last_metadata is not None: + manager._last_metadata["peft_state"] = peft_state(model) + + +def add_adapter(model, config: PEFTConfig): + """Add a new LoRA adapter to the model. + + Args: + model: Model with LoRA modules to add adapters to + config: PEFTConfig instance containing adapter_cfg and adapter_name + + Returns: + The model with the new adapter added + """ + adapter_cfg = config.adapter_cfg + adapter_name = config.adapter_name + + for name, module in model.named_modules(): + if isinstance(module, LoRAModule): + for wildcard_or_filter_func, adapter_setting in adapter_cfg.items(): + if isinstance(wildcard_or_filter_func, str): + if not fnmatch.fnmatch(name, wildcard_or_filter_func): + continue + elif callable(wildcard_or_filter_func): + if not wildcard_or_filter_func(name): + continue + else: + raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") + module.update_layer_lora( + adapter_name, + adapter_setting, + ) + + return model diff --git a/modelopt/torch/peft/convert.py b/modelopt/torch/peft/convert.py new file mode 100644 index 00000000..710b7382 --- /dev/null +++ b/modelopt/torch/peft/convert.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""User-facing PEFT API for LoRA module conversion and adapter management.""" + +import fnmatch +from typing import Any + +import torch.nn as nn + +from modelopt.torch.opt import apply_mode +from modelopt.torch.peft.config import PEFTConfig +from modelopt.torch.peft.conversion import add_adapter + +from .lora.layer import LoRAModule +from .mode import PEFTModeRegistry + +__all__ = [ + "disable_adapters", + "enable_adapters", + "get_adapter_states", + "is_peft_model", + "update_model", +] + + +def update_model( + model: nn.Module, + config: dict[str, Any] | PEFTConfig, +): + """Update model with PEFT/LoRA adapters. + + This function handles both initial PEFT conversion and adding additional adapters: + - First call: Converts modules to LoRAModules and adds the first adapter + - Subsequent calls: Adds new adapters to existing LoRAModules + + Args: + model: The model to update + config: PEFT configuration dict or PEFTConfig instance + + Returns: + The updated model with LoRA adapters + """ + # Check if model is already in PEFT mode by looking for LoRA modules + if not is_peft_model(model): + apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry) + else: + if not isinstance(config, PEFTConfig): + config = PEFTConfig(**config) + add_adapter(model, config) + return model + + +def is_peft_model(model: nn.Module) -> bool: + """Check if the model has been converted to PEFT/LoRA model. + + This function checks if any modules in the model are LoRAModule instances, + which indicates the model has already been converted to PEFT mode. + + Args: + model: The model to check + + Returns: + True if the model contains LoRA modules, False otherwise + """ + return any(isinstance(module, LoRAModule) for _, module in model.named_modules()) + + +def _set_adapter_state(model, enable_state, layer_patterns=None, adapter_patterns=None): + """Helper function to set adapter states. + + Args: + model: Model with LoRA adapters + enable_state: Boolean state to set for matching adapters + layer_patterns: Optional list of layer name patterns (wildcards or callables) + adapter_patterns: Optional list of adapter name patterns (wildcards) + """ + assert is_peft_model(model), "It's not a MO-PEFT model" + + def matches_any_pattern(name, patterns, allow_callable=True): + for pattern in patterns: + if isinstance(pattern, str): + if fnmatch.fnmatch(name, pattern): + return True + elif allow_callable and callable(pattern): + if pattern(name): + return True + else: + pattern_type = "pattern" if allow_callable else "adapter pattern" + raise TypeError(f"Unsupported {pattern_type} type: {type(pattern)}") + return False + + for module_name, module in model.named_modules(): + if isinstance(module, LoRAModule): + if layer_patterns is not None: + if not matches_any_pattern(module_name, layer_patterns, allow_callable=True): + continue + + for adapter_name, adapter_dict in module._lora_adapters.items(): + if adapter_patterns is not None: + if not matches_any_pattern( + adapter_name, adapter_patterns, allow_callable=False + ): + continue + + adapter_dict["enable"] = enable_state + + +def disable_adapters(model, layers_to_disable=None, adapters_to_disable=None): + """Disable LoRA adapters in the model. + + Args: + model: Model with LoRA adapters + layers_to_disable: Optional list of layer name patterns (wildcards or callables) + to disable adapters on. If None, disables on all layers. + adapters_to_disable: Optional list of adapter name patterns (wildcards) to disable. + If None, disables all adapters. + + Examples: + # Disable all adapters + disable_adapters(model) + + # Disable adapters only on attention layers + disable_adapters(model, layers_to_disable=["*attention*"]) + + # Disable only "default" adapters + disable_adapters(model, adapters_to_disable=["*default*"]) + + # Disable "default" adapters on attention layers only + disable_adapters(model, layers_to_disable=["*attention*"], adapters_to_disable=["*default*"]) + """ + _set_adapter_state( + model, + enable_state=False, + layer_patterns=layers_to_disable, + adapter_patterns=adapters_to_disable, + ) + + +def enable_adapters(model, layers_to_enable=None, adapters_to_enable=None): + """Enable LoRA adapters in the model. + + Args: + model: Model with LoRA adapters + layers_to_enable: Optional list of layer name patterns (wildcards or callables) + to enable adapters on. If None, enables on all layers. + adapters_to_enable: Optional list of adapter name patterns (wildcards) to enable. + If None, enables all adapters. + + Examples: + # Enable all adapters + enable_adapters(model) + + # Enable adapters only on MLP layers + enable_adapters(model, layers_to_enable=["*mlp*"]) + + # Enable only "finetuned" adapters + enable_adapters(model, adapters_to_enable=["*finetuned*"]) + + # Enable "finetuned" adapters on MLP layers only + enable_adapters(model, layers_to_enable=["*mlp*"], adapters_to_enable=["*finetuned*"]) + """ + _set_adapter_state( + model, + enable_state=True, + layer_patterns=layers_to_enable, + adapter_patterns=adapters_to_enable, + ) + + +def get_adapter_states(model): + """Get the current state of all adapters in the model. + + Args: + model: Model with LoRA adapters + + Returns: + Dict mapping module names to their adapter states + + Example: + >>> states = get_adapter_states(model) + >>> print(states) + { + 'transformer.layers.0.attention': { + 'default': {'enabled': True, 'rank': 32}, + 'finetuned': {'enabled': False, 'rank': 64} + }, + 'transformer.layers.0.mlp': { + 'default': {'enabled': True, 'rank': 32} + } + } + """ + assert is_peft_model(model), "It's not a MO-PEFT model" + + adapter_states = {} + for module_name, module in model.named_modules(): + if isinstance(module, LoRAModule): + module_adapters = {} + for adapter_name, adapter_dict in module._lora_adapters.items(): + module_adapters[adapter_name] = { + "enabled": adapter_dict.get("enable", True), + "rank": adapter_dict.get("rank", "unknown"), + "scale": adapter_dict.get("scale", 1.0), + } + if module_adapters: + adapter_states[module_name] = module_adapters + + return adapter_states diff --git a/modelopt/torch/peft/custom.py b/modelopt/torch/peft/custom.py new file mode 100644 index 00000000..580efcf5 --- /dev/null +++ b/modelopt/torch/peft/custom.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom PEFT/LoRA plugins registry.""" + +# Registry for custom model plugins +CUSTOM_MODEL_PLUGINS = set() + + +def register_custom_model_plugins_on_the_fly(model): + """Registers custom PEFT/LoRA plugins on the fly. + + This is called before LoRAModule replacement to allow plugins + to configure the model (e.g., for distributed checkpointing). + """ + for callback in CUSTOM_MODEL_PLUGINS: + callback(model) diff --git a/modelopt/torch/peft/lora/__init__.py b/modelopt/torch/peft/lora/__init__.py new file mode 100644 index 00000000..3fed2846 --- /dev/null +++ b/modelopt/torch/peft/lora/__init__.py @@ -0,0 +1,3 @@ +"""LoRA (Low-Rank Adaptation) implementation for parameter-efficient fine-tuning.""" + +from . import layer, tp_layer diff --git a/modelopt/torch/peft/lora/layer.py b/modelopt/torch/peft/lora/layer.py new file mode 100644 index 00000000..d3f39d44 --- /dev/null +++ b/modelopt/torch/peft/lora/layer.py @@ -0,0 +1,240 @@ +"""LoRA (Low-Rank Adaptation) module implementation.""" + +import warnings +from abc import abstractmethod +from typing import Any + +import torch +import torch.nn as nn + +from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls + +from ..config import PEFTAttributeConfig + +__all__ = [ + "LoRAModule", + "LoRAModuleRegistry", +] + + +class LoRAModule(DynamicModule): + """Base class for LoRA (Low-Rank Adaptation) modules. + + This module wraps existing layers and adds trainable low-rank decomposition + matrices (LoRA adapters) that are added to the original layer's output. + + Attributes: + _lora_adapters: Dictionary mapping adapter names to their LoRA A and B matrices + _active_adapters: Set of currently active adapter names + """ + + def _setup(self) -> None: + """Initialize LoRA-specific attributes.""" + self._lora_adapters: dict[str, dict[str, Any]] = {} + + @property + def adapter_names(self) -> set: + """Return the set of all registered adapter names.""" + return set(self._lora_adapters.keys()) + + @property + def active_adapters(self) -> set: + """Return the set of currently active adapter names.""" + return self._active_adapters.copy() + + def _register_adapter( + self, + adapter_name: str, + lora_a: nn.Module, + lora_b: nn.Module, + rank: int, + scale: float = 1.0, + enable: bool = True, + ) -> None: + """Register a new LoRA adapter with explicit rank tracking. + + Args: + adapter_name: Name of the adapter + lora_a: LoRA A module (down-projection) + lora_b: LoRA B module (up-projection) + rank: Rank of the LoRA decomposition + scale: Scale factor for the LoRA output + """ + self.add_module(f"lora_a_{adapter_name}", lora_a) + self.add_module(f"lora_b_{adapter_name}", lora_b) + + # Store in adapter dictionary with explicit rank + if adapter_name in self._lora_adapters: + raise ValueError(f"adapter_name: {adapter_name} is already exist..") + self._lora_adapters[adapter_name] = { + "lora_a": lora_a, + "lora_b": lora_b, + "rank": rank, + "scale": scale, + "enable": enable, + } + + @abstractmethod + def update_layer_lora( + self, + adapter_name: str, + attr_config: PEFTAttributeConfig, + ) -> None: + """Create and register a new LoRA adapter. + + This method must be implemented by subclasses to create the appropriate + LoRA A and B matrices for the specific layer type. + + Args: + adapter_name: Name for the new adapter + rank: Rank of the LoRA decomposition (default: 64) + scale: Scale factor for the LoRA output (default: 1.0) + lora_a_init: Optional initialization function for LoRA A matrix + lora_b_init: Optional initialization function for LoRA B matrix + """ + raise NotImplementedError("Subclasses must implement update_layer_lora") + + def get_peft_state(self) -> dict[str, Any]: + """Get PEFT/LoRA state to be saved in checkpoint. + + This method returns the configuration and state of all LoRA adapters + without including the actual weight tensors. + + Returns: + Dictionary containing: + - adapters: Dict mapping adapter names to their configuration + - active_adapters: List of currently active adapter names + """ + modelopt_state = {} + + # Store adapter configurations + adapters_config = {} + for adapter_name, adapter_modules in self._lora_adapters.items(): + lora_a = adapter_modules["lora_a"] + lora_b = adapter_modules["lora_b"] + + # Get explicitly stored rank for reliability + rank = adapter_modules.get("rank", None) + + # If rank is not stored (legacy case), try to infer it + if rank is None: + if hasattr(lora_a, "output_size"): + rank = lora_a.output_size + elif hasattr(lora_b, "input_size"): + rank = lora_b.input_size + elif hasattr(lora_a, "out_features"): + rank = lora_a.out_features + elif hasattr(lora_b, "in_features"): + rank = lora_b.in_features + + adapters_config[adapter_name] = { + "rank": rank, + "enable": adapter_modules.get("enable", True), + "scale": adapter_modules.get("scale", 1.0), + } + + modelopt_state["adapters"] = adapters_config + + return modelopt_state + + def get_extra_state(self) -> dict[str, Any]: + """Get extra state for distributed checkpointing. + + For distributed/sharded checkpoints (like NeMo-MCore), we store the PEFT state + as extra_state instead of in metadata. This handles cases where module names + change with different parallelism settings (TP, PP, EP). + + Returns: + Dictionary containing the PEFT/LoRA adapter state + """ + # Only return state if we have adapters + if not self._lora_adapters: + return {} + + # Get the current PEFT state + peft_state = self.get_peft_state() + + return {"modelopt_peft_state": peft_state} + + def set_from_peft_state(self, peft_state: dict[str, Any]) -> None: + """Restore LoRA adapters from saved PEFT state. + + This method recreates LoRA adapters based on their saved configuration. + Note: This only restores the adapter structure, not the weights. + + Args: + peft_state: Dictionary containing adapter configurations + """ + adapters_config = peft_state.get("adapters", {}) + + for adapter_name, config in adapters_config.items(): + if adapter_name not in self._lora_adapters: + self.update_layer_lora(adapter_name, config) + + def set_extra_state(self, state: dict[str, Any]) -> None: + """Restore extra state for distributed checkpointing. + + This method is called during load_state_dict() to restore the PEFT/LoRA state + from distributed checkpoints. It handles the adapter configuration but not + the actual weights (which are restored through the normal state_dict mechanism). + + Args: + state: Dictionary containing the extra state to restore + """ + if state is None: + return + + peft_state = state.get("modelopt_peft_state") + if peft_state is None: + return + + # Restore the PEFT state + try: + self.set_from_peft_state(peft_state) + except Exception as e: + warnings.warn( + f"Failed to restore PEFT state from extra_state: {e}. " + "This might happen if the model structure has changed." + ) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> Any: + """Forward pass with LoRA adaptation. + + Args: + x: Input tensor + *args: Additional positional arguments for the base layer + **kwargs: Additional keyword arguments for the base layer + + Returns: + Output from the base layer plus active LoRA adaptations + """ + output = super().forward(x, *args, **kwargs) + + if isinstance(output, tuple): + result = output[0] + other_outputs = output[1:] + else: + result = output + other_outputs = () + + for adapter_name in self._lora_adapters: + adapter = self._lora_adapters[adapter_name] + if adapter["enable"]: + lora_a = adapter["lora_a"] + lora_b = adapter["lora_b"] + lora_a_output = lora_a(x) + if isinstance(lora_a_output, tuple): + lora_a_output = lora_a_output[0] + lora_b_output = lora_b(lora_a_output) + if isinstance(lora_b_output, tuple): + lora_b_output = lora_b_output[0] + scale = adapter["scale"] + result = result + scale * lora_b_output + + if other_outputs: + return (result, *other_outputs) + else: + return result + + +LoRAModuleRegistry = _DMRegistryCls("LoRA", LoRAModule) diff --git a/modelopt/torch/peft/lora/tp_layer.py b/modelopt/torch/peft/lora/tp_layer.py new file mode 100644 index 00000000..ea694c54 --- /dev/null +++ b/modelopt/torch/peft/lora/tp_layer.py @@ -0,0 +1,288 @@ +"""Tensor Parallel LoRA implementations for Megatron layers.""" + +import math +from collections.abc import Callable + +import torch.nn as nn +import torch.nn.init as init +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint + +from ..config import PEFTAttributeConfig +from .layer import LoRAModule, LoRAModuleRegistry + +try: + from modelopt.torch.quantization.plugins.megatron import ( + _MegatronColumnParallelLinear as QuantColumnParallelLinear, + ) + from modelopt.torch.quantization.plugins.megatron import ( + _MegatronRowParallelLinear as QuantRowParallelLinear, + ) + + QUANT_MODULES_AVAILABLE = True +except ImportError: + QUANT_MODULES_AVAILABLE = False + +DEFAULT_LORA_RANK = 64 +DEFAULT_SCALE = 1.0 + + +class _MegatronParallelLoRABase(LoRAModule): + """Base class for Megatron tensor parallel LoRA implementations. + + This class provides common functionality for both ColumnParallel and RowParallel + LoRA implementations, reducing code duplication. + """ + + def _get_init_methods(self, lora_a_init, lora_b_init) -> tuple[Callable, Callable]: + """Get initialization methods for LoRA A and B matrices. + + Returns: + Tuple of (lora_a_init, lora_b_init) initialization functions + """ + if lora_a_init is None: + lora_a_init = lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)) # noqa: E731 # LoRA A: Kaiming uniform + if lora_b_init is None: + lora_b_init = lambda weight: init.zeros_(weight) # noqa: E731 # LoRA B: zeros + return lora_a_init, lora_b_init + + def _register_adapter_with_device( + self, + adapter_name: str, + lora_a: nn.Module, + lora_b: nn.Module, + rank: int, + scale: float, + enable: bool, + ) -> None: + """Register LoRA adapter modules and ensure correct device placement. + + Args: + adapter_name: Name of the adapter + lora_a: LoRA A module (down-projection) + lora_b: LoRA B module (up-projection) + rank: Rank of the LoRA decomposition + """ + # Move LoRA modules to the same device and dtype as the parent module + # Try to get device and dtype from parent module's parameters or buffers + device = None + dtype = None + for p in self.parameters(): + device = p.device + dtype = p.dtype + break + if device is None: + for b in self.buffers(): + device = b.device + dtype = b.dtype + break + + # If we found a device and dtype, move LoRA modules to match + if device is not None: + lora_a = lora_a.to(device) + lora_b = lora_b.to(device) + if dtype is not None: + lora_a = lora_a.to(dtype) + lora_b = lora_b.to(dtype) + + super()._register_adapter(adapter_name, lora_a, lora_b, rank, scale, enable) + + +@LoRAModuleRegistry.register({ColumnParallelLinear: "megatron_ColumnParallelLinear"}) +class _LoRAMegatronColumnParallelLinear(_MegatronParallelLoRABase): + """LoRA implementation for Megatron ColumnParallelLinear layers. + + This implementation creates column-parallel LoRA adapters that match + the parallelization scheme of the base layer. + """ + + def update_layer_lora( + self, + adapter_name: str, + attr_config: PEFTAttributeConfig, + ) -> None: + """Create and register a new LoRA adapter for ColumnParallelLinear. + + Args: + adapter_name: Name for the new adapter + rank: Rank of the LoRA decomposition + """ + lora_a = ColumnParallelLinear( + self.input_size, + attr_config.rank, + config=self.config, + bias=False, + gather_output=True, + init_method=attr_config.lora_a_init, + disable_grad_reduce=getattr(self.config, "sequence_parallel", False), + ) + + lora_b = ColumnParallelLinear( + attr_config.rank, + self.output_size, + config=self.config, + bias=False, + gather_output=False, # Keep output distributed like base layer + init_method=attr_config.lora_a_init, + ) + + self._register_adapter_with_device( + adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable + ) + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + """Sharding along axis 0 for ColumnParallelLinear, bias not sharded. + + For ColumnParallelLinear: + - lora_a weight: sharded at dim 0 + - lora_b weight: sharded at dim 0 + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + if hasattr(self, "_lora_adapters"): + lora_state_dict = {} + state_dict = self.state_dict(prefix="", keep_vars=True) + + for adapter_name in self._lora_adapters: + lora_a_key = f"lora_a_{adapter_name}.weight" + lora_b_key = f"lora_b_{adapter_name}.weight" + + if lora_a_key in state_dict: + lora_state_dict[lora_a_key] = state_dict[lora_a_key] + if lora_b_key in state_dict: + lora_state_dict[lora_b_key] = state_dict[lora_b_key] + + lora_sharding_dims = {} + for key in lora_state_dict: + lora_sharding_dims[key] = 0 + + if lora_state_dict: + lora_sharded = make_sharded_tensors_for_checkpoint( + lora_state_dict, prefix, lora_sharding_dims, sharded_offsets + ) + sharded_state_dict.update(lora_sharded) + + return sharded_state_dict + + +@LoRAModuleRegistry.register({RowParallelLinear: "megatron_RowParallelLinear"}) +class _LoRAMegatronRowParallelLinear(_MegatronParallelLoRABase): + """LoRA implementation for Megatron RowParallelLinear layers. + + This implementation creates row-parallel LoRA adapters that match + the parallelization scheme of the base layer. + """ + + def update_layer_lora( + self, + adapter_name: str, + attr_config: PEFTAttributeConfig, + ) -> None: + """Create and register a new LoRA adapter for RowParallelLinear. + + Args: + adapter_name: Name for the new adapter + rank: Rank of the LoRA decomposition + """ + lora_a = RowParallelLinear( + self.input_size, + attr_config.rank, + config=self.config, + input_is_parallel=True, + skip_bias_add=True, + bias=False, + init_method=attr_config.lora_a_init, + ) + + lora_b = ColumnParallelLinear( + attr_config.rank, + self.output_size, + config=self.config, + bias=False, + gather_output=True, + init_method=attr_config.lora_b_init, + ) + + self._register_adapter_with_device( + adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable + ) + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + """Sharding along axis 1 for RowParallelLinear, bias not sharded. + + For RowParallelLinear: + - lora_a weight: sharded at dim 1 (RowParallelLinear) + - lora_b weight: sharded at dim 0 (ColumnParallelLinear) + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + if hasattr(self, "_lora_adapters"): + lora_state_dict = {} + state_dict = self.state_dict() + + for adapter_name in self._lora_adapters: + lora_a_key = f"lora_a_{adapter_name}.weight" + lora_b_key = f"lora_b_{adapter_name}.weight" + + if lora_a_key in state_dict: + lora_state_dict[lora_a_key] = state_dict[lora_a_key] + if lora_b_key in state_dict: + lora_state_dict[lora_b_key] = state_dict[lora_b_key] + + lora_sharding_dims = {} + for key in lora_state_dict: + if "lora_a_" in key: + lora_sharding_dims[key] = 1 + elif "lora_b_" in key: + lora_sharding_dims[key] = 0 + + if lora_state_dict: + lora_sharded = make_sharded_tensors_for_checkpoint( + lora_state_dict, prefix, lora_sharding_dims, sharded_offsets + ) + sharded_state_dict.update(lora_sharded) + + return sharded_state_dict + + +# Register quantized versions if available +if QUANT_MODULES_AVAILABLE: + LoRAModuleRegistry.register({QuantColumnParallelLinear: "quant_megatron_ColumnParallelLinear"})( + _LoRAMegatronColumnParallelLinear + ) + LoRAModuleRegistry.register({QuantRowParallelLinear: "quant_megatron_RowParallelLinear"})( + _LoRAMegatronRowParallelLinear + ) + + from modelopt.torch.quantization.nn import QuantModuleRegistry + + class _QuantLoRAMegatronColumnParallelLinear( + _LoRAMegatronColumnParallelLinear, QuantColumnParallelLinear + ): + """Quantized LoRA ColumnParallelLinear that combines LoRA and quantization. + + This class ensures that the base layer functionality is quantized while + preserving LoRA adapter functionality. + """ + + def _setup(self): + QuantColumnParallelLinear._setup(self) + + class _QuantLoRAMegatronRowParallelLinear( + _LoRAMegatronRowParallelLinear, QuantRowParallelLinear + ): + """Quantized LoRA RowParallelLinear that combines LoRA and quantization. + + This class ensures that the base layer functionality is quantized while + preserving LoRA adapter functionality. + """ + + def _setup(self): + QuantRowParallelLinear._setup(self) + + QuantModuleRegistry.register( + {_LoRAMegatronColumnParallelLinear: "lora_megatron_ColumnParallelLinear"} + )(_QuantLoRAMegatronColumnParallelLinear) + QuantModuleRegistry.register( + {_LoRAMegatronRowParallelLinear: "lora_megatron_RowParallelLinear"} + )(_QuantLoRAMegatronRowParallelLinear) diff --git a/modelopt/torch/peft/mode.py b/modelopt/torch/peft/mode.py new file mode 100644 index 00000000..8aefb211 --- /dev/null +++ b/modelopt/torch/peft/mode.py @@ -0,0 +1,73 @@ +from modelopt.torch.opt.config import ModeloptBaseConfig +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ConvertReturnType, + ModeConfigList, + ModeDescriptor, + RestoreEntrypoint, + UpdateEntrypoint, + _ModeRegistryCls, +) +from .config import PEFTConfig, ExportPEFTConfig +from .conversion import convert_to_peft_model, restore_peft_model, update_peft_metadata, export_peft_model, restore_export_peft_model + +PEFTModeRegistry = _ModeRegistryCls("PEFT") + +@PEFTModeRegistry.register_mode +class PEFTModeDescriptor(ModeDescriptor): + @property + def name(self) -> str: + return "peft" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + return PEFTConfig + + @property + def export_mode(self) -> str | None: + return "export_peft" + + @property + def convert(self) -> ConvertEntrypoint: + return convert_to_peft_model + + @property + def restore(self) -> RestoreEntrypoint: + return restore_peft_model + + @property + def update_for_save(self) -> UpdateEntrypoint: + return update_peft_metadata + + @property + def update_for_new_mode(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the models state before new mode.""" + return update_peft_metadata + +@PEFTModeRegistry.register_mode +class ExportPEFTModeDescriptor(ModeDescriptor): + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "export_peft" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return ExportPEFTConfig + + @property + def is_export_mode(self) -> bool: + """Specifies whether the mode is an export mode.""" + return True + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return export_peft_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_export_peft_model \ No newline at end of file diff --git a/modelopt/torch/peft/plugins/__init__.py b/modelopt/torch/peft/plugins/__init__.py new file mode 100644 index 00000000..03cd81fe --- /dev/null +++ b/modelopt/torch/peft/plugins/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PEFT/LoRA plugins for various frameworks.""" + +from contextlib import suppress + +with suppress(ImportError): + from . import megatron as _megatron diff --git a/modelopt/torch/peft/plugins/megatron.py b/modelopt/torch/peft/plugins/megatron.py new file mode 100644 index 00000000..6390f7b3 --- /dev/null +++ b/modelopt/torch/peft/plugins/megatron.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron-Core specific PEFT/LoRA plugins.""" + +import torch + +# Import MegatronModule if available +try: + from megatron.core.transformer.module import MegatronModule + + MEGATRON_AVAILABLE = True +except ImportError: + MegatronModule = None + MEGATRON_AVAILABLE = False + +from ..custom import CUSTOM_MODEL_PLUGINS + +__all__ = [] + + +def megatron_replace_lora_module_hook(model: torch.nn.Module): + """Configure Megatron-Core model PEFT/LoRA support. + + This callback is called before the LoRAModule replacement to configure + distributed checkpointing support. For each MegatronModule: + 1. We enable heterogeneous distributed checkpointing + + Note: LoRAModule already has built-in get_extra_state and set_extra_state methods, + so we don't need to register callbacks for them. + """ + if not MEGATRON_AVAILABLE: + return + + for name, module in model.named_modules(): + if isinstance(module, MegatronModule): + # Enable heterogeneous distributed checkpointing + if hasattr(module, "config") and hasattr( + module.config, "hetereogenous_dist_checkpoint" + ): + module.config.hetereogenous_dist_checkpoint = True + + +# Register the hook +CUSTOM_MODEL_PLUGINS.add(megatron_replace_lora_module_hook) diff --git a/tests/gpu/torch/peft/test_forward_megatron.py b/tests/gpu/torch/peft/test_forward_megatron.py new file mode 100644 index 00000000..a4b96286 --- /dev/null +++ b/tests/gpu/torch/peft/test_forward_megatron.py @@ -0,0 +1,605 @@ +# import json +# from copy import deepcopy +# from functools import partial + +# import pytest +# import torch +# import transformers +# # from _test_utils.import_helper import skip_if_no_megatron +# # from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job +# from _test_utils.torch_dist.plugins.megatron_common import get_mcore_gpt_model +# # from _test_utils.torch_model.transformers_models import create_tiny_llama_dir + +# # skip_if_no_megatron(apex_or_te_required=True) + +# import modelopt.torch.speculative as mtsp +# from modelopt.torch.export import export_mcore_gpt_to_hf, import_mcore_gpt_from_hf +# from modelopt.torch.speculative.eagle.default_config import default_eagle_config +# from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel +# from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel +import copy +import os +from warnings import warn + +import torch +import torch.nn.functional as F +from megatron.core import dist_checkpointing +from megatron.core.models.gpt import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.parallel_state import ( + initialize_model_parallel, + is_pipeline_first_stage, + is_pipeline_last_stage, +) +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_config import TransformerConfig + +from modelopt.torch.opt.plugins.mcore_dist_checkpointing import ( + restore_sharded_modelopt_state, + save_sharded_modelopt_state, +) +from modelopt.torch.utils.plugins import megatron_prefill + +try: + from megatron.core.extensions.transformer_engine import TENorm + from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec + + HAS_TE = True +except ImportError as e: + warn(f"Transformer Engine not installed: {e}") + HAS_TE = False + +try: + from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec + from megatron.core.ssm.mamba_layer import MambaLayer + + HAS_MAMBA = True +except ImportError as e: + warn(f"Mamba not installed: {e}") + HAS_MAMBA = False + +try: + import apex # noqa: F401 + + HAS_APEX = True +except ImportError as e: + warn(f"Apex not installed: {e}") + HAS_APEX = False +import modelopt.torch.peft as mtp +import modelopt.torch.quantization as mtq + +lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*": {"rank": 64}, + }, +} + + +def initialize_for_megatron( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234 +): + """Initialize Megatron model parallelism. + + NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`. + """ + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size) + model_parallel_cuda_manual_seed(seed) + + +def get_mcore_gpt_model( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + initialize_megatron: bool = False, + *, + num_layers: int = 2, + num_layers_in_first_pipeline_stage: int | None = None, + num_layers_in_last_pipeline_stage: int | None = None, + hidden_size: int = 64, + num_attention_heads: int = 8, + num_query_groups: int | None = None, + ffn_hidden_size: int | None = 128, + max_sequence_length: int = 16, + vocab_size: int = 64, + activation_func: str = "swiglu", + normalization: str = "LayerNorm", + transformer_impl: str = "modelopt" if HAS_TE else "local", + use_cpu_initialization: bool = False, + bf16: bool = True, +) -> GPTModel: + assert activation_func in ["swiglu", "squared_relu"] + assert normalization in ["LayerNorm", "RMSNorm"] + assert transformer_impl in ["local", "transformer_engine", "modelopt"] + print(f"Using `{transformer_impl=}` model spec for building GPT Model.") + + if initialize_megatron: + initialize_for_megatron(tensor_model_parallel_size, pipeline_model_parallel_size) + + def squared_relu(x): + return torch.pow(F.relu(x), 2) + + config = TransformerConfig( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + sequence_parallel=False, + num_layers=num_layers, + num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, + num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + activation_func=squared_relu if activation_func == "squared_relu" else F.silu, + normalization=normalization, + gated_linear_unit=(activation_func == "swiglu"), + add_bias_linear=False, + use_cpu_initialization=use_cpu_initialization, + pipeline_dtype=torch.bfloat16 if bf16 else torch.float32, + bf16=bf16, + ) + + if transformer_impl == "local": + assert HAS_APEX, "Apex not installed" + transformer_layer_spec = get_gpt_layer_local_spec(normalization=normalization) + else: + assert HAS_TE, "Transformer Engine not installed" + transformer_layer_spec = ( + get_gpt_modelopt_spec(config, remap_te_layernorm=True) + if transformer_impl == "modelopt" + else get_gpt_layer_with_transformer_engine_spec() + ) + + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + pre_process=is_pipeline_first_stage(), + post_process=is_pipeline_last_stage(), + share_embeddings_and_output_weights=False, + position_embedding_type="rope", + ) + if bf16: + model = model.to(torch.bfloat16) + + return model + + +def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): + """Build the model.""" + + if meta_device: + with torch.device("meta"): + gpt_model = get_mcore_gpt_model( + tensor_model_parallel_size=tp_size, + num_layers=4, + ffn_hidden_size=None, + num_attention_heads=4, + activation_func="squared_relu", + transformer_impl="local", + hidden_size=hidden_size, + vocab_size=vocab_size, + use_cpu_initialization=meta_device, + ) + else: + gpt_model = get_mcore_gpt_model( + tensor_model_parallel_size=tp_size, + num_layers=4, + ffn_hidden_size=None, + num_attention_heads=4, + activation_func="squared_relu", + transformer_impl="local", + hidden_size=hidden_size, + vocab_size=vocab_size, + ).cuda() + return gpt_model.eval() + + +from megatron.core import parallel_state + +from tests.gpu.torch.peft.test_forward_megatron import ( + _gpt_model_provider, + initialize_for_megatron, + megatron_prefill, +) + + +def _test_lora_forward(): + """Test LoRA forward pass with Megatron model.""" + # Initialize model parallel groups with proper CUDA RNG seed + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + # Create model + model = _gpt_model_provider(tp_size=1) + + # Create input tokens + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + # Run forward pass + output = megatron_prefill(model, prompt_tokens) + print( + f"Forward pass successful! Output shape: {output.shape if hasattr(output, 'shape') else 'N/A'}" + ) + + # Now test with LoRA + + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + + # # Apply LoRA + model = mtp.update_model(model, lora_config) + print("LoRA adapters added successfully!") + + # Test forward pass with LoRA + output_lora = megatron_prefill(model, prompt_tokens) + print( + f"LoRA forward pass successful! Output shape: {output_lora.shape if hasattr(output_lora, 'shape') else 'N/A'}" + ) + + # Check if LoRA modules were added + lora_count = 0 + for name, module in model.named_modules(): + if hasattr(module, "_lora_adapters"): + lora_count += 1 + print(f"LoRA module found: {name}") + + print(f"\nTotal LoRA modules: {lora_count}") + print("Test passed!") + + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def _test_lora_add_2nd_lora(): + """Test LoRA forward pass with Megatron model.""" + # Initialize model parallel groups with proper CUDA RNG seed + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + # Create model + model = _gpt_model_provider(tp_size=1) + + # Create input tokens + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + # Run forward pass + output = megatron_prefill(model, prompt_tokens) + print( + f"Forward pass successful! Output shape: {output.shape if hasattr(output, 'shape') else 'N/A'}" + ) + + # Now test with LoRA + + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + + lora_2d_config = { + "adapter_type": "lora", + "adapter_name": "2nd", + "adapter_cfg": { + "*attention*": {"rank": 128, "scale": 1}, + "*mlp*": {"rank": 128, "scale": 1}, + }, + } + + # # Apply LoRA + model = mtp.update_model(model, lora_config) + print("LoRA adapters added successfully!") + model = mtp.update_model(model, lora_2d_config) + + # Test forward pass with LoRA + output_lora = megatron_prefill(model, prompt_tokens) + print( + f"LoRA forward pass successful! Output shape: {output_lora.shape if hasattr(output_lora, 'shape') else 'N/A'}" + ) + + # Check if LoRA modules were added + lora_count = 0 + for name, module in model.named_modules(): + if hasattr(module, "_lora_adapters"): + lora_count += 1 + print(f"LoRA module found: {name}") + + print(f"\nTotal LoRA modules: {lora_count}") + print("Test passed!") + + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def _test_lora_save_and_restore(): + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + model_ref = _gpt_model_provider(tp_size=1) + model_test = _gpt_model_provider(tp_size=1) + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + model_ref = mtp.update_model(model_ref, lora_config) + state_dict = copy.deepcopy(model_ref.state_dict()) + tmp_path = "./model_ref" + save_distributed_checkpoint(tmp_path, model_ref) + save_sharded_modelopt_state([model_ref], tmp_path) + restore_sharded_modelopt_state([model_test], tmp_path) + model_test = load_distributed_checkpoint(tmp_path, model_test) + + prompt_tokens = torch.randint( + 0, model_test.vocab_size, (2, model_test.max_sequence_length) + ).cuda() + + # Run forward pass + output_test = megatron_prefill(model_test, prompt_tokens) + output_ref = megatron_prefill(model_ref, prompt_tokens) + print( + f"Forward pass successful! Output shape: {output_test.shape if hasattr(output_test, 'shape') else 'N/A'}" + ) + print(output_test) + print(output_ref) + + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def _test_lora_save_and_restore_with2loras(): + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + model_ref = _gpt_model_provider(tp_size=1) + model_test = _gpt_model_provider(tp_size=1) + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 10}, + }, + } + lora_2d_config = { + "adapter_type": "lora", + "adapter_name": "2nd", + "adapter_cfg": { + "*attention*": {"rank": 128, "scale": 1}, + "*mlp*": {"rank": 128, "scale": 1}, + }, + } + lora_3d_config = { + "adapter_type": "lora", + "adapter_name": "3rd", + "adapter_cfg": { + "*attention*": {"rank": 128, "scale": 1}, + "*mlp*": {"rank": 128, "scale": 1}, + }, + } + model_ref = mtp.update_model(model_ref, lora_config) + model_ref = mtp.update_model(model_ref, lora_2d_config) + tmp_path = "./model_ref" + save_distributed_checkpoint(tmp_path, model_ref) + save_sharded_modelopt_state([model_ref], tmp_path) + restore_sharded_modelopt_state([model_test], tmp_path) + model_test = load_distributed_checkpoint(tmp_path, model_test) + # model_test = mtp.update_model(model_test, lora_3d_config) + + # Debug: Check active adapters + print("\n=== Active Adapters ===") + for name, module in model_test.named_modules(): + if hasattr(module, "_lora_adapters") and module._lora_adapters: + print( + f"{name}: adapters={list(module._lora_adapters.keys())}, active={list(module._active_adapters)}" + ) + break # Just show one module as example + + prompt_tokens = torch.randint( + 0, model_test.vocab_size, (2, model_test.max_sequence_length) + ).cuda() + + # Run forward pass + output_test = megatron_prefill(model_test, prompt_tokens) + output_ref = megatron_prefill(model_ref, prompt_tokens) + print( + f"Forward pass successful! Output shape: {output_test.shape if hasattr(output_test, 'shape') else 'N/A'}" + ) + print(output_test) + print(output_ref) + + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def _test_quantize_then_lora(): + """Test LoRA forward pass with Megatron model.""" + # Initialize model parallel groups with proper CUDA RNG seed + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + model = _gpt_model_provider(tp_size=1) + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + + def forward_func(mod): + output = megatron_prefill(model, prompt_tokens) + + mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_func) + model = mtp.update_model(model, lora_config) + lora_count = 0 + for name, module in model.named_modules(): + if hasattr(module, "_lora_adapters"): + lora_count += 1 + print(f"LoRA module found: {name}") + print(f"\nTotal LoRA modules: {lora_count}") + output_lora_quant = megatron_prefill(model, prompt_tokens) + print( + f"LoRA forward pass successful! Output shape: {output_lora_quant.shape if hasattr(output_lora_quant, 'shape') else 'N/A'}" + ) + print(model) + print("Test passed!") + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def _test_quantize_then_lora_save_restore(): + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + model_ref = _gpt_model_provider(tp_size=1) + model_test = _gpt_model_provider(tp_size=1) + prompt_tokens = torch.randint( + 0, model_test.vocab_size, (2, model_test.max_sequence_length) + ).cuda() + + def forward_func(mod): + output = megatron_prefill(model_ref, prompt_tokens) + + mtq.quantize(model_ref, mtq.FP8_DEFAULT_CFG, forward_func) + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + model_ref = mtp.update_model(model_ref, lora_config) + tmp_path = "./model_ref" + save_distributed_checkpoint(tmp_path, model_ref) + save_sharded_modelopt_state([model_ref], tmp_path) + restore_sharded_modelopt_state([model_test], tmp_path) + model_test = load_distributed_checkpoint(tmp_path, model_test) + # Run forward pass + output_test = megatron_prefill(model_test, prompt_tokens) + output_ref = megatron_prefill(model_ref, prompt_tokens) + print( + f"Forward pass successful! Output shape: {output_test.shape if hasattr(output_test, 'shape') else 'N/A'}" + ) + print(model_ref) + print(f"output_test: {output_test}") + print(f"output_ref: {output_ref}") + + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def _test_lora_then_quantize(): + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234) + + try: + model = _gpt_model_provider(tp_size=1) + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + lora_config = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*attention*": {"rank": 32, "scale": 1}, + "*mlp*": {"rank": 64, "scale": 1}, + }, + } + + def forward_func(mod): + output = megatron_prefill(model, prompt_tokens) + + model = mtp.update_model(model, lora_config) + mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_func) + lora_count = 0 + for name, module in model.named_modules(): + if hasattr(module, "_lora_adapters"): + lora_count += 1 + print(f"LoRA module found: {name}") + print(f"\nTotal LoRA modules: {lora_count}") + output_lora_quant = megatron_prefill(model, prompt_tokens) + print( + f"LoRA forward pass successful! Output shape: {output_lora_quant.shape if hasattr(output_lora_quant, 'shape') else 'N/A'}" + ) + print("Test passed!") + print(model) + finally: + # Clean up model parallel groups + parallel_state.destroy_model_parallel() + + +def _test_lora_then_quantize_save_restore(): + pass + + +def _test_disable_lora(): + pass + + +def _test_disable_lora_restore(): + pass + + +def save_distributed_checkpoint(checkpoint_path, gpt_model): + os.makedirs(checkpoint_path, exist_ok=True) + sharded_state_dict = gpt_model.sharded_state_dict(prefix="") + dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path) + + +def load_distributed_checkpoint(checkpoint_path, gpt_model): + sharded_state_dict = gpt_model.sharded_state_dict(prefix="") + checkpoint = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path + ) + gpt_model.load_state_dict(checkpoint) + return gpt_model + + +def main(): + """Main function to setup distributed and run test.""" + # Setup distributed environment + if not torch.distributed.is_initialized(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) + + try: + # _test_lora_forward() + # _test_lora_save_and_restore() + # _test_lora_add_2nd_lora() + # _test_lora_save_and_restore_with2loras() + # _test_quantize_then_lora() + # _test_quantize_then_lora_save_restore() + _test_lora_then_quantize() + finally: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/gpu/torch/peft/test_megatron_peft.py b/tests/gpu/torch/peft/test_megatron_peft.py new file mode 100644 index 00000000..07120031 --- /dev/null +++ b/tests/gpu/torch/peft/test_megatron_peft.py @@ -0,0 +1,162 @@ +import pytest +import torch +from _test_utils.import_helper import skip_if_no_megatron +from _test_utils.torch_dist.plugins.megatron_common import ( + get_mcore_gpt_model, + initialize_for_megatron, +) + +skip_if_no_megatron() + + +import modelopt.torch.peft as mtp +from modelopt.torch.peft.config import kaiming_init, zero_init +from modelopt.torch.peft.lora.layer import LoRAModule +from modelopt.torch.utils.plugins import megatron_prefill + +DEFAULT_LORA_CFG_TEST = { + "adapter_type": "lora", + "adapter_name": "default", + "adapter_cfg": { + "*": { + "rank": 32, + "scale": 1, + "lora_a_init": kaiming_init, + "lora_b_init": zero_init, + "enable": True, + }, + }, +} + +DEFAULT_LORA_CFG_RANDOM_INIT_TEST = { + "adapter_type": "lora", + "adapter_name": "random", + "adapter_cfg": { + "*": { + "rank": 32, + "scale": 1, + "lora_a_init": kaiming_init, + "lora_b_init": kaiming_init, + "enable": True, + }, + }, +} + + +def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): + """Build the model.""" + + if meta_device: + with torch.device("meta"): + gpt_model = get_mcore_gpt_model( + tensor_model_parallel_size=tp_size, + num_layers=4, + ffn_hidden_size=None, + num_attention_heads=4, + activation_func="squared_relu", + transformer_impl="local", + hidden_size=hidden_size, + vocab_size=vocab_size, + use_cpu_initialization=meta_device, + ) + else: + gpt_model = get_mcore_gpt_model( + tensor_model_parallel_size=tp_size, + num_layers=4, + ffn_hidden_size=None, + num_attention_heads=4, + activation_func="squared_relu", + transformer_impl="local", + hidden_size=hidden_size, + vocab_size=vocab_size, + ).cuda() + return gpt_model.eval() + + +@pytest.mark.parametrize( + "lora_config", + [ + DEFAULT_LORA_CFG_TEST, + DEFAULT_LORA_CFG_RANDOM_INIT_TEST, + ], +) +def test_forward_with_one_lora(lora_config): + hidden_size = 320 + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + model = _gpt_model_provider(tp_size=1, hidden_size=hidden_size) + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + original_output = megatron_prefill(model, prompt_tokens) + mtp.update_model(model, lora_config) + lora_output = megatron_prefill(model, prompt_tokens) + assert lora_output.shape == original_output.shape + if lora_config == DEFAULT_LORA_CFG_TEST: + assert torch.allclose(lora_output, original_output) + else: + assert not torch.allclose(lora_output, original_output) + + mtp.disable_adapters(model) + lora_disabled_output = megatron_prefill(model, prompt_tokens) + assert torch.allclose(lora_disabled_output, original_output) + + for _, module in model.named_modules(): + if isinstance(module, LoRAModule): + assert hasattr(module, f"lora_a_{lora_config['adapter_name']}") + assert hasattr(module, f"lora_b_{lora_config['adapter_name']}") + + +@pytest.mark.parametrize( + "lora_config_1", + [ + DEFAULT_LORA_CFG_TEST, + ], +) +@pytest.mark.parametrize( + "lora_config_2", + [ + DEFAULT_LORA_CFG_RANDOM_INIT_TEST, + ], +) +def test_forward_with_two_loras(lora_config_1, lora_config_2): + hidden_size = 320 + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + model = _gpt_model_provider(tp_size=1, hidden_size=hidden_size) + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + mtp.update_model(model, lora_config_1) + lora_1_output = megatron_prefill(model, prompt_tokens) + mtp.update_model(model, lora_config_2) + lora_2_output = megatron_prefill(model, prompt_tokens) + + assert not torch.allclose(lora_1_output, lora_2_output) + assert lora_1_output.shape == lora_2_output.shape + + for _, module in model.named_modules(): + if isinstance(module, LoRAModule): + assert hasattr(module, f"lora_a_{lora_config_1['adapter_name']}") + assert hasattr(module, f"lora_b_{lora_config_1['adapter_name']}") + + assert hasattr(module, f"lora_a_{lora_config_2['adapter_name']}") + assert hasattr(module, f"lora_b_{lora_config_2['adapter_name']}") + + +def test_forward_with_lora_quantize(): + pass + + +def test_forward_with_quantize_lora(): + pass + + +def test_one_lora_save_restore(): + pass + + +def test_two_loras_save_restore(): + pass + + +def test_one_lora_quantize_save_restore(): + pass + + +def test_two_loras_quantize_save_restore(): + pass