@@ -31,16 +31,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
31
31
// allocate memory on device.
32
32
Out->mutable_data <T>(context.GetPlace ());
33
33
34
- auto dims = X->dims ();
35
- auto flattened_dims = framework::flatten_to_2d (dims, dims.size () - 1 );
36
- framework::LoDTensor flattened_x;
37
- framework::LoDTensor flattened_out;
38
- flattened_x.ShareDataWith (*X).Resize (flattened_dims);
39
- flattened_out.ShareDataWith (*Out).Resize (flattened_dims);
34
+ int rank = X->dims ().size ();
35
+ Tensor X_2d = rank > 2 ? framework::ReshapeToMatrix (*X, rank - 1 ) : *X;
36
+ Tensor Out_2d =
37
+ rank > 2 ? framework::ReshapeToMatrix (*Out, rank - 1 ) : *Out;
40
38
41
39
math::SoftmaxFunctor<DeviceContext, T>()(
42
- context.template device_context <DeviceContext>(), &flattened_x,
43
- &flattened_out);
40
+ context.template device_context <DeviceContext>(), &X_2d, &Out_2d);
44
41
}
45
42
};
46
43
@@ -55,18 +52,16 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
55
52
// allocate memory on device.
56
53
dX->mutable_data <T>(context.GetPlace ());
57
54
58
- auto dims = Out->dims ();
59
- auto flattened_dims = framework::flatten_to_2d (dims, dims.size () - 1 );
60
- framework::LoDTensor flattened_out;
61
- framework::LoDTensor flattened_d_out;
62
- framework::LoDTensor flattened_d_x;
63
- flattened_out.ShareDataWith (*Out).Resize (flattened_dims);
64
- flattened_d_out.ShareDataWith (*dOut).Resize (flattened_dims);
65
- flattened_d_x.ShareDataWith (*dX).Resize (flattened_dims);
55
+ int rank = Out->dims ().size ();
56
+ Tensor Out_2d =
57
+ rank > 2 ? framework::ReshapeToMatrix (*Out, rank - 1 ) : *Out;
58
+ Tensor dOut_2d =
59
+ rank > 2 ? framework::ReshapeToMatrix (*dOut, rank - 1 ) : *dOut;
60
+ Tensor dX_2d = rank > 2 ? framework::ReshapeToMatrix (*dX, rank - 1 ) : *dX;
66
61
67
62
math::SoftmaxGradFunctor<DeviceContext, T>()(
68
- context.template device_context <DeviceContext>(), &flattened_out ,
69
- &flattened_d_out, &flattened_d_x );
63
+ context.template device_context <DeviceContext>(), &Out_2d, &dOut_2d ,
64
+ &dX_2d );
70
65
}
71
66
};
72
67
0 commit comments