15
15
16
16
"""PEFT conversion and restore utilities for LoRA modules."""
17
17
18
- import fnmatch
19
- from collections .abc import Callable , Iterable
20
-
21
18
import torch .nn as nn
22
19
23
20
from modelopt .torch .opt .conversion import ModelLikeModule , ModeloptStateManager
24
21
from modelopt .torch .opt .mode import ConvertReturnType , MetadataDict
22
+ from modelopt .torch .utils .network import matches_pattern
25
23
26
24
from .config import PEFTConfig
27
25
from .lora .layer import LoRAModule , LoRAModuleRegistry
@@ -117,7 +115,7 @@ def add_adapter(model, config: PEFTConfig):
117
115
# Later patterns override earlier ones
118
116
merged_setting = None
119
117
for wildcard_or_filter_func , adapter_setting in adapter_cfg .items ():
120
- if _matches (name , wildcard_or_filter_func ):
118
+ if matches_pattern (name , wildcard_or_filter_func ):
121
119
if merged_setting is None :
122
120
merged_setting = adapter_setting .copy ()
123
121
else :
@@ -134,44 +132,9 @@ def add_adapter(model, config: PEFTConfig):
134
132
return model
135
133
136
134
137
- def _matches (
138
- name : str ,
139
- patterns : str | Callable [[str ], bool ] | Iterable [str | Callable [[str ], bool ]] | None ,
140
- * ,
141
- allow_callable : bool = True ,
142
- ) -> bool :
143
- if patterns is None :
144
- return True
145
-
146
- if isinstance (patterns , (str , bytes )):
147
- patterns_iter : Iterable [str | Callable [[str ], bool ]] = (patterns ,)
148
- elif callable (patterns ):
149
- if not allow_callable :
150
- raise TypeError ("Callable patterns are not supported in this context." )
151
- patterns_iter = (patterns ,)
152
- elif isinstance (patterns , Iterable ):
153
- patterns_iter = tuple (patterns )
154
- else :
155
- raise TypeError (f"Unsupported pattern type: { type (patterns )} " )
156
-
157
- for pattern in patterns_iter :
158
- if isinstance (pattern , (str , bytes )):
159
- if fnmatch .fnmatch (name , pattern ):
160
- return True
161
- elif callable (pattern ):
162
- if not allow_callable :
163
- raise TypeError ("Callable patterns are not supported in this context." )
164
- if pattern (name ):
165
- return True
166
- else :
167
- raise TypeError (f"Unsupported pattern type: { type (pattern )} " )
168
-
169
- return False
170
-
171
-
172
135
def _iter_lora_modules (model , layer_patterns = None ):
173
136
for module_name , module in model .named_modules ():
174
- if isinstance (module , LoRAModule ) and _matches (module_name , layer_patterns ):
137
+ if isinstance (module , LoRAModule ) and matches_pattern (module_name , layer_patterns ):
175
138
yield module_name , module
176
139
177
140
@@ -192,14 +155,14 @@ def _set_base_requires_grad(model, *, requires_grad: bool, layer_patterns=None):
192
155
# If layer_patterns is specified, only affect matching layers
193
156
if layer_patterns is not None :
194
157
module_name = "." .join (name .split ("." )[:- 1 ]) # Get module name without param name
195
- if not _matches (module_name , layer_patterns ):
158
+ if not matches_pattern (module_name , layer_patterns ):
196
159
continue
197
160
param .requires_grad = requires_grad
198
161
199
162
200
163
def _iter_adapter_names (module , adapter_patterns = None ):
201
164
for adapter_name in module ._lora_adapters :
202
- if _matches (adapter_name , adapter_patterns , allow_callable = False ):
165
+ if matches_pattern (adapter_name , adapter_patterns , allow_callable = False ):
203
166
yield adapter_name
204
167
205
168
0 commit comments