Skip to content

Commit 03560cf

Browse files
committed
Fix explicit comparison to True/False
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent baa97ab commit 03560cf

File tree

1 file changed

+38
-46
lines changed

1 file changed

+38
-46
lines changed

tests/test_utils/test_match.py

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -70,43 +70,35 @@ class TestMatchName:
7070

7171
def test_exact_match(self):
7272
"""Test exact string matching"""
73-
assert _match_name("layer1", "layer1") == True
74-
assert _match_name("layer1", "layer2") == False
75-
assert (
76-
_match_name(
77-
"transformer.layers.0.self_attn.q_proj",
78-
"transformer.layers.0.self_attn.q_proj",
79-
)
80-
== True
73+
assert _match_name("layer1", "layer1")
74+
assert not _match_name("layer1", "layer2")
75+
assert _match_name(
76+
"transformer.layers.0.self_attn.q_proj",
77+
"transformer.layers.0.self_attn.q_proj",
8178
)
8279

8380
def test_regex_match(self):
8481
"""Test regex matching with "re:" prefix"""
85-
assert _match_name("layer1", "re:layer.*") == True
86-
assert _match_name("layer1", "re:^layer1$") == True
87-
assert _match_name("layer1", "re:layer2") == False
88-
assert (
89-
_match_name("transformer.layers.0.self_attn.q_proj", "re:.*q_proj") == True
90-
)
91-
assert (
92-
_match_name(
93-
"transformer.layers.0.self_attn.q_proj",
94-
"re:transformer\\.layers\\.\\d+\\.self_attn\\..*_proj$",
95-
)
96-
== True
82+
assert _match_name("layer1", "re:layer.*")
83+
assert _match_name("layer1", "re:^layer1$")
84+
assert not _match_name("layer1", "re:layer2")
85+
assert _match_name("transformer.layers.0.self_attn.q_proj", "re:.*q_proj")
86+
assert _match_name(
87+
"transformer.layers.0.self_attn.q_proj",
88+
"re:transformer\\.layers\\.\\d+\\.self_attn\\..*_proj$",
9789
)
9890

9991
def test_empty_strings(self):
10092
"""Test edge cases with empty strings"""
101-
assert _match_name("", "") == True
102-
assert _match_name("layer1", "") == False
103-
assert _match_name("", "layer1") == False
93+
assert _match_name("", "")
94+
assert not _match_name("layer1", "")
95+
assert not _match_name("", "layer1")
10496

10597
def test_regex_special_characters(self):
10698
"""Test regex with special characters"""
107-
assert _match_name("layer.1", "re:layer\\.1") == True
108-
assert _match_name("layer.1", "re:layer.1") == True # . matches any char
109-
assert _match_name("layer_1", "re:layer_1") == True
99+
assert _match_name("layer.1", "re:layer\\.1")
100+
assert _match_name("layer.1", "re:layer.1") # . matches any char
101+
assert _match_name("layer_1", "re:layer_1")
110102

111103

112104
class TestMatchClass:
@@ -115,32 +107,32 @@ class TestMatchClass:
115107
def test_direct_class_match(self):
116108
"""Test matching direct class names"""
117109
linear = nn.Linear(10, 20)
118-
assert _match_class(linear, "Linear") == True
119-
assert _match_class(linear, "Conv2d") == False
110+
assert _match_class(linear, "Linear")
111+
assert not _match_class(linear, "Conv2d")
120112

121113
norm = nn.LayerNorm(10)
122-
assert _match_class(norm, "LayerNorm") == True
123-
assert _match_class(norm, "BatchNorm1d") == False
114+
assert _match_class(norm, "LayerNorm")
115+
assert not _match_class(norm, "BatchNorm1d")
124116

125117
def test_parent_class_match(self):
126118
"""Test matching parent class names"""
127119
linear = nn.Linear(10, 20)
128-
assert _match_class(linear, "Module") == True
120+
assert _match_class(linear, "Module")
129121

130122
conv = nn.Conv2d(3, 16, 3)
131-
assert _match_class(conv, "Module") == True
132-
assert _match_class(conv, "_ConvNd") == True
123+
assert _match_class(conv, "Module")
124+
assert _match_class(conv, "_ConvNd")
133125

134126
def test_non_torch_module(self):
135127
"""Test with non-torch modules"""
136128
regular_object = object()
137-
assert _match_class(regular_object, "object") == False # not a torch.nn.Module
129+
assert not _match_class(regular_object, "object") # not a torch.nn.Module
138130

139131
def test_custom_module(self):
140132
"""Test with custom module classes"""
141133
model = DummyModel()
142-
assert _match_class(model, "DummyModel") == True
143-
assert _match_class(model, "Module") == True
134+
assert _match_class(model, "DummyModel")
135+
assert _match_class(model, "Module")
144136

145137

146138
class TestIsMatch:
@@ -149,27 +141,27 @@ class TestIsMatch:
149141
def test_name_match(self):
150142
"""Test matching by name"""
151143
linear = nn.Linear(10, 20)
152-
assert is_match("layer1", linear, "layer1") == True
153-
assert is_match("layer1", linear, "layer2") == False
144+
assert is_match("layer1", linear, "layer1")
145+
assert not is_match("layer1", linear, "layer2")
154146

155147
def test_class_match(self):
156148
"""Test matching by class"""
157149
linear = nn.Linear(10, 20)
158-
assert is_match("layer1", linear, "Linear") == True
159-
assert is_match("layer1", linear, "Conv2d") == False
150+
assert is_match("layer1", linear, "Linear")
151+
assert not is_match("layer1", linear, "Conv2d")
160152

161153
def test_combined_match(self):
162154
"""Test that either name or class match works"""
163155
linear = nn.Linear(10, 20)
164-
assert is_match("layer1", linear, "layer1") == True # name match
165-
assert is_match("layer1", linear, "Linear") == True # class match
166-
assert is_match("layer1", linear, "layer2") == False # no match
156+
assert is_match("layer1", linear, "layer1") # name match
157+
assert is_match("layer1", linear, "Linear") # class match
158+
assert not is_match("layer1", linear, "layer2") # no match
167159

168160
def test_regex_in_name_match(self):
169161
"""Test regex matching in name"""
170162
linear = nn.Linear(10, 20)
171-
assert is_match("layer1", linear, "re:layer.*") == True
172-
assert is_match("layer1", linear, "re:conv.*") == False
163+
assert is_match("layer1", linear, "re:layer.*")
164+
assert not is_match("layer1", linear, "re:conv.*")
173165

174166
def test_internal_module_match(self):
175167
"""Test not matching internal modules"""
@@ -178,7 +170,7 @@ class InternalLinear(InternalModule, nn.Linear):
178170
pass
179171

180172
linear = InternalLinear(10, 20)
181-
assert is_match("layer1", linear, "re:layer.*") == False
173+
assert not is_match("layer1", linear, "re:layer.*")
182174

183175

184176
class TestMatchNamedModules:

0 commit comments

Comments
 (0)