Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions modelopt/torch/peft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Distillation API subpackage for torch."""

from . import mode
from .config import *
from .conversion import *
from .convert import *

# isort: off
# Import plugins last to avoid circular imports
from . import plugins
174 changes: 174 additions & 0 deletions modelopt/torch/peft/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Configuration classes for PEFT methods."""

import math
import pickle # nosec B403 - Only checking picklability
from collections.abc import Callable

import torch.nn.init as init
from pydantic import field_validator, model_validator

from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField

__all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"]


def default_lora_a_init(weight):
"""Default initialization for LoRA A matrix using Kaiming uniform."""
return init.kaiming_uniform_(weight, a=math.sqrt(5))


def default_lora_b_init(weight):
"""Default initialization for LoRA B matrix using zeros."""
return init.zeros_(weight)


class PEFTAttributeConfig(ModeloptBaseConfig):
"""Configuration for PEFT adapter attributes."""

enable: bool = ModeloptField(
default=True,
title="Enable adapter",
description="If True, enables the adapter. If False, by-passes the adapter.",
)

rank: int = ModeloptField(
default=64,
title="LoRA rank",
description=(
"The rank (dimension) of the LoRA matrices. "
"Higher rank allows more expressiveness but uses more memory."
),
)

scale: float = ModeloptField(
default=1.0,
title="LoRA scaling factor",
description="Scaling factor for the LoRA output. Controls the magnitude of the adaptation.",
)

lora_a_init: Callable[[object], None] | None = ModeloptField(
default=default_lora_a_init,
title="LoRA A matrix initializer",
description="Custom initialization function for LoRA A matrix. Default to Kaiming uniform initialization.",
)

lora_b_init: Callable[[object], None] | None = ModeloptField(
default=default_lora_b_init,
title="LoRA B matrix initializer",
description="Custom initialization function for LoRA B matrix. Default to zero initialization.",
)

@field_validator("rank")
@classmethod
def validate_rank(cls, v):
"""Validate rank is positive."""
if v < 1:
raise ValueError("rank must be a positive integer")
return v

@field_validator("scale")
@classmethod
def validate_scale(cls, v):
"""Validate scale is positive."""
if v <= 0:
raise ValueError("scale must be a positive number")
return v

@model_validator(mode="after")
def validate_init_functions(self):
"""Validate initialization functions are callable and picklable."""
if self.lora_a_init is not None and not callable(self.lora_a_init):
raise ValueError("lora_a_init must be callable")
if self.lora_b_init is not None and not callable(self.lora_b_init):
raise ValueError("lora_b_init must be callable")
if self.lora_a_init is not None:
try:
_del = pickle.dumps(self.lora_a_init)
del _del
except (pickle.PicklingError, TypeError, AttributeError) as e:
raise ValueError(
f"lora_a_init cannot be pickled: {e}. "
"Please use a module-level function instead of a lambda or nested function."
)
if self.lora_b_init is not None:
try:
_del = pickle.dumps(self.lora_b_init)
del _del
except (pickle.PicklingError, TypeError, AttributeError) as e:
raise ValueError(
f"lora_b_init cannot be pickled: {e}. "
"Please use a module-level function instead of a lambda or nested function."
)
return self


# Type alias for adapter configuration
PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict]


class PEFTConfig(ModeloptBaseConfig):
"""Default configuration for ``peft`` mode."""

adapter_name: str = ModeloptField(
default="default",
title="Adapter name",
description="Name of the adapter to create or update.",
validate_default=True,
)

adapter_cfg: PEFTAdapterCfgType = ModeloptField(
default={"default": {"rank": 128}},
title="Adapter configuration",
description="Configuration for adapters. Maps module patterns to PEFTAttributeConfig or dict.",
validate_default=True,
)

adapter_type: str = ModeloptField(
default="lora",
title="Adapter type",
description="Type of PEFT adapter to use. Currently only 'lora' is supported.",
validate_default=True,
)

@field_validator("adapter_type")
@classmethod
def validate_adapter_type(cls, v):
"""Validate adapter type."""
if v not in ["lora"]:
raise ValueError(f"Unsupported adapter type: {v}. Only 'lora' is currently supported.")
return v

@field_validator("adapter_cfg")
@classmethod
def validate_adapter_cfg(cls, v):
"""Validate and convert adapter configurations."""
validated_cfg = {}
for key, value in v.items():
if isinstance(value, dict) and not isinstance(value, PEFTAttributeConfig):
# Convert dict to PEFTAttributeConfig to trigger validation
try:
validated_cfg[key] = PEFTAttributeConfig(**value)
except Exception as e:
raise ValueError(f"Invalid adapter configuration for '{key}': {e}")
else:
validated_cfg[key] = value
return validated_cfg


class ExportPEFTConfig(ModeloptBaseConfig):
"""An empty config."""
193 changes: 193 additions & 0 deletions modelopt/torch/peft/conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""PEFT conversion and restore utilities for LoRA modules."""

import fnmatch
from typing import Any

import torch.nn as nn

from modelopt.torch.opt.conversion import ApplyModeError, ModelLikeModule, ModeloptStateManager
from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict
from modelopt.torch.utils import get_unwrapped_name

from .config import PEFTConfig
from .lora.layer import LoRAModule, LoRAModuleRegistry

__all__ = [
"replace_lora_module",
"update_peft_metadata_in_model",
]


def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> ConvertReturnType:
"""Convert the model to a peft one as per `config`."""
# initialize the true module if necessary
model = model.init_modellike() if isinstance(model, ModelLikeModule) else model

# TODO: Replace to LoRA module
replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config)

metadata = {}
add_adapter(model, config)
# Should return adapaters, active_adapters
update_peft_metadata(model, config, metadata)

return model, metadata


def restore_peft_model(
model: ModelLikeModule, config: PEFTConfig, metadata: MetadataDict
) -> nn.Module:
convert_to_peft_model(model, config)
return restore_peft_state(model, metadata)


def restore_peft_state(model: ModelLikeModule, metadata: MetadataDict):
"""Restore PEFT state from metadata or extra_state.

For backward compatibility, we check metadata first. For distributed
checkpoints (NeMo-MCore), the state will be in extra_state of each LoRAModule
and will be restored automatically via set_extra_state() during load_state_dict().

Args:
model: Model with LoRA modules to restore
metadata: Metadata dictionary that may contain peft_state
Returns:
The model with restored PEFT state
"""
if "peft_state" not in metadata:
# For distributed checkpoints (NeMo-MCore), peft_state is stored
# in each LoRAModule's extra_state and will be restored via
# set_extra_state() during load_state_dict()
return model

# Legacy path: restore from metadata
peft_state_dict = metadata["peft_state"]
for name, module in model.named_modules():
if isinstance(module, LoRAModule):
unwrapped_name = get_unwrapped_name(name)
if unwrapped_name in peft_state_dict:
try:
module.set_from_peft_state(peft_state_dict[unwrapped_name])
except Exception as e:
raise ApplyModeError(f"Failed to restore PEFT state for module {name}: {e}")

return model


def update_peft_metadata(model: nn.Module, config: PEFTConfig, metadata: MetadataDict) -> None:
"""Update the PEFT/LoRA state in the metadata dict."""
metadata["peft_state"] = peft_state(model)


def peft_state(model: nn.Module) -> dict[str, Any]:
return {
get_unwrapped_name(n): m.get_peft_state()
for n, m in model.named_modules()
if isinstance(m, LoRAModule)
}


def replace_lora_module(
model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry
):
"""Recursively replace the module with LoRA module."""
# Register custom plugins (e.g., for Megatron distributed checkpointing)
from .custom import register_custom_model_plugins_on_the_fly

register_custom_model_plugins_on_the_fly(model)

if type(model) in registry:
model = registry.convert(model)
_replace_lora_module(model, version=version, registry=registry)


def export_peft_model(model: nn.Module, config):
raise NotImplementedError("Exporting a peft model is not supported yet.")


def restore_export_peft_model(model: nn.Module, config, metadata: MetadataDict):
raise NotImplementedError("Restoring a peft & exported model is not supported yet.")


def _replace_lora_module(model: nn.Module, version=None, registry=LoRAModuleRegistry):
for name, child in model.named_children():
if type(child) in registry:
lora_module = registry.convert(child)
setattr(model, name, lora_module)

_replace_lora_module(getattr(model, name), version=version, registry=registry)


def update_peft_metadata_in_model(model: nn.Module) -> None:
"""Update the PEFT metadata in the model's ModeloptStateManager.

This function should be called after manually modifying LoRA adapters to ensure
the metadata stored in the ModeloptStateManager reflects the current state.

Args:
model: Model with LoRA modules whose metadata needs updating
Example:
>>> # After manually adding/modifying adapters
>>> for module in model.modules():
... if isinstance(module, LoRAModule):
... module.update_layer_lora("custom_adapter", rank=32)
>>> # Update metadata to reflect changes
>>> update_peft_metadata_in_model(model)
"""
# Check if model has ModeloptStateManager (has been converted with peft mode)
if not ModeloptStateManager.is_converted(model):
return

# Get the state manager
manager = ModeloptStateManager(model)

# Update the metadata with current PEFT state
if manager._state and manager._last_metadata is not None:
manager._last_metadata["peft_state"] = peft_state(model)


def add_adapter(model, config: PEFTConfig):
"""Add a new LoRA adapter to the model.

Args:
model: Model with LoRA modules to add adapters to
config: PEFTConfig instance containing adapter_cfg and adapter_name

Returns:
The model with the new adapter added
"""
adapter_cfg = config.adapter_cfg
adapter_name = config.adapter_name

for name, module in model.named_modules():
if isinstance(module, LoRAModule):
for wildcard_or_filter_func, adapter_setting in adapter_cfg.items():
if isinstance(wildcard_or_filter_func, str):
if not fnmatch.fnmatch(name, wildcard_or_filter_func):
continue
elif callable(wildcard_or_filter_func):
if not wildcard_or_filter_func(name):
continue
else:
raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}")
module.update_layer_lora(
adapter_name,
adapter_setting,
)

return model
Loading
Loading