16
16
"""Configuration classes for PEFT methods."""
17
17
18
18
import math
19
+ import pickle # nosec B403 - Only checking picklability
19
20
from collections .abc import Callable
20
- from collections .abc import Callable as CallableType
21
21
22
22
import torch .nn .init as init
23
23
from pydantic import field_validator , model_validator
27
27
__all__ = ["ExportPEFTConfig" , "PEFTAttributeConfig" , "PEFTConfig" ]
28
28
29
29
30
+ def default_lora_a_init (weight ):
31
+ """Default initialization for LoRA A matrix using Kaiming uniform."""
32
+ return init .kaiming_uniform_ (weight , a = math .sqrt (5 ))
33
+
34
+
35
+ def default_lora_b_init (weight ):
36
+ """Default initialization for LoRA B matrix using zeros."""
37
+ return init .zeros_ (weight )
38
+
39
+
30
40
class PEFTAttributeConfig (ModeloptBaseConfig ):
31
41
"""Configuration for PEFT adapter attributes."""
32
42
@@ -52,13 +62,13 @@ class PEFTAttributeConfig(ModeloptBaseConfig):
52
62
)
53
63
54
64
lora_a_init : Callable [[object ], None ] | None = ModeloptField (
55
- default = lambda weight : init . kaiming_uniform_ ( weight , a = math . sqrt ( 5 )) ,
65
+ default = default_lora_a_init ,
56
66
title = "LoRA A matrix initializer" ,
57
67
description = "Custom initialization function for LoRA A matrix. Default to Kaiming uniform initialization." ,
58
68
)
59
69
60
70
lora_b_init : Callable [[object ], None ] | None = ModeloptField (
61
- default = lambda weight : init . zeros_ ( weight ) ,
71
+ default = default_lora_b_init ,
62
72
title = "LoRA B matrix initializer" ,
63
73
description = "Custom initialization function for LoRA B matrix. Default to zero initialization." ,
64
74
)
@@ -81,16 +91,34 @@ def validate_scale(cls, v):
81
91
82
92
@model_validator (mode = "after" )
83
93
def validate_init_functions (self ):
84
- """Validate initialization functions are callable."""
94
+ """Validate initialization functions are callable and picklable ."""
85
95
if self .lora_a_init is not None and not callable (self .lora_a_init ):
86
96
raise ValueError ("lora_a_init must be callable" )
87
97
if self .lora_b_init is not None and not callable (self .lora_b_init ):
88
98
raise ValueError ("lora_b_init must be callable" )
99
+ if self .lora_a_init is not None :
100
+ try :
101
+ _del = pickle .dumps (self .lora_a_init )
102
+ del _del
103
+ except (pickle .PicklingError , TypeError , AttributeError ) as e :
104
+ raise ValueError (
105
+ f"lora_a_init cannot be pickled: { e } . "
106
+ "Please use a module-level function instead of a lambda or nested function."
107
+ )
108
+ if self .lora_b_init is not None :
109
+ try :
110
+ _del = pickle .dumps (self .lora_b_init )
111
+ del _del
112
+ except (pickle .PicklingError , TypeError , AttributeError ) as e :
113
+ raise ValueError (
114
+ f"lora_b_init cannot be pickled: { e } . "
115
+ "Please use a module-level function instead of a lambda or nested function."
116
+ )
89
117
return self
90
118
91
119
92
120
# Type alias for adapter configuration
93
- PEFTAdapterCfgType = dict [str | CallableType , PEFTAttributeConfig | dict ]
121
+ PEFTAdapterCfgType = dict [str | Callable , PEFTAttributeConfig | dict ]
94
122
95
123
96
124
class PEFTConfig (ModeloptBaseConfig ):
0 commit comments