1313
1414class Conv2dKernelMidpointMask2d (unittest .TestCase ):
1515 def test_conv2d_kernel_midpoint_mask_odd_bool (self ):
16- mask = drop .conv2d_kernel_midpoint_mask (shape = (5 , 7 ), kernel = (3 , 3 ), device = torch_device )
16+ mask = drop .conv2d_kernel_midpoint_mask (
17+ shape = (5 , 7 ),
18+ kernel = (3 , 3 ),
19+ device = torch_device ,
20+ dtype = torch .bool ,
21+ )
1722 print (mask )
1823 assert mask .device == torch .device (torch_device )
1924 assert mask .tolist () == \
@@ -25,32 +30,6 @@ def test_conv2d_kernel_midpoint_mask_odd_bool(self):
2530 [False , False , False , False , False , False , False ],
2631 ]
2732
28- def test_conv2d_kernel_midpoint_mask_odd_float_inplace (self ):
29- mask = torch .tensor (
30- [
31- [2.0 , 1.0 , 1.0 , 1.0 , 1.0 , 7.0 , 1.0 ],
32- [1.0 , 3.0 , 1.0 , 1.0 , 1.0 , 1.0 , 8.0 ],
33- [9.0 , 1.0 , 4.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
34- [1.0 , 1.0 , 1.0 , 5.0 , 1.0 , 1.0 , 1.0 ],
35- [1.0 , 1.0 , 1.0 , 1.0 , 6.0 , 1.0 , 1.0 ],
36- ],
37- device = torch_device ,
38- )
39- drop .conv2d_kernel_midpoint_mask (
40- kernel = (3 , 3 ),
41- inplace_mask = mask ,
42- )
43- print (mask )
44- assert mask .device == torch .device (torch_device )
45- assert mask .tolist () == \
46- [
47- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
48- [0.0 , 3.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
49- [0.0 , 1.0 , 4.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
50- [0.0 , 1.0 , 1.0 , 5.0 , 1.0 , 1.0 , 0.0 ],
51- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
52- ]
53-
5433 def test_conv2d_kernel_midpoint_mask_odd_float (self ):
5534 mask = drop .conv2d_kernel_midpoint_mask (
5635 shape = (5 , 7 ),
@@ -88,10 +67,14 @@ def test_conv2d_kernel_midpoint_mask_odd_int(self):
8867 ]
8968
9069 def test_conv2d_kernel_midpoint_mask_even (self ):
91- mask = drop .conv2d_kernel_midpoint_mask (shape = (5 , 7 ), kernel = (2 , 2 ), device = torch_device )
70+ mask = drop .conv2d_kernel_midpoint_mask (
71+ shape = (5 , 7 ),
72+ kernel = (2 , 2 ),
73+ device = torch_device ,
74+ dtype = torch .bool ,
75+ )
9276 print (mask )
9377 assert mask .device == torch .device (torch_device )
94- # TODO: This is a suprising result; should even kernels be forbidden?
9578 assert mask .tolist () == \
9679 [
9780 [False , False , False , False , False , False , False ],
@@ -103,9 +86,93 @@ def test_conv2d_kernel_midpoint_mask_even(self):
10386
10487 def test_clip_mask_2d_kernel_too_big (self ):
10588 try :
106- drop .conv2d_kernel_midpoint_mask (shape = (4 , 7 ), kernel = (5 , 5 ), device = torch_device )
89+ drop .conv2d_kernel_midpoint_mask (
90+ shape = (4 , 7 ),
91+ kernel = (5 , 5 ),
92+ device = torch_device ,
93+ dtype = torch .bool ,
94+ )
10795 raise RuntimeError ("Expected throw" )
10896
10997 except AssertionError as e :
11098 assert "kernel=(5, 5) ! <= shape=(4, 7)" in e .args [0 ]
11199
100+
101+ class DropBlock2dDropFilterTest (unittest .TestCase ):
102+ def test_drop_filter (self ):
103+ selection = torch .tensor (
104+ [
105+ [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0.0 ],
106+ [0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
107+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 ],
108+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
109+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 ],
110+ ],
111+ device = torch_device ,
112+ ).unsqueeze (0 ).unsqueeze (0 )
113+
114+ result = drop .drop_block_2d_drop_filter_ (
115+ selection = selection ,
116+ kernel = (2 , 3 ),
117+ messy = False
118+ ).squeeze ()
119+ print (result )
120+ assert result .device == torch .device (torch_device )
121+ assert result .tolist () == \
122+ [
123+ [1.0 , 1.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
124+ [1.0 , 1.0 , 1.0 , 0.0 , 1.0 , 1.0 , 1.0 ],
125+ [0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 ],
126+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
127+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
128+ ]
129+
130+ def test_drop_filter_messy (self ):
131+ selection = torch .tensor (
132+ [
133+ [0 , 0 , 0 , 1 , 0 , 0 , 0 ],
134+ [0 , 1 , 0 , 0 , 0 , 0 , 0 ],
135+ [0 , 0 , 0 , 0 , 0 , 1 , 0 ],
136+ [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
137+ [0 , 0 , 0 , 0 , 0 , 0 , 1 ],
138+ ],
139+ device = torch_device ,
140+ dtype = torch .int32 ,
141+ ).unsqueeze (0 ).unsqueeze (0 )
142+
143+ result = drop .drop_block_2d_drop_filter_ (
144+ selection = selection ,
145+ kernel = (2 , 3 ),
146+ messy = True
147+ ).squeeze ()
148+ print (result )
149+ assert result .device == torch .device (torch_device )
150+ assert result .tolist () == \
151+ [
152+ [1 , 1 , 1 , 1 , 1 , 0 , 0 ],
153+ [1 , 1 , 1 , 0 , 1 , 1 , 1 ],
154+ [0 , 0 , 0 , 0 , 1 , 1 , 1 ],
155+ [0 , 0 , 0 , 0 , 0 , 1 , 1 ],
156+ [0 , 0 , 0 , 0 , 0 , 1 , 1 ],
157+ ]
158+
159+ class DropBlock2dTest (unittest .TestCase ):
160+ def test_drop_block_2d (self ):
161+ tensor = torch .ones ((1 , 1 , 200 , 300 ), device = torch_device )
162+
163+ drop_prob = 0.1
164+ keep_prob = 1.0 - drop_prob
165+
166+ result = drop .drop_block_2d (
167+ tensor ,
168+ drop_prob = drop_prob ,
169+ with_noise = True ,
170+ ).squeeze ()
171+
172+ numel = float (result .numel ())
173+ unchanged = float (len (result [result == 1.0 ]))
174+ keep_ratio = unchanged / numel
175+
176+ assert abs (keep_ratio - keep_prob ) < 0.05 , \
177+ f"abs({ keep_ratio = } - { keep_prob = } ) ! < 0.05"
178+
0 commit comments