@@ -70,43 +70,35 @@ class TestMatchName:
70
70
71
71
def test_exact_match (self ):
72
72
"""Test exact string matching"""
73
- assert _match_name ("layer1" , "layer1" ) == True
74
- assert _match_name ("layer1" , "layer2" ) == False
75
- assert (
76
- _match_name (
77
- "transformer.layers.0.self_attn.q_proj" ,
78
- "transformer.layers.0.self_attn.q_proj" ,
79
- )
80
- == True
73
+ assert _match_name ("layer1" , "layer1" )
74
+ assert not _match_name ("layer1" , "layer2" )
75
+ assert _match_name (
76
+ "transformer.layers.0.self_attn.q_proj" ,
77
+ "transformer.layers.0.self_attn.q_proj" ,
81
78
)
82
79
83
80
def test_regex_match (self ):
84
81
"""Test regex matching with "re:" prefix"""
85
- assert _match_name ("layer1" , "re:layer.*" ) == True
86
- assert _match_name ("layer1" , "re:^layer1$" ) == True
87
- assert _match_name ("layer1" , "re:layer2" ) == False
88
- assert (
89
- _match_name ("transformer.layers.0.self_attn.q_proj" , "re:.*q_proj" ) == True
90
- )
91
- assert (
92
- _match_name (
93
- "transformer.layers.0.self_attn.q_proj" ,
94
- "re:transformer\\ .layers\\ .\\ d+\\ .self_attn\\ ..*_proj$" ,
95
- )
96
- == True
82
+ assert _match_name ("layer1" , "re:layer.*" )
83
+ assert _match_name ("layer1" , "re:^layer1$" )
84
+ assert not _match_name ("layer1" , "re:layer2" )
85
+ assert _match_name ("transformer.layers.0.self_attn.q_proj" , "re:.*q_proj" )
86
+ assert _match_name (
87
+ "transformer.layers.0.self_attn.q_proj" ,
88
+ "re:transformer\\ .layers\\ .\\ d+\\ .self_attn\\ ..*_proj$" ,
97
89
)
98
90
99
91
def test_empty_strings (self ):
100
92
"""Test edge cases with empty strings"""
101
- assert _match_name ("" , "" ) == True
102
- assert _match_name ("layer1" , "" ) == False
103
- assert _match_name ("" , "layer1" ) == False
93
+ assert _match_name ("" , "" )
94
+ assert not _match_name ("layer1" , "" )
95
+ assert not _match_name ("" , "layer1" )
104
96
105
97
def test_regex_special_characters (self ):
106
98
"""Test regex with special characters"""
107
- assert _match_name ("layer.1" , "re:layer\\ .1" ) == True
108
- assert _match_name ("layer.1" , "re:layer.1" ) == True # . matches any char
109
- assert _match_name ("layer_1" , "re:layer_1" ) == True
99
+ assert _match_name ("layer.1" , "re:layer\\ .1" )
100
+ assert _match_name ("layer.1" , "re:layer.1" ) # . matches any char
101
+ assert _match_name ("layer_1" , "re:layer_1" )
110
102
111
103
112
104
class TestMatchClass :
@@ -115,32 +107,32 @@ class TestMatchClass:
115
107
def test_direct_class_match (self ):
116
108
"""Test matching direct class names"""
117
109
linear = nn .Linear (10 , 20 )
118
- assert _match_class (linear , "Linear" ) == True
119
- assert _match_class (linear , "Conv2d" ) == False
110
+ assert _match_class (linear , "Linear" )
111
+ assert not _match_class (linear , "Conv2d" )
120
112
121
113
norm = nn .LayerNorm (10 )
122
- assert _match_class (norm , "LayerNorm" ) == True
123
- assert _match_class (norm , "BatchNorm1d" ) == False
114
+ assert _match_class (norm , "LayerNorm" )
115
+ assert not _match_class (norm , "BatchNorm1d" )
124
116
125
117
def test_parent_class_match (self ):
126
118
"""Test matching parent class names"""
127
119
linear = nn .Linear (10 , 20 )
128
- assert _match_class (linear , "Module" ) == True
120
+ assert _match_class (linear , "Module" )
129
121
130
122
conv = nn .Conv2d (3 , 16 , 3 )
131
- assert _match_class (conv , "Module" ) == True
132
- assert _match_class (conv , "_ConvNd" ) == True
123
+ assert _match_class (conv , "Module" )
124
+ assert _match_class (conv , "_ConvNd" )
133
125
134
126
def test_non_torch_module (self ):
135
127
"""Test with non-torch modules"""
136
128
regular_object = object ()
137
- assert _match_class (regular_object , "object" ) == False # not a torch.nn.Module
129
+ assert not _match_class (regular_object , "object" ) # not a torch.nn.Module
138
130
139
131
def test_custom_module (self ):
140
132
"""Test with custom module classes"""
141
133
model = DummyModel ()
142
- assert _match_class (model , "DummyModel" ) == True
143
- assert _match_class (model , "Module" ) == True
134
+ assert _match_class (model , "DummyModel" )
135
+ assert _match_class (model , "Module" )
144
136
145
137
146
138
class TestIsMatch :
@@ -149,27 +141,27 @@ class TestIsMatch:
149
141
def test_name_match (self ):
150
142
"""Test matching by name"""
151
143
linear = nn .Linear (10 , 20 )
152
- assert is_match ("layer1" , linear , "layer1" ) == True
153
- assert is_match ("layer1" , linear , "layer2" ) == False
144
+ assert is_match ("layer1" , linear , "layer1" )
145
+ assert not is_match ("layer1" , linear , "layer2" )
154
146
155
147
def test_class_match (self ):
156
148
"""Test matching by class"""
157
149
linear = nn .Linear (10 , 20 )
158
- assert is_match ("layer1" , linear , "Linear" ) == True
159
- assert is_match ("layer1" , linear , "Conv2d" ) == False
150
+ assert is_match ("layer1" , linear , "Linear" )
151
+ assert not is_match ("layer1" , linear , "Conv2d" )
160
152
161
153
def test_combined_match (self ):
162
154
"""Test that either name or class match works"""
163
155
linear = nn .Linear (10 , 20 )
164
- assert is_match ("layer1" , linear , "layer1" ) == True # name match
165
- assert is_match ("layer1" , linear , "Linear" ) == True # class match
166
- assert is_match ("layer1" , linear , "layer2" ) == False # no match
156
+ assert is_match ("layer1" , linear , "layer1" ) # name match
157
+ assert is_match ("layer1" , linear , "Linear" ) # class match
158
+ assert not is_match ("layer1" , linear , "layer2" ) # no match
167
159
168
160
def test_regex_in_name_match (self ):
169
161
"""Test regex matching in name"""
170
162
linear = nn .Linear (10 , 20 )
171
- assert is_match ("layer1" , linear , "re:layer.*" ) == True
172
- assert is_match ("layer1" , linear , "re:conv.*" ) == False
163
+ assert is_match ("layer1" , linear , "re:layer.*" )
164
+ assert not is_match ("layer1" , linear , "re:conv.*" )
173
165
174
166
def test_internal_module_match (self ):
175
167
"""Test not matching internal modules"""
@@ -178,7 +170,7 @@ class InternalLinear(InternalModule, nn.Linear):
178
170
pass
179
171
180
172
linear = InternalLinear (10 , 20 )
181
- assert is_match ("layer1" , linear , "re:layer.*" ) == False
173
+ assert not is_match ("layer1" , linear , "re:layer.*" )
182
174
183
175
184
176
class TestMatchNamedModules :
0 commit comments