@@ -28,7 +28,7 @@ def unpack_indices(dual_object):
2828
2929@flow .unittest .skip_unless_1n1d ()
3030class TestMaxPooling (flow .unittest .TestCase ):
31- @autotest (n = 100 , auto_backward = False )
31+ @autotest (auto_backward = False )
3232 def test_maxpool1d_with_random_data (test_case ):
3333 return_indices = random ().to (bool ).value ()
3434 m = torch .nn .MaxPool1d (
@@ -50,7 +50,7 @@ def test_maxpool1d_with_random_data(test_case):
5050 else :
5151 return y , y .sum ().backward ()
5252
53- @autotest (n = 100 , auto_backward = False )
53+ @autotest (auto_backward = False )
5454 def test_maxpool2d_with_random_data (test_case ):
5555 return_indices = random ().to (bool ).value ()
5656 m = torch .nn .MaxPool2d (
@@ -74,7 +74,7 @@ def test_maxpool2d_with_random_data(test_case):
7474 else :
7575 return y , y .sum ().backward ()
7676
77- @autotest (n = 100 , auto_backward = False )
77+ @autotest (auto_backward = False )
7878 def test_maxpool3d_with_random_data (test_case ):
7979 return_indices = random ().to (bool ).value ()
8080 m = torch .nn .MaxPool3d (
@@ -99,5 +99,72 @@ def test_maxpool3d_with_random_data(test_case):
9999 return y , y .sum ().backward ()
100100
101101
102+ @flow .unittest .skip_unless_1n1d ()
103+ class TestMaxPoolingFunctional (flow .unittest .TestCase ):
104+ @autotest (auto_backward = False )
105+ def test_maxpool1d_with_random_data (test_case ):
106+ return_indices = random ().to (bool ).value ()
107+ device = random_device ()
108+ x = random_pytorch_tensor (ndim = 3 , dim2 = random (20 , 22 )).to (device )
109+ y = torch .nn .functional .max_pool1d (
110+ x ,
111+ kernel_size = random (4 , 6 ).to (int ),
112+ stride = random (1 , 3 ).to (int ) | nothing (),
113+ padding = random (1 , 3 ).to (int ) | nothing (),
114+ dilation = random (2 , 4 ).to (int ) | nothing (),
115+ ceil_mode = random ().to (bool ),
116+ return_indices = return_indices ,
117+ )
118+
119+ if return_indices :
120+ return unpack_indices (y )
121+ else :
122+ return y , y .sum ().backward ()
123+
124+ @autotest (auto_backward = False )
125+ def test_maxpool2d_with_random_data (test_case ):
126+ return_indices = random ().to (bool ).value ()
127+ device = random_device ()
128+ x = random_pytorch_tensor (ndim = 4 , dim2 = random (20 , 22 ), dim3 = random (20 , 22 )).to (
129+ device
130+ )
131+ y = torch .nn .functional .max_pool2d (
132+ x ,
133+ kernel_size = random (4 , 6 ).to (int ),
134+ stride = random (1 , 3 ).to (int ) | nothing (),
135+ padding = random (1 , 3 ).to (int ) | nothing (),
136+ dilation = random (2 , 4 ).to (int ) | nothing (),
137+ ceil_mode = random ().to (bool ),
138+ return_indices = return_indices ,
139+ )
140+
141+ if return_indices :
142+ return unpack_indices (y )
143+ else :
144+ return y , y .sum ().backward ()
145+
146+ @autotest (auto_backward = False )
147+ def test_maxpool3d_with_random_data (test_case ):
148+ return_indices = random ().to (bool ).value ()
149+ device = random_device ()
150+ x = random_pytorch_tensor (
151+ ndim = 5 , dim2 = random (20 , 22 ), dim3 = random (20 , 22 ), dim4 = random (20 , 22 )
152+ ).to (device )
153+ y = torch .nn .functional .max_pool3d (
154+ x ,
155+ kernel_size = random (4 , 6 ).to (int ),
156+ stride = random (1 , 3 ).to (int ) | nothing (),
157+ padding = random (1 , 3 ).to (int ) | nothing (),
158+ dilation = random (2 , 4 ).to (int ) | nothing (),
159+ ceil_mode = random ().to (bool ),
160+ return_indices = return_indices ,
161+ )
162+
163+ if return_indices :
164+ return unpack_indices (y )
165+ else :
166+ return y , y .sum ().backward ()
167+
168+
102169if __name__ == "__main__" :
103170 unittest .main ()
0 commit comments