1111 importlib .import_module (torch_backend )
1212torch_device = os .environ .get ('TORCH_DEVICE' , 'cpu' )
1313
14- class ClipMaskTests (unittest .TestCase ):
15- def test_clip_mask_2d_odd (self ):
16- mask = drop .clip_mask_2d ( h = 5 , w = 7 , kernel = 3 , device = torch_device )
14+ class Conv2dKernelMidpointMask2d (unittest .TestCase ):
15+ def test_conv2d_kernel_midpoint_mask_odd_bool (self ):
16+ mask = drop .conv2d_kernel_midpoint_mask ( shape = ( 5 , 7 ) , kernel = ( 3 , 3 ) , device = torch_device )
1717 print (mask )
1818 assert mask .device == torch .device (torch_device )
1919 assert mask .tolist () == \
@@ -25,8 +25,44 @@ def test_clip_mask_2d_odd(self):
2525 [False , False , False , False , False , False , False ],
2626 ]
2727
28- def test_clip_mask_2d_even (self ):
29- mask = drop .clip_mask_2d (h = 5 , w = 7 , kernel = 2 , device = torch_device )
28+ def test_conv2d_kernel_midpoint_mask_odd_float (self ):
29+ mask = drop .conv2d_kernel_midpoint_mask (
30+ shape = (5 , 7 ),
31+ kernel = (3 , 3 ),
32+ device = torch_device ,
33+ dtype = torch .float32 ,
34+ )
35+ print (mask )
36+ assert mask .device == torch .device (torch_device )
37+ assert mask .tolist () == \
38+ [
39+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
40+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
41+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
42+ [0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
43+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
44+ ]
45+
46+ def test_conv2d_kernel_midpoint_mask_odd_int (self ):
47+ mask = drop .conv2d_kernel_midpoint_mask (
48+ shape = (5 , 7 ),
49+ kernel = (3 , 3 ),
50+ device = torch_device ,
51+ dtype = torch .int32 ,
52+ )
53+ print (mask )
54+ assert mask .device == torch .device (torch_device )
55+ assert mask .tolist () == \
56+ [
57+ [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
58+ [0 , 1 , 1 , 1 , 1 , 1 , 0 ],
59+ [0 , 1 , 1 , 1 , 1 , 1 , 0 ],
60+ [0 , 1 , 1 , 1 , 1 , 1 , 0 ],
61+ [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
62+ ]
63+
64+ def test_conv2d_kernel_midpoint_mask_even (self ):
65+ mask = drop .conv2d_kernel_midpoint_mask (shape = (5 , 7 ), kernel = (2 , 2 ), device = torch_device )
3066 print (mask )
3167 assert mask .device == torch .device (torch_device )
3268 # TODO: This is a suprising result; should even kernels be forbidden?
@@ -41,9 +77,9 @@ def test_clip_mask_2d_even(self):
4177
4278 def test_clip_mask_2d_kernel_too_big (self ):
4379 try :
44- drop .clip_mask_2d ( h = 4 , w = 7 , kernel = 5 , device = torch_device )
80+ drop .conv2d_kernel_midpoint_mask ( shape = ( 4 , 7 ) , kernel = ( 5 , 5 ) , device = torch_device )
4581 raise RuntimeError ("Expected throw" )
4682
4783 except AssertionError as e :
48- assert "kernel=5 > min(h= 4, w= 7)" in e .args [0 ]
84+ assert "kernel=(5, 5) ! <= shape=( 4, 7)" in e .args [0 ]
4985
0 commit comments