Skip to content

Commit c45352e

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/transform_save
2 parents 85419e2 + 3d49764 commit c45352e

File tree

2 files changed

+80
-36
lines changed

2 files changed

+80
-36
lines changed

src/compressed_tensors/utils/match.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Iterable, Tuple
1919

2020
import torch
21+
from compressed_tensors.utils.internal import InternalModule
2122

2223

2324
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -28,8 +29,6 @@
2829
"match_named_parameters",
2930
"match_modules_set",
3031
"is_match",
31-
"match_name",
32-
"match_class",
3332
]
3433

3534

@@ -83,13 +82,16 @@ def match_named_parameters(
8382
"""
8483
unmatched_targets = set(targets)
8584
for module_name, module in model.named_modules():
85+
if isinstance(module, InternalModule):
86+
continue
87+
8688
for param_name, param in module.named_parameters(recurse=False):
8789
param_fqn = f"{module_name}.{param_name}"
8890
for target in targets:
89-
if match_name(param_fqn, target):
91+
if _match_name(param_fqn, target):
9092
unmatched_targets -= {target}
9193

92-
if not any(match_name(param_fqn, ign) for ign in ignore):
94+
if not any(_match_name(param_fqn, ign) for ign in ignore):
9395
yield param_fqn, module, param
9496

9597
if warn_on_fail:
@@ -165,11 +167,14 @@ def match_modules_set(
165167
def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
166168
"""
167169
Returns true if either module name or module parent classes match against target
170+
and the module is not an internal module
168171
"""
169-
return match_name(name, target) or match_class(module, target)
172+
return not isinstance(module, InternalModule) and (
173+
_match_name(name, target) or _match_class(module, target)
174+
)
170175

171176

172-
def match_name(name: str, target: str) -> bool:
177+
def _match_name(name: str, target: str) -> bool:
173178
"""
174179
Returns true if target string begins with "re:" and
175180
regex matches or if target string exactly matches name
@@ -180,7 +185,7 @@ def match_name(name: str, target: str) -> bool:
180185
return target == name
181186

182187

183-
def match_class(module: torch.nn.Module, target: str) -> bool:
188+
def _match_class(module: torch.nn.Module, target: str) -> bool:
184189
"""
185190
Returns true if any torch parent class names match the target string exactly
186191
"""

tests/test_utils/test_match.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
from accelerate import init_empty_weights
2020

2121
# Assuming the module is named "module_matching" - adjust import as needed
22-
from compressed_tensors.utils.match import (
22+
from compressed_tensors.utils import (
23+
InternalModule,
2324
is_match,
24-
match_class,
2525
match_modules_set,
26-
match_name,
2726
match_named_modules,
2827
match_named_parameters,
2928
)
29+
from compressed_tensors.utils.match import _match_class, _match_name
3030

3131

3232
class DummyModel(nn.Module):
@@ -66,14 +66,14 @@ def __init__(self):
6666

6767

6868
class TestMatchName:
69-
"""Test cases for match_name function"""
69+
"""Test cases for _match_name function"""
7070

7171
def test_exact_match(self):
7272
"""Test exact string matching"""
73-
assert match_name("layer1", "layer1") == True
74-
assert match_name("layer1", "layer2") == False
73+
assert _match_name("layer1", "layer1") == True
74+
assert _match_name("layer1", "layer2") == False
7575
assert (
76-
match_name(
76+
_match_name(
7777
"transformer.layers.0.self_attn.q_proj",
7878
"transformer.layers.0.self_attn.q_proj",
7979
)
@@ -82,14 +82,14 @@ def test_exact_match(self):
8282

8383
def test_regex_match(self):
8484
"""Test regex matching with "re:" prefix"""
85-
assert match_name("layer1", "re:layer.*") == True
86-
assert match_name("layer1", "re:^layer1$") == True
87-
assert match_name("layer1", "re:layer2") == False
85+
assert _match_name("layer1", "re:layer.*") == True
86+
assert _match_name("layer1", "re:^layer1$") == True
87+
assert _match_name("layer1", "re:layer2") == False
8888
assert (
89-
match_name("transformer.layers.0.self_attn.q_proj", "re:.*q_proj") == True
89+
_match_name("transformer.layers.0.self_attn.q_proj", "re:.*q_proj") == True
9090
)
9191
assert (
92-
match_name(
92+
_match_name(
9393
"transformer.layers.0.self_attn.q_proj",
9494
"re:transformer\\.layers\\.\\d+\\.self_attn\\..*_proj$",
9595
)
@@ -98,49 +98,49 @@ def test_regex_match(self):
9898

9999
def test_empty_strings(self):
100100
"""Test edge cases with empty strings"""
101-
assert match_name("", "") == True
102-
assert match_name("layer1", "") == False
103-
assert match_name("", "layer1") == False
101+
assert _match_name("", "") == True
102+
assert _match_name("layer1", "") == False
103+
assert _match_name("", "layer1") == False
104104

105105
def test_regex_special_characters(self):
106106
"""Test regex with special characters"""
107-
assert match_name("layer.1", "re:layer\\.1") == True
108-
assert match_name("layer.1", "re:layer.1") == True # . matches any char
109-
assert match_name("layer_1", "re:layer_1") == True
107+
assert _match_name("layer.1", "re:layer\\.1") == True
108+
assert _match_name("layer.1", "re:layer.1") == True # . matches any char
109+
assert _match_name("layer_1", "re:layer_1") == True
110110

111111

112112
class TestMatchClass:
113-
"""Test cases for match_class function"""
113+
"""Test cases for _match_class function"""
114114

115115
def test_direct_class_match(self):
116116
"""Test matching direct class names"""
117117
linear = nn.Linear(10, 20)
118-
assert match_class(linear, "Linear") == True
119-
assert match_class(linear, "Conv2d") == False
118+
assert _match_class(linear, "Linear") == True
119+
assert _match_class(linear, "Conv2d") == False
120120

121121
norm = nn.LayerNorm(10)
122-
assert match_class(norm, "LayerNorm") == True
123-
assert match_class(norm, "BatchNorm1d") == False
122+
assert _match_class(norm, "LayerNorm") == True
123+
assert _match_class(norm, "BatchNorm1d") == False
124124

125125
def test_parent_class_match(self):
126126
"""Test matching parent class names"""
127127
linear = nn.Linear(10, 20)
128-
assert match_class(linear, "Module") == True
128+
assert _match_class(linear, "Module") == True
129129

130130
conv = nn.Conv2d(3, 16, 3)
131-
assert match_class(conv, "Module") == True
132-
assert match_class(conv, "_ConvNd") == True
131+
assert _match_class(conv, "Module") == True
132+
assert _match_class(conv, "_ConvNd") == True
133133

134134
def test_non_torch_module(self):
135135
"""Test with non-torch modules"""
136136
regular_object = object()
137-
assert match_class(regular_object, "object") == False # not a torch.nn.Module
137+
assert _match_class(regular_object, "object") == False # not a torch.nn.Module
138138

139139
def test_custom_module(self):
140140
"""Test with custom module classes"""
141141
model = DummyModel()
142-
assert match_class(model, "DummyModel") == True
143-
assert match_class(model, "Module") == True
142+
assert _match_class(model, "DummyModel") == True
143+
assert _match_class(model, "Module") == True
144144

145145

146146
class TestIsMatch:
@@ -171,6 +171,15 @@ def test_regex_in_name_match(self):
171171
assert is_match("layer1", linear, "re:layer.*") == True
172172
assert is_match("layer1", linear, "re:conv.*") == False
173173

174+
def test_internal_module_match(self):
175+
"""Test not matching internal modules"""
176+
177+
class InternalLinear(InternalModule, nn.Linear):
178+
pass
179+
180+
linear = InternalLinear(10, 20)
181+
assert is_match("layer1", linear, "re:layer.*") == False
182+
174183

175184
class TestMatchNamedModules:
176185
"""Test cases for match_named_modules function"""
@@ -236,6 +245,16 @@ def test_warn_on_fail(self, mock_logger):
236245
assert "Could not match" in warning_msg
237246
assert "nonexistent_module" in warning_msg
238247

248+
def test_internal_match(self):
249+
"""Test not matching internal modules"""
250+
251+
class InternalLinear(InternalModule, nn.Linear):
252+
pass
253+
254+
linear = InternalLinear(10, 20)
255+
matches = list(match_named_modules(linear, ["re:.*"]))
256+
assert len(matches) == 0
257+
239258

240259
class TestMatchNamedParameters:
241260
"""Test cases for match_named_parameters function"""
@@ -298,6 +317,16 @@ def test_warn_on_fail_parameters(self, mock_logger):
298317
assert "Could not match" in warning_msg
299318
assert "nonexistent.param" in warning_msg
300319

320+
def test_internal_match(self):
321+
"""Test not matching internal modules"""
322+
323+
class InternalLinear(InternalModule, nn.Linear):
324+
pass
325+
326+
linear = InternalLinear(10, 20)
327+
matches = list(match_named_parameters(linear, ["re:.*"]))
328+
assert len(matches) == 0
329+
301330

302331
class TestMatchModulesSet:
303332
"""Test cases for match_modules_set function"""
@@ -377,6 +406,16 @@ def test_module_set_with_ignore(self):
377406
# Should have 2 sets (layers 1 and 2, but not 0)
378407
assert len(matches) == 2
379408

409+
def test_internal_match(self):
410+
"""Test not matching internal modules"""
411+
412+
class InternalLinear(InternalModule, nn.Linear):
413+
pass
414+
415+
linear = InternalLinear(10, 20)
416+
matches = list(match_modules_set(linear, ["re:.*"]))
417+
assert len(matches) == 0
418+
380419

381420
class TestIntegration:
382421
"""Integration tests combining multiple functions"""

0 commit comments

Comments
 (0)