|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Configuration classes for PEFT methods.""" |
| 17 | + |
| 18 | +import importlib |
| 19 | +import inspect |
| 20 | +from collections.abc import Callable |
| 21 | +from typing import Annotated, Any |
| 22 | + |
| 23 | +import torch.nn.init as init |
| 24 | +from pydantic import PlainSerializer, WithJsonSchema, field_validator |
| 25 | + |
| 26 | +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField |
| 27 | + |
| 28 | +__all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"] |
| 29 | + |
| 30 | +InitRuntimeType = Any |
| 31 | + |
| 32 | + |
| 33 | +def _qualname(fn) -> str: |
| 34 | + m = inspect.getmodule(fn) |
| 35 | + return f"{m.__name__}.{fn.__name__}" if m else getattr(fn, "__name__", str(fn)) |
| 36 | + |
| 37 | + |
| 38 | +InitField = Annotated[ |
| 39 | + InitRuntimeType, |
| 40 | + WithJsonSchema( |
| 41 | + { |
| 42 | + "type": "string", |
| 43 | + "title": "torch initializer", |
| 44 | + "description": ( |
| 45 | + "Fully-qualified callable from ``torch.nn.init``. " |
| 46 | + "Must be in-place (name ends with ``\\_``)." |
| 47 | + ), |
| 48 | + "examples": ["torch.nn.init.zeros\\_", "torch.nn.init.kaiming_uniform\\_"], |
| 49 | + } |
| 50 | + ), |
| 51 | + PlainSerializer(lambda v: _qualname(v), return_type=str, when_used="always"), |
| 52 | +] |
| 53 | + |
| 54 | + |
| 55 | +class PEFTAttributeConfig(ModeloptBaseConfig): |
| 56 | + """Configuration for PEFT adapter attributes.""" |
| 57 | + |
| 58 | + enable: bool = ModeloptField( |
| 59 | + default=True, |
| 60 | + title="Enable adapter", |
| 61 | + description="If True, enables the adapter. If False, by-passes the adapter.", |
| 62 | + ) |
| 63 | + |
| 64 | + rank: int = ModeloptField( |
| 65 | + default=64, |
| 66 | + title="LoRA rank", |
| 67 | + description=( |
| 68 | + "The rank (dimension) of the LoRA matrices. " |
| 69 | + "Higher rank allows more expressiveness but uses more memory." |
| 70 | + ), |
| 71 | + ) |
| 72 | + |
| 73 | + scale: float = ModeloptField( |
| 74 | + default=1.0, |
| 75 | + title="LoRA scaling factor", |
| 76 | + description="Scaling factor for the LoRA output. Controls the magnitude of the adaptation.", |
| 77 | + ) |
| 78 | + |
| 79 | + lora_a_init: InitField = ModeloptField( |
| 80 | + default=init.kaiming_uniform_, |
| 81 | + title="LoRA A matrix initializer", |
| 82 | + description="Initializer from ``torch.nn.init`` (in-place; name ends with ``\\_``).", |
| 83 | + ) |
| 84 | + |
| 85 | + lora_b_init: InitField = ModeloptField( |
| 86 | + default=init.zeros_, |
| 87 | + title="LoRA B matrix initializer", |
| 88 | + description="Initializer from ``torch.nn.init`` (in-place; name ends with ``\\_``).", |
| 89 | + ) |
| 90 | + |
| 91 | + @field_validator("lora_a_init", "lora_b_init", mode="before") |
| 92 | + @classmethod |
| 93 | + def _parse_init_callable(cls, v): |
| 94 | + if isinstance(v, str): |
| 95 | + try: |
| 96 | + module_path, func_name = v.rsplit(".", 1) |
| 97 | + mod = importlib.import_module(module_path) |
| 98 | + v = getattr(mod, func_name) |
| 99 | + except Exception as e: |
| 100 | + raise ValueError( |
| 101 | + f"Could not resolve initializer '{v}' into a callable " |
| 102 | + "(expected a dotted path like 'torch.nn.init.zeros_')." |
| 103 | + ) from e |
| 104 | + return v |
| 105 | + |
| 106 | + @field_validator("lora_a_init", "lora_b_init") |
| 107 | + @classmethod |
| 108 | + def validate_init_method(cls, v): |
| 109 | + """Validate initialization method is supported.""" |
| 110 | + if callable(v): |
| 111 | + module = inspect.getmodule(v) |
| 112 | + if module is not init: |
| 113 | + raise ValueError( |
| 114 | + "Callable initialization method must be from torch.nn.init, " |
| 115 | + f"got {module.__name__ if module else 'unknown'}" |
| 116 | + ) |
| 117 | + func_name = getattr(v, "__name__", "") |
| 118 | + if not func_name.endswith("_"): |
| 119 | + raise ValueError( |
| 120 | + "Initialization method must be in-place (name ends with '_'). " |
| 121 | + "For example: ``torch.nn.init.kaiming_uniform\\_`` not " |
| 122 | + "``torch.nn.init.kaiming_uniform``." |
| 123 | + ) |
| 124 | + else: |
| 125 | + raise ValueError( |
| 126 | + f"Initialization method must be a callable function from torch.nn.init, got {type(v)}" |
| 127 | + ) |
| 128 | + return v |
| 129 | + |
| 130 | + @field_validator("rank") |
| 131 | + @classmethod |
| 132 | + def validate_rank(cls, v): |
| 133 | + """Validate rank is positive.""" |
| 134 | + if v < 1: |
| 135 | + raise ValueError("rank must be a positive integer") |
| 136 | + return v |
| 137 | + |
| 138 | + @field_validator("scale") |
| 139 | + @classmethod |
| 140 | + def validate_scale(cls, v): |
| 141 | + """Validate scale is positive.""" |
| 142 | + if v <= 0: |
| 143 | + raise ValueError("scale must be a positive number") |
| 144 | + return v |
| 145 | + |
| 146 | + |
| 147 | +# Type alias for adapter configuration |
| 148 | +PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict] |
| 149 | + |
| 150 | + |
| 151 | +class PEFTConfig(ModeloptBaseConfig): |
| 152 | + """Default configuration for ``peft`` mode. |
| 153 | +
|
| 154 | + For adapter_cfg, later patterns override earlier ones, for example:: |
| 155 | +
|
| 156 | + "adapter_cfg": { |
| 157 | + "*": { |
| 158 | + "rank": 32, |
| 159 | + "scale": 1, |
| 160 | + "enable": True, |
| 161 | + }, |
| 162 | + "*output_layer*": {"enable": False}, |
| 163 | + } |
| 164 | +
|
| 165 | + If a layer name matches ``"*output_layer*"``, the attributes will be replaced with ``{"enable": False}``. |
| 166 | + """ |
| 167 | + |
| 168 | + adapter_name: str = ModeloptField( |
| 169 | + default="default", |
| 170 | + title="Adapter name", |
| 171 | + description="Name of the adapter to create or update.", |
| 172 | + validate_default=True, |
| 173 | + ) |
| 174 | + |
| 175 | + adapter_cfg: PEFTAdapterCfgType = ModeloptField( |
| 176 | + default={"*": {"rank": 64}}, |
| 177 | + title="Adapter configuration", |
| 178 | + description="Configuration for adapters. Maps module patterns to PEFTAttributeConfig or dict.", |
| 179 | + validate_default=True, |
| 180 | + ) |
| 181 | + |
| 182 | + adapter_type: str = ModeloptField( |
| 183 | + default="lora", |
| 184 | + title="Adapter type", |
| 185 | + description="Type of PEFT adapter to use. Currently only 'lora' is supported.", |
| 186 | + validate_default=True, |
| 187 | + ) |
| 188 | + |
| 189 | + freeze_base_model: bool = ModeloptField( |
| 190 | + default=True, |
| 191 | + title="Freeze base weights during training", |
| 192 | + description="Whether to freeze the base model weights; in most cases, this should be set to True.", |
| 193 | + validate_default=True, |
| 194 | + ) |
| 195 | + |
| 196 | + freeze_lora_weights: bool = ModeloptField( |
| 197 | + default=False, |
| 198 | + title="Freeze lora weights during training", |
| 199 | + description="Whether to freeze the lora model weights; in most cases, this should be set to False.", |
| 200 | + validate_default=True, |
| 201 | + ) |
| 202 | + |
| 203 | + @field_validator("adapter_type") |
| 204 | + @classmethod |
| 205 | + def validate_adapter_type(cls, v): |
| 206 | + """Validate adapter type.""" |
| 207 | + if v not in ["lora"]: |
| 208 | + raise ValueError(f"Unsupported adapter type: {v}. Only 'lora' is currently supported.") |
| 209 | + return v |
| 210 | + |
| 211 | + @field_validator("adapter_cfg") |
| 212 | + @classmethod |
| 213 | + def validate_adapter_cfg(cls, v): |
| 214 | + """Validate and convert adapter configurations.""" |
| 215 | + validated_cfg = {} |
| 216 | + for key, value in v.items(): |
| 217 | + if isinstance(value, dict) and not isinstance(value, PEFTAttributeConfig): |
| 218 | + # Convert dict to PEFTAttributeConfig to trigger validation |
| 219 | + try: |
| 220 | + validated_cfg[key] = PEFTAttributeConfig(**value) |
| 221 | + except Exception as e: |
| 222 | + raise ValueError(f"Invalid adapter configuration for '{key}': {e}") |
| 223 | + else: |
| 224 | + validated_cfg[key] = value |
| 225 | + return validated_cfg |
| 226 | + |
| 227 | + |
| 228 | +class ExportPEFTConfig(ModeloptBaseConfig): |
| 229 | + """An empty config.""" |
0 commit comments