Skip to content

Commit fad4982

Browse files
committed
minor
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 82dc269 commit fad4982

File tree

3 files changed

+77
-60
lines changed

3 files changed

+77
-60
lines changed

modelopt/torch/peft/conversion.py

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515

1616
"""PEFT conversion and restore utilities for LoRA modules."""
1717

18-
import fnmatch
19-
from collections.abc import Callable, Iterable
20-
2118
import torch.nn as nn
2219

2320
from modelopt.torch.opt.conversion import ModelLikeModule, ModeloptStateManager
2421
from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict
22+
from modelopt.torch.utils.network import matches_pattern
2523

2624
from .config import PEFTConfig
2725
from .lora.layer import LoRAModule, LoRAModuleRegistry
@@ -117,7 +115,7 @@ def add_adapter(model, config: PEFTConfig):
117115
# Later patterns override earlier ones
118116
merged_setting = None
119117
for wildcard_or_filter_func, adapter_setting in adapter_cfg.items():
120-
if _matches(name, wildcard_or_filter_func):
118+
if matches_pattern(name, wildcard_or_filter_func):
121119
if merged_setting is None:
122120
merged_setting = adapter_setting.copy()
123121
else:
@@ -134,44 +132,9 @@ def add_adapter(model, config: PEFTConfig):
134132
return model
135133

136134

137-
def _matches(
138-
name: str,
139-
patterns: str | Callable[[str], bool] | Iterable[str | Callable[[str], bool]] | None,
140-
*,
141-
allow_callable: bool = True,
142-
) -> bool:
143-
if patterns is None:
144-
return True
145-
146-
if isinstance(patterns, (str, bytes)):
147-
patterns_iter: Iterable[str | Callable[[str], bool]] = (patterns,)
148-
elif callable(patterns):
149-
if not allow_callable:
150-
raise TypeError("Callable patterns are not supported in this context.")
151-
patterns_iter = (patterns,)
152-
elif isinstance(patterns, Iterable):
153-
patterns_iter = tuple(patterns)
154-
else:
155-
raise TypeError(f"Unsupported pattern type: {type(patterns)}")
156-
157-
for pattern in patterns_iter:
158-
if isinstance(pattern, (str, bytes)):
159-
if fnmatch.fnmatch(name, pattern):
160-
return True
161-
elif callable(pattern):
162-
if not allow_callable:
163-
raise TypeError("Callable patterns are not supported in this context.")
164-
if pattern(name):
165-
return True
166-
else:
167-
raise TypeError(f"Unsupported pattern type: {type(pattern)}")
168-
169-
return False
170-
171-
172135
def _iter_lora_modules(model, layer_patterns=None):
173136
for module_name, module in model.named_modules():
174-
if isinstance(module, LoRAModule) and _matches(module_name, layer_patterns):
137+
if isinstance(module, LoRAModule) and matches_pattern(module_name, layer_patterns):
175138
yield module_name, module
176139

177140

@@ -192,14 +155,14 @@ def _set_base_requires_grad(model, *, requires_grad: bool, layer_patterns=None):
192155
# If layer_patterns is specified, only affect matching layers
193156
if layer_patterns is not None:
194157
module_name = ".".join(name.split(".")[:-1]) # Get module name without param name
195-
if not _matches(module_name, layer_patterns):
158+
if not matches_pattern(module_name, layer_patterns):
196159
continue
197160
param.requires_grad = requires_grad
198161

199162

200163
def _iter_adapter_names(module, adapter_patterns=None):
201164
for adapter_name in module._lora_adapters:
202-
if _matches(adapter_name, adapter_patterns, allow_callable=False):
165+
if matches_pattern(adapter_name, adapter_patterns, allow_callable=False):
203166
yield adapter_name
204167

205168

modelopt/torch/peft/convert.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616
"""User-facing PEFT API for LoRA module conversion and adapter management."""
1717

18-
import fnmatch
1918
from typing import Any
2019

2120
import torch.nn as nn
2221

2322
from modelopt.torch.opt import apply_mode
2423
from modelopt.torch.peft.config import PEFTConfig
2524
from modelopt.torch.peft.conversion import add_adapter
25+
from modelopt.torch.utils.network import matches_pattern
2626

2727
try:
2828
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
@@ -101,30 +101,15 @@ def _set_adapter_state(model, enable_state, layer_patterns=None, adapter_pattern
101101
if not is_peft_model(model):
102102
raise ValueError("Model must be a PEFT model to set adapter states.")
103103

104-
def matches_any_pattern(name, patterns, allow_callable=True):
105-
for pattern in patterns:
106-
if isinstance(pattern, str):
107-
if fnmatch.fnmatch(name, pattern):
108-
return True
109-
elif allow_callable and callable(pattern):
110-
if pattern(name):
111-
return True
112-
else:
113-
pattern_type = "pattern" if allow_callable else "adapter pattern"
114-
raise TypeError(f"Unsupported {pattern_type} type: {type(pattern)}")
115-
return False
116-
117104
for module_name, module in model.named_modules():
118105
if isinstance(module, LoRAModule):
119106
if layer_patterns is not None:
120-
if not matches_any_pattern(module_name, layer_patterns, allow_callable=True):
107+
if not matches_pattern(module_name, layer_patterns, allow_callable=True):
121108
continue
122109

123110
for adapter_name, adapter_dict in module._lora_adapters.items():
124111
if adapter_patterns is not None:
125-
if not matches_any_pattern(
126-
adapter_name, adapter_patterns, allow_callable=False
127-
):
112+
if not matches_pattern(adapter_name, adapter_patterns, allow_callable=False):
128113
continue
129114

130115
adapter_dict["enable"] = enable_state

modelopt/torch/utils/network.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Utility functions for PyTorch models."""
1717

18+
import fnmatch
1819
import inspect
1920
import types
2021
import warnings
@@ -55,6 +56,7 @@ def _convert_to_wrapped_module_name(name: str) -> str:
5556
"is_channels_last",
5657
"is_parallel",
5758
"make_divisible",
59+
"matches_pattern",
5860
"model_to",
5961
"param_num",
6062
"param_num_from_forward",
@@ -87,6 +89,73 @@ def _convert_to_wrapped_module_name(name: str) -> str:
8789
ConstructorLike = Callable | tuple
8890

8991

92+
def matches_pattern(
93+
name: str,
94+
patterns: str | Callable[[str], bool] | Iterable[str | Callable[[str], bool]] | None,
95+
*,
96+
allow_callable: bool = True,
97+
) -> bool:
98+
"""Check if a name matches any of the given patterns.
99+
100+
This utility function checks if a given name (e.g., module name, layer name)
101+
matches any pattern in a collection of patterns. Patterns can be:
102+
- String wildcards (using fnmatch syntax, e.g., "*.attention.*")
103+
- Callable predicates that take a name and return bool
104+
- None (matches everything)
105+
106+
Args:
107+
name: The name to check (e.g., "model.layer1.attention.weight")
108+
patterns: A single pattern, iterable of patterns, or None.
109+
If None, returns True (matches everything).
110+
allow_callable: If False, raises TypeError when encountering callable patterns.
111+
Useful for contexts where only string patterns are allowed.
112+
113+
Returns:
114+
True if the name matches any pattern, False otherwise.
115+
116+
Raises:
117+
TypeError: If pattern type is unsupported or if callable patterns are
118+
provided when allow_callable=False.
119+
120+
Examples:
121+
>>> matches_pattern("model.attention.query", "*.attention.*")
122+
True
123+
>>> matches_pattern("model.mlp.linear", ["*.attention.*", "*.mlp.*"])
124+
True
125+
>>> matches_pattern("model.layer1", lambda x: "layer" in x)
126+
True
127+
>>> matches_pattern("anything", None)
128+
True
129+
"""
130+
if patterns is None:
131+
return True
132+
133+
if isinstance(patterns, (str, bytes)):
134+
patterns_iter: Iterable[str | Callable[[str], bool]] = (patterns,)
135+
elif callable(patterns):
136+
if not allow_callable:
137+
raise TypeError("Callable patterns are not supported in this context.")
138+
patterns_iter = (patterns,)
139+
elif isinstance(patterns, Iterable):
140+
patterns_iter = tuple(patterns)
141+
else:
142+
raise TypeError(f"Unsupported pattern type: {type(patterns)}")
143+
144+
for pattern in patterns_iter:
145+
if isinstance(pattern, (str, bytes)):
146+
if fnmatch.fnmatch(name, pattern):
147+
return True
148+
elif callable(pattern):
149+
if not allow_callable:
150+
raise TypeError("Callable patterns are not supported in this context.")
151+
if pattern(name):
152+
return True
153+
else:
154+
raise TypeError(f"Unsupported pattern type: {type(pattern)}")
155+
156+
return False
157+
158+
90159
def is_parallel(model: nn.Module) -> bool:
91160
"""Check if a PyTorch model is parallelized."""
92161
return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))

0 commit comments

Comments
 (0)