Skip to content

Commit f7bd0b2

Browse files
committed
Add unittests for softmax_op
1 parent b314a69 commit f7bd0b2

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

paddle/fluid/operators/softmax_op.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class SoftmaxKernel : public framework::OpKernel<T> {
3535
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
3636
framework::LoDTensor flattened_x;
3737
framework::LoDTensor flattened_out;
38-
flattened_x.ShareDataWith(*X);
39-
flattened_out.ShareDataWith(*Out);
38+
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
39+
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
4040

4141
math::SoftmaxFunctor<DeviceContext, T>()(
4242
context.template device_context<DeviceContext>(), &flattened_x,
@@ -60,9 +60,9 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
6060
framework::LoDTensor flattened_out;
6161
framework::LoDTensor flattened_d_out;
6262
framework::LoDTensor flattened_d_x;
63-
flattened_out.ShareDataWith(*Out);
64-
flattened_d_out.ShareDataWith(*dOut);
65-
flattened_d_x.ShareDataWith(*dX);
63+
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
64+
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
65+
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
6666

6767
math::SoftmaxGradFunctor<DeviceContext, T>()(
6868
context.template device_context<DeviceContext>(), &flattened_out,

python/paddle/fluid/tests/unittests/test_softmax_op.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,22 @@ def stable_softmax(x):
2626

2727

2828
class TestSoftmaxOp(OpTest):
29+
def get_x_shape(self):
30+
return [10, 10]
31+
2932
def setUp(self):
3033
self.op_type = "softmax"
3134
self.use_cudnn = False
3235
self.use_mkldnn = False
3336
self.dtype = np.float32
3437
self.init_kernel_type()
38+
self.shape = self.get_x_shape()
39+
40+
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
41+
out = np.apply_along_axis(stable_softmax, 1,
42+
x.reshape([-1, self.shape[-1]]))
43+
out = out.reshape(self.shape)
3544

36-
x = np.random.uniform(0.1, 1, [10, 10]).astype(self.dtype)
37-
out = np.apply_along_axis(stable_softmax, 1, x)
3845
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
3946
self.outputs = {'Out': out}
4047
self.attrs = {
@@ -63,13 +70,25 @@ def test_check_grad(self):
6370
self.check_grad(["X"], "Out", max_relative_error=0.01)
6471

6572

73+
class TestSoftmaxOp2(TestSoftmaxOp):
74+
def get_x_shape(self):
75+
return [2, 3, 4, 5]
76+
77+
6678
@unittest.skipIf(not core.is_compiled_with_cuda(),
6779
"core is not compiled with CUDA")
6880
class TestSoftmaxCUDNNOp(TestSoftmaxOp):
6981
def init_kernel_type(self):
7082
self.use_cudnn = True
7183

7284

85+
@unittest.skipIf(not core.is_compiled_with_cuda(),
86+
"core is not compiled with CUDA")
87+
class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp):
88+
def get_x_shape(self):
89+
return [2, 3, 4, 5]
90+
91+
7392
@unittest.skipIf(not core.is_compiled_with_cuda(),
7493
"core is not compiled with CUDA")
7594
class TestSoftmaxFP16Op(TestSoftmaxOp):
@@ -83,6 +102,13 @@ def test_check_output(self):
83102
self.check_output_with_place(place, atol=1e-3)
84103

85104

105+
@unittest.skipIf(not core.is_compiled_with_cuda(),
106+
"core is not compiled with CUDA")
107+
class TestSoftmaxFP16Op2(TestSoftmaxFP16Op):
108+
def get_x_shape(self):
109+
return [2, 3, 4, 5]
110+
111+
86112
@unittest.skipIf(not core.is_compiled_with_cuda(),
87113
"core is not compiled with CUDA")
88114
class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp):
@@ -97,10 +123,22 @@ def test_check_output(self):
97123
self.check_output_with_place(place, atol=1e-3)
98124

99125

126+
@unittest.skipIf(not core.is_compiled_with_cuda(),
127+
"core is not compiled with CUDA")
128+
class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp):
129+
def get_x_shape(self):
130+
return [2, 3, 4, 5]
131+
132+
100133
class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
101134
def init_kernel_type(self):
102135
self.use_mkldnn = True
103136

104137

138+
class TestSoftmaxMKLDNNOp2(TestSoftmaxMKLDNNOp):
139+
def get_x_shape(self):
140+
return [2, 3, 4, 5]
141+
142+
105143
if __name__ == "__main__":
106144
unittest.main()

0 commit comments

Comments
 (0)