Skip to content

Commit 9b0eae3

Browse files
committed
- Removing partial specialization of sotmax for inference for GPU
test=develop
1 parent be80bb4 commit 9b0eae3

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

paddle/fluid/operators/math/softmax.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ namespace paddle {
1919
namespace operators {
2020
namespace math {
2121

22-
template <typename DeviceContext, typename T, bool is_test>
22+
template <typename DeviceContext, typename T, bool is_test,
23+
typename Enable = void>
2324
class SoftmaxFunctor {
2425
public:
2526
void operator()(const DeviceContext& context, const framework::Tensor* X,

paddle/fluid/operators/math/softmax_impl.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ struct ValueClip {
3333
}
3434
};
3535

36-
template <typename DeviceContext, typename T, bool is_test>
37-
void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
36+
template <typename DeviceContext, typename T, bool is_test, typename Enable>
37+
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
3838
const DeviceContext& context, const framework::Tensor* X,
3939
framework::Tensor* Y) {
4040
auto logits = EigenMatrix<T>::From(*X);
@@ -66,8 +66,12 @@ void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
6666
.broadcast(one_by_class));
6767
}
6868

69+
template <class DeviceContext>
70+
using enable_if_CPU = typename std::enable_if<
71+
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type;
72+
6973
template <typename DeviceContext>
70-
class SoftmaxFunctor<DeviceContext, float, true> {
74+
class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
7175
void operator()(const DeviceContext& context, const framework::Tensor* X,
7276
framework::Tensor* Y) {
7377
auto in_dims = X->dims();

0 commit comments

Comments
 (0)