Skip to content

Commit a7a3186

Browse files
committed
format
1 parent ea5d119 commit a7a3186

File tree

1 file changed

+67
-66
lines changed

1 file changed

+67
-66
lines changed

tests/layers/test_drop.py

Lines changed: 67 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
from timm.layers import drop
88

9-
torch_backend = os.environ.get('TORCH_BACKEND')
9+
torch_backend = os.environ.get("TORCH_BACKEND")
1010
if torch_backend is not None:
1111
importlib.import_module(torch_backend)
12-
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
12+
torch_device = os.environ.get("TORCH_DEVICE", "cpu")
13+
1314

1415
class Conv2dKernelMidpointMask2d(unittest.TestCase):
1516
def test_conv2d_kernel_midpoint_mask_odd(self):
@@ -21,15 +22,13 @@ def test_conv2d_kernel_midpoint_mask_odd(self):
2122
)
2223
print(mask)
2324
assert mask.device == torch.device(torch_device)
24-
assert mask.tolist() == \
25-
[
26-
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
27-
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
28-
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
29-
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
30-
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
31-
]
32-
25+
assert mask.tolist() == [
26+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
27+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
28+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
29+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
30+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
31+
]
3332

3433
def test_conv2d_kernel_midpoint_mask_even(self):
3534
mask = drop.conv2d_kernel_midpoint_mask(
@@ -40,14 +39,13 @@ def test_conv2d_kernel_midpoint_mask_even(self):
4039
)
4140
print(mask)
4241
assert mask.device == torch.device(torch_device)
43-
assert mask.tolist() == \
44-
[
45-
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
46-
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
47-
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
48-
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
49-
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
50-
]
42+
assert mask.tolist() == [
43+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
44+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
45+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
46+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
47+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
48+
]
5149

5250
def test_clip_mask_2d_kernel_too_big(self):
5351
try:
@@ -65,67 +63,70 @@ def test_clip_mask_2d_kernel_too_big(self):
6563

6664
class DropBlock2dDropFilterTest(unittest.TestCase):
6765
def test_drop_filter(self):
68-
selection = torch.tensor(
69-
[
70-
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
71-
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
72-
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
73-
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
74-
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
75-
],
76-
device=torch_device,
77-
).unsqueeze(0).unsqueeze(0)
66+
selection = (
67+
torch.tensor(
68+
[
69+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
70+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
71+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
72+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
73+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
74+
],
75+
device=torch_device,
76+
)
77+
.unsqueeze(0)
78+
.unsqueeze(0)
79+
)
7880

7981
result = drop.drop_block_2d_drop_filter_(
80-
selection=selection,
81-
kernel=(2, 3),
82-
partial_edge_blocks=False
82+
selection=selection, kernel=(2, 3), partial_edge_blocks=False
8383
).squeeze()
8484
print(result)
8585
assert result.device == torch.device(torch_device)
86-
assert result.tolist() == \
87-
[
88-
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
89-
[1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0],
90-
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
91-
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
92-
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
93-
]
86+
assert result.tolist() == [
87+
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
88+
[1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0],
89+
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
90+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
91+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
92+
]
9493

9594
def test_drop_filter_partial_edge_blocks(self):
96-
selection = torch.tensor(
97-
[
98-
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
99-
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
100-
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
101-
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
102-
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
103-
],
104-
device=torch_device,
105-
dtype=torch.float32,
106-
).unsqueeze(0).unsqueeze(0)
95+
selection = (
96+
torch.tensor(
97+
[
98+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
99+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
100+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
101+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
102+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
103+
],
104+
device=torch_device,
105+
dtype=torch.float32,
106+
)
107+
.unsqueeze(0)
108+
.unsqueeze(0)
109+
)
107110

108111
result = drop.drop_block_2d_drop_filter_(
109-
selection=selection,
110-
kernel=(2, 3),
111-
partial_edge_blocks=True
112+
selection=selection, kernel=(2, 3), partial_edge_blocks=True
112113
).squeeze()
113114
print(result)
114115
assert result.device == torch.device(torch_device)
115-
assert result.tolist() == \
116-
[
117-
[1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
118-
[1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0],
119-
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
120-
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
121-
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
122-
]
116+
assert result.tolist() == [
117+
[1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
118+
[1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0],
119+
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
120+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
121+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
122+
]
123+
123124

124125
class DropBlock2dTest(unittest.TestCase):
125126
def test_drop_block_2d(self):
126127
tensor = torch.ones((1, 1, 200, 300), device=torch_device)
127128

128-
drop_prob=0.1
129+
drop_prob = 0.1
129130
keep_prob = 1.0 - drop_prob
130131

131132
result = drop.drop_block_2d(
@@ -138,6 +139,6 @@ def test_drop_block_2d(self):
138139
unchanged = float(len(result[result == 1.0]))
139140
keep_ratio = unchanged / numel
140141

141-
assert abs(keep_ratio - keep_prob) < 0.05, \
142-
f"abs({keep_ratio=} - {keep_prob=}) ! < 0.05"
143-
142+
assert (
143+
abs(keep_ratio - keep_prob) < 0.05
144+
), f"abs({keep_ratio=} - {keep_prob=}) ! < 0.05"

0 commit comments

Comments
 (0)