1212torch_device = os .environ .get ('TORCH_DEVICE' , 'cpu' )
1313
1414class Conv2dKernelMidpointMask2d (unittest .TestCase ):
15- def test_conv2d_kernel_midpoint_mask_odd_bool (self ):
16- mask = drop .conv2d_kernel_midpoint_mask (
17- shape = (5 , 7 ),
18- kernel = (3 , 3 ),
19- device = torch_device ,
20- dtype = torch .bool ,
21- )
22- print (mask )
23- assert mask .device == torch .device (torch_device )
24- assert mask .tolist () == \
25- [
26- [False , False , False , False , False , False , False ],
27- [False , True , True , True , True , True , False ],
28- [False , True , True , True , True , True , False ],
29- [False , True , True , True , True , True , False ],
30- [False , False , False , False , False , False , False ],
31- ]
32-
33- def test_conv2d_kernel_midpoint_mask_odd_float (self ):
15+ def test_conv2d_kernel_midpoint_mask_odd (self ):
3416 mask = drop .conv2d_kernel_midpoint_mask (
3517 shape = (5 , 7 ),
3618 kernel = (3 , 3 ),
@@ -48,23 +30,6 @@ def test_conv2d_kernel_midpoint_mask_odd_float(self):
4830 [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
4931 ]
5032
51- def test_conv2d_kernel_midpoint_mask_odd_int (self ):
52- mask = drop .conv2d_kernel_midpoint_mask (
53- shape = (5 , 7 ),
54- kernel = (3 , 3 ),
55- device = torch_device ,
56- dtype = torch .int32 ,
57- )
58- print (mask )
59- assert mask .device == torch .device (torch_device )
60- assert mask .tolist () == \
61- [
62- [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
63- [0 , 1 , 1 , 1 , 1 , 1 , 0 ],
64- [0 , 1 , 1 , 1 , 1 , 1 , 0 ],
65- [0 , 1 , 1 , 1 , 1 , 1 , 0 ],
66- [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
67- ]
6833
6934 def test_conv2d_kernel_midpoint_mask_even (self ):
7035 mask = drop .conv2d_kernel_midpoint_mask (
@@ -77,11 +42,11 @@ def test_conv2d_kernel_midpoint_mask_even(self):
7742 assert mask .device == torch .device (torch_device )
7843 assert mask .tolist () == \
7944 [
80- [False , False , False , False , False , False , False ],
81- [False , True , True , True , True , True , True ],
82- [False , True , True , True , True , True , True ],
83- [False , True , True , True , True , True , True ],
84- [False , True , True , True , True , True , True ],
45+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
46+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
47+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
48+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
49+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
8550 ]
8651
8752 def test_clip_mask_2d_kernel_too_big (self ):
@@ -130,11 +95,11 @@ def test_drop_filter(self):
13095 def test_drop_filter_messy (self ):
13196 selection = torch .tensor (
13297 [
133- [0 , 0 , 0 , 1 , 0 , 0 , 0 ],
134- [0 , 1 , 0 , 0 , 0 , 0 , 0 ],
135- [0 , 0 , 0 , 0 , 0 , 1 , 0 ],
136- [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
137- [0 , 0 , 0 , 0 , 0 , 0 , 1 ],
98+ [0.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 0. 0 ],
99+ [0.0 , 1.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0. 0 ],
100+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0. 0 ],
101+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0. 0 ],
102+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 ],
138103 ],
139104 device = torch_device ,
140105 dtype = torch .int32 ,
@@ -149,11 +114,11 @@ def test_drop_filter_messy(self):
149114 assert result .device == torch .device (torch_device )
150115 assert result .tolist () == \
151116 [
152- [1 , 1 , 1 , 1 , 1 , 0 , 0 ],
153- [1 , 1 , 1 , 0 , 1 , 1 , 1 ],
154- [0 , 0 , 0 , 0 , 1 , 1 , 1 ],
155- [0 , 0 , 0 , 0 , 0 , 1 , 1 ],
156- [0 , 0 , 0 , 0 , 0 , 1 , 1 ],
117+ [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 , 0. 0 ],
118+ [1.0 , 1.0 , 1.0 , 0.0 , 1.0 , 1.0 , 1.0 ],
119+ [0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 ],
120+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 ],
121+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 ],
157122 ]
158123
159124class DropBlock2dTest (unittest .TestCase ):
0 commit comments