Skip to content

Commit 1d38784

Browse files
committed
Update restore logic
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 019efb0 commit 1d38784

File tree

2 files changed

+35
-15
lines changed

2 files changed

+35
-15
lines changed

modelopt/torch/peft/config.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
"""Configuration classes for PEFT methods."""
1717

1818
import math
19+
import pickle # nosec B403 - Only checking picklability
1920
from collections.abc import Callable
20-
from collections.abc import Callable as CallableType
2121

2222
import torch.nn.init as init
2323
from pydantic import field_validator, model_validator
@@ -27,6 +27,16 @@
2727
__all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"]
2828

2929

30+
def default_lora_a_init(weight):
31+
"""Default initialization for LoRA A matrix using Kaiming uniform."""
32+
return init.kaiming_uniform_(weight, a=math.sqrt(5))
33+
34+
35+
def default_lora_b_init(weight):
36+
"""Default initialization for LoRA B matrix using zeros."""
37+
return init.zeros_(weight)
38+
39+
3040
class PEFTAttributeConfig(ModeloptBaseConfig):
3141
"""Configuration for PEFT adapter attributes."""
3242

@@ -52,13 +62,13 @@ class PEFTAttributeConfig(ModeloptBaseConfig):
5262
)
5363

5464
lora_a_init: Callable[[object], None] | None = ModeloptField(
55-
default=lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)),
65+
default=default_lora_a_init,
5666
title="LoRA A matrix initializer",
5767
description="Custom initialization function for LoRA A matrix. Default to Kaiming uniform initialization.",
5868
)
5969

6070
lora_b_init: Callable[[object], None] | None = ModeloptField(
61-
default=lambda weight: init.zeros_(weight),
71+
default=default_lora_b_init,
6272
title="LoRA B matrix initializer",
6373
description="Custom initialization function for LoRA B matrix. Default to zero initialization.",
6474
)
@@ -81,16 +91,34 @@ def validate_scale(cls, v):
8191

8292
@model_validator(mode="after")
8393
def validate_init_functions(self):
84-
"""Validate initialization functions are callable."""
94+
"""Validate initialization functions are callable and picklable."""
8595
if self.lora_a_init is not None and not callable(self.lora_a_init):
8696
raise ValueError("lora_a_init must be callable")
8797
if self.lora_b_init is not None and not callable(self.lora_b_init):
8898
raise ValueError("lora_b_init must be callable")
99+
if self.lora_a_init is not None:
100+
try:
101+
_del = pickle.dumps(self.lora_a_init)
102+
del _del
103+
except (pickle.PicklingError, TypeError, AttributeError) as e:
104+
raise ValueError(
105+
f"lora_a_init cannot be pickled: {e}. "
106+
"Please use a module-level function instead of a lambda or nested function."
107+
)
108+
if self.lora_b_init is not None:
109+
try:
110+
_del = pickle.dumps(self.lora_b_init)
111+
del _del
112+
except (pickle.PicklingError, TypeError, AttributeError) as e:
113+
raise ValueError(
114+
f"lora_b_init cannot be pickled: {e}. "
115+
"Please use a module-level function instead of a lambda or nested function."
116+
)
89117
return self
90118

91119

92120
# Type alias for adapter configuration
93-
PEFTAdapterCfgType = dict[str | CallableType, PEFTAttributeConfig | dict]
121+
PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict]
94122

95123

96124
class PEFTConfig(ModeloptBaseConfig):

modelopt/torch/peft/lora/layer.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,9 @@ def set_from_peft_state(self, peft_state: dict[str, Any]) -> None:
165165
"""
166166
adapters_config = peft_state.get("adapters", {})
167167

168-
self._lora_adapters.clear()
169-
self._active_adapters.clear()
170-
171168
for adapter_name, config in adapters_config.items():
172-
self.update_layer_lora(adapter_name, config)
173-
174-
# Set activation state
175-
if config.get("is_active", False):
176-
self.activate_adapter(adapter_name)
177-
else:
178-
self.deactivate_adapter(adapter_name)
169+
if adapter_name not in self._lora_adapters:
170+
self.update_layer_lora(adapter_name, config)
179171

180172
def set_extra_state(self, state: dict[str, Any]) -> None:
181173
"""Restore extra state for distributed checkpointing.

0 commit comments

Comments
 (0)