Skip to content

Commit a6a01c1

Browse files
author
wanghaox
committed
add test_maxout_op framework to fluis
1 parent 63f8c5f commit a6a01c1

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import unittest
2+
import numpy as np
3+
from op_test import OpTest
4+
5+
6+
def maxout_forward_naive(input, groups,num_channels):
7+
s0, s1, s2, s3 = input.shape
8+
return np.ndarray([s0, s1 / groups, groups, s2, s3], \
9+
buffer = input, dtype=input.dtype).max(axis=(2))
10+
11+
12+
class TestMaxOutOp(OpTest):
13+
def setUp(self):
14+
self.op_type = "maxout"
15+
self.init_test_case()
16+
input = np.random.random(self.shape).astype("float32")
17+
output = self.MaxOut_forward_naive(input, self.groups,
18+
self.num_channels).astype("float32")
19+
20+
self.inputs = {'X': input}
21+
self.attrs = {'groups': self.groups, 'num_channels': self.num_channels}
22+
23+
self.outputs = {'Out': output.astype('float32')}
24+
25+
def test_check_output(self):
26+
self.check_output()
27+
28+
def test_check_grad(self):
29+
self.check_grad(['X'], 'Out')
30+
31+
def init_test_case(self):
32+
self.MaxOut_forward_naive = maxout_forward_naive
33+
self.shape = [100, 6, 2, 2]
34+
self.groups=2
35+
self.num_channels=6
36+
37+
38+
39+
40+
if __name__ == '__main__':
41+
unittest.main()

0 commit comments

Comments
 (0)