19
19
from accelerate import init_empty_weights
20
20
21
21
# Assuming the module is named "module_matching" - adjust import as needed
22
- from compressed_tensors .utils .match import (
22
+ from compressed_tensors .utils import (
23
+ InternalModule ,
23
24
is_match ,
24
- match_class ,
25
25
match_modules_set ,
26
- match_name ,
27
26
match_named_modules ,
28
27
match_named_parameters ,
29
28
)
29
+ from compressed_tensors .utils .match import _match_class , _match_name
30
30
31
31
32
32
class DummyModel (nn .Module ):
@@ -66,14 +66,14 @@ def __init__(self):
66
66
67
67
68
68
class TestMatchName :
69
- """Test cases for match_name function"""
69
+ """Test cases for _match_name function"""
70
70
71
71
def test_exact_match (self ):
72
72
"""Test exact string matching"""
73
- assert match_name ("layer1" , "layer1" ) == True
74
- assert match_name ("layer1" , "layer2" ) == False
73
+ assert _match_name ("layer1" , "layer1" ) == True
74
+ assert _match_name ("layer1" , "layer2" ) == False
75
75
assert (
76
- match_name (
76
+ _match_name (
77
77
"transformer.layers.0.self_attn.q_proj" ,
78
78
"transformer.layers.0.self_attn.q_proj" ,
79
79
)
@@ -82,14 +82,14 @@ def test_exact_match(self):
82
82
83
83
def test_regex_match (self ):
84
84
"""Test regex matching with "re:" prefix"""
85
- assert match_name ("layer1" , "re:layer.*" ) == True
86
- assert match_name ("layer1" , "re:^layer1$" ) == True
87
- assert match_name ("layer1" , "re:layer2" ) == False
85
+ assert _match_name ("layer1" , "re:layer.*" ) == True
86
+ assert _match_name ("layer1" , "re:^layer1$" ) == True
87
+ assert _match_name ("layer1" , "re:layer2" ) == False
88
88
assert (
89
- match_name ("transformer.layers.0.self_attn.q_proj" , "re:.*q_proj" ) == True
89
+ _match_name ("transformer.layers.0.self_attn.q_proj" , "re:.*q_proj" ) == True
90
90
)
91
91
assert (
92
- match_name (
92
+ _match_name (
93
93
"transformer.layers.0.self_attn.q_proj" ,
94
94
"re:transformer\\ .layers\\ .\\ d+\\ .self_attn\\ ..*_proj$" ,
95
95
)
@@ -98,49 +98,49 @@ def test_regex_match(self):
98
98
99
99
def test_empty_strings (self ):
100
100
"""Test edge cases with empty strings"""
101
- assert match_name ("" , "" ) == True
102
- assert match_name ("layer1" , "" ) == False
103
- assert match_name ("" , "layer1" ) == False
101
+ assert _match_name ("" , "" ) == True
102
+ assert _match_name ("layer1" , "" ) == False
103
+ assert _match_name ("" , "layer1" ) == False
104
104
105
105
def test_regex_special_characters (self ):
106
106
"""Test regex with special characters"""
107
- assert match_name ("layer.1" , "re:layer\\ .1" ) == True
108
- assert match_name ("layer.1" , "re:layer.1" ) == True # . matches any char
109
- assert match_name ("layer_1" , "re:layer_1" ) == True
107
+ assert _match_name ("layer.1" , "re:layer\\ .1" ) == True
108
+ assert _match_name ("layer.1" , "re:layer.1" ) == True # . matches any char
109
+ assert _match_name ("layer_1" , "re:layer_1" ) == True
110
110
111
111
112
112
class TestMatchClass :
113
- """Test cases for match_class function"""
113
+ """Test cases for _match_class function"""
114
114
115
115
def test_direct_class_match (self ):
116
116
"""Test matching direct class names"""
117
117
linear = nn .Linear (10 , 20 )
118
- assert match_class (linear , "Linear" ) == True
119
- assert match_class (linear , "Conv2d" ) == False
118
+ assert _match_class (linear , "Linear" ) == True
119
+ assert _match_class (linear , "Conv2d" ) == False
120
120
121
121
norm = nn .LayerNorm (10 )
122
- assert match_class (norm , "LayerNorm" ) == True
123
- assert match_class (norm , "BatchNorm1d" ) == False
122
+ assert _match_class (norm , "LayerNorm" ) == True
123
+ assert _match_class (norm , "BatchNorm1d" ) == False
124
124
125
125
def test_parent_class_match (self ):
126
126
"""Test matching parent class names"""
127
127
linear = nn .Linear (10 , 20 )
128
- assert match_class (linear , "Module" ) == True
128
+ assert _match_class (linear , "Module" ) == True
129
129
130
130
conv = nn .Conv2d (3 , 16 , 3 )
131
- assert match_class (conv , "Module" ) == True
132
- assert match_class (conv , "_ConvNd" ) == True
131
+ assert _match_class (conv , "Module" ) == True
132
+ assert _match_class (conv , "_ConvNd" ) == True
133
133
134
134
def test_non_torch_module (self ):
135
135
"""Test with non-torch modules"""
136
136
regular_object = object ()
137
- assert match_class (regular_object , "object" ) == False # not a torch.nn.Module
137
+ assert _match_class (regular_object , "object" ) == False # not a torch.nn.Module
138
138
139
139
def test_custom_module (self ):
140
140
"""Test with custom module classes"""
141
141
model = DummyModel ()
142
- assert match_class (model , "DummyModel" ) == True
143
- assert match_class (model , "Module" ) == True
142
+ assert _match_class (model , "DummyModel" ) == True
143
+ assert _match_class (model , "Module" ) == True
144
144
145
145
146
146
class TestIsMatch :
@@ -171,6 +171,15 @@ def test_regex_in_name_match(self):
171
171
assert is_match ("layer1" , linear , "re:layer.*" ) == True
172
172
assert is_match ("layer1" , linear , "re:conv.*" ) == False
173
173
174
+ def test_internal_module_match (self ):
175
+ """Test not matching internal modules"""
176
+
177
+ class InternalLinear (InternalModule , nn .Linear ):
178
+ pass
179
+
180
+ linear = InternalLinear (10 , 20 )
181
+ assert is_match ("layer1" , linear , "re:layer.*" ) == False
182
+
174
183
175
184
class TestMatchNamedModules :
176
185
"""Test cases for match_named_modules function"""
@@ -236,6 +245,16 @@ def test_warn_on_fail(self, mock_logger):
236
245
assert "Could not match" in warning_msg
237
246
assert "nonexistent_module" in warning_msg
238
247
248
+ def test_internal_match (self ):
249
+ """Test not matching internal modules"""
250
+
251
+ class InternalLinear (InternalModule , nn .Linear ):
252
+ pass
253
+
254
+ linear = InternalLinear (10 , 20 )
255
+ matches = list (match_named_modules (linear , ["re:.*" ]))
256
+ assert len (matches ) == 0
257
+
239
258
240
259
class TestMatchNamedParameters :
241
260
"""Test cases for match_named_parameters function"""
@@ -298,6 +317,16 @@ def test_warn_on_fail_parameters(self, mock_logger):
298
317
assert "Could not match" in warning_msg
299
318
assert "nonexistent.param" in warning_msg
300
319
320
+ def test_internal_match (self ):
321
+ """Test not matching internal modules"""
322
+
323
+ class InternalLinear (InternalModule , nn .Linear ):
324
+ pass
325
+
326
+ linear = InternalLinear (10 , 20 )
327
+ matches = list (match_named_parameters (linear , ["re:.*" ]))
328
+ assert len (matches ) == 0
329
+
301
330
302
331
class TestMatchModulesSet :
303
332
"""Test cases for match_modules_set function"""
@@ -377,6 +406,16 @@ def test_module_set_with_ignore(self):
377
406
# Should have 2 sets (layers 1 and 2, but not 0)
378
407
assert len (matches ) == 2
379
408
409
+ def test_internal_match (self ):
410
+ """Test not matching internal modules"""
411
+
412
+ class InternalLinear (InternalModule , nn .Linear ):
413
+ pass
414
+
415
+ linear = InternalLinear (10 , 20 )
416
+ matches = list (match_modules_set (linear , ["re:.*" ]))
417
+ assert len (matches ) == 0
418
+
380
419
381
420
class TestIntegration :
382
421
"""Integration tests combining multiple functions"""
0 commit comments