66
77from timm .layers import drop
88
9- torch_backend = os .environ .get (' TORCH_BACKEND' )
9+ torch_backend = os .environ .get (" TORCH_BACKEND" )
1010if torch_backend is not None :
1111 importlib .import_module (torch_backend )
12- torch_device = os .environ .get ('TORCH_DEVICE' , 'cpu' )
12+ torch_device = os .environ .get ("TORCH_DEVICE" , "cpu" )
13+
1314
1415class Conv2dKernelMidpointMask2d (unittest .TestCase ):
1516 def test_conv2d_kernel_midpoint_mask_odd (self ):
@@ -21,15 +22,13 @@ def test_conv2d_kernel_midpoint_mask_odd(self):
2122 )
2223 print (mask )
2324 assert mask .device == torch .device (torch_device )
24- assert mask .tolist () == \
25- [
26- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
27- [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
28- [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
29- [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
30- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
31- ]
32-
25+ assert mask .tolist () == [
26+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
27+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
28+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
29+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
30+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
31+ ]
3332
3433 def test_conv2d_kernel_midpoint_mask_even (self ):
3534 mask = drop .conv2d_kernel_midpoint_mask (
@@ -40,14 +39,13 @@ def test_conv2d_kernel_midpoint_mask_even(self):
4039 )
4140 print (mask )
4241 assert mask .device == torch .device (torch_device )
43- assert mask .tolist () == \
44- [
45- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
46- [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
47- [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
48- [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
49- [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
50- ]
42+ assert mask .tolist () == [
43+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
44+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
45+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
46+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
47+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
48+ ]
5149
5250 def test_clip_mask_2d_kernel_too_big (self ):
5351 try :
@@ -65,67 +63,70 @@ def test_clip_mask_2d_kernel_too_big(self):
6563
6664class DropBlock2dDropFilterTest (unittest .TestCase ):
6765 def test_drop_filter (self ):
68- selection = torch .tensor (
69- [
70- [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0.0 ],
71- [0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
72- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 ],
73- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
74- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 ],
75- ],
76- device = torch_device ,
77- ).unsqueeze (0 ).unsqueeze (0 )
66+ selection = (
67+ torch .tensor (
68+ [
69+ [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0.0 ],
70+ [0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
71+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 ],
72+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
73+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 ],
74+ ],
75+ device = torch_device ,
76+ )
77+ .unsqueeze (0 )
78+ .unsqueeze (0 )
79+ )
7880
7981 result = drop .drop_block_2d_drop_filter_ (
80- selection = selection ,
81- kernel = (2 , 3 ),
82- partial_edge_blocks = False
82+ selection = selection , kernel = (2 , 3 ), partial_edge_blocks = False
8383 ).squeeze ()
8484 print (result )
8585 assert result .device == torch .device (torch_device )
86- assert result .tolist () == \
87- [
88- [1.0 , 1.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
89- [1.0 , 1.0 , 1.0 , 0.0 , 1.0 , 1.0 , 1.0 ],
90- [0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 ],
91- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
92- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
93- ]
86+ assert result .tolist () == [
87+ [1.0 , 1.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
88+ [1.0 , 1.0 , 1.0 , 0.0 , 1.0 , 1.0 , 1.0 ],
89+ [0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 ],
90+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
91+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
92+ ]
9493
9594 def test_drop_filter_partial_edge_blocks (self ):
96- selection = torch .tensor (
97- [
98- [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0.0 ],
99- [0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
100- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 ],
101- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
102- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 ],
103- ],
104- device = torch_device ,
105- dtype = torch .float32 ,
106- ).unsqueeze (0 ).unsqueeze (0 )
95+ selection = (
96+ torch .tensor (
97+ [
98+ [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0.0 ],
99+ [0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
100+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 ],
101+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
102+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 ],
103+ ],
104+ device = torch_device ,
105+ dtype = torch .float32 ,
106+ )
107+ .unsqueeze (0 )
108+ .unsqueeze (0 )
109+ )
107110
108111 result = drop .drop_block_2d_drop_filter_ (
109- selection = selection ,
110- kernel = (2 , 3 ),
111- partial_edge_blocks = True
112+ selection = selection , kernel = (2 , 3 ), partial_edge_blocks = True
112113 ).squeeze ()
113114 print (result )
114115 assert result .device == torch .device (torch_device )
115- assert result .tolist () == \
116- [
117- [1.0 , 1.0 , 1.0 , 1 .0 , 1.0 , 0 .0 , 0 .0 ],
118- [ 1 .0 , 1 .0 , 1 .0 , 0.0 , 1.0 , 1.0 , 1.0 ],
119- [0.0 , 0.0 , 0.0 , 0.0 , 1 .0 , 1.0 , 1.0 ],
120- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 ],
121- [ 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 ],
122- ]
116+ assert result .tolist () == [
117+ [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 , 0.0 ],
118+ [1.0 , 1.0 , 1.0 , 0 .0 , 1.0 , 1 .0 , 1 .0 ],
119+ [ 0 .0 , 0 .0 , 0 .0 , 0.0 , 1.0 , 1.0 , 1.0 ],
120+ [0.0 , 0.0 , 0.0 , 0.0 , 0 .0 , 1.0 , 1.0 ],
121+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 ],
122+ ]
123+
123124
124125class DropBlock2dTest (unittest .TestCase ):
125126 def test_drop_block_2d (self ):
126127 tensor = torch .ones ((1 , 1 , 200 , 300 ), device = torch_device )
127128
128- drop_prob = 0.1
129+ drop_prob = 0.1
129130 keep_prob = 1.0 - drop_prob
130131
131132 result = drop .drop_block_2d (
@@ -138,6 +139,6 @@ def test_drop_block_2d(self):
138139 unchanged = float (len (result [result == 1.0 ]))
139140 keep_ratio = unchanged / numel
140141
141- assert abs ( keep_ratio - keep_prob ) < 0.05 , \
142- f" abs({ keep_ratio = } - { keep_prob = } ) ! < 0.05"
143-
142+ assert (
143+ abs (keep_ratio - keep_prob ) < 0.05
144+ ), f"abs( { keep_ratio = } - { keep_prob = } ) ! < 0.05"
0 commit comments