@@ -33,8 +33,8 @@ struct ValueClip {
33
33
}
34
34
};
35
35
36
- template <typename DeviceContext, typename T, bool is_test>
37
- void SoftmaxFunctor<DeviceContext, T, is_test>::operator ()(
36
+ template <typename DeviceContext, typename T, bool is_test, typename Enable >
37
+ void SoftmaxFunctor<DeviceContext, T, is_test, Enable >::operator ()(
38
38
const DeviceContext& context, const framework::Tensor* X,
39
39
framework::Tensor* Y) {
40
40
auto logits = EigenMatrix<T>::From (*X);
@@ -66,8 +66,12 @@ void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
66
66
.broadcast (one_by_class));
67
67
}
68
68
69
+ template <class DeviceContext >
70
+ using enable_if_CPU = typename std::enable_if<
71
+ std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type;
72
+
69
73
template <typename DeviceContext>
70
- class SoftmaxFunctor <DeviceContext, float , true > {
74
+ class SoftmaxFunctor <DeviceContext, float , true , enable_if_CPU<DeviceContext> > {
71
75
void operator ()(const DeviceContext& context, const framework::Tensor* X,
72
76
framework::Tensor* Y) {
73
77
auto in_dims = X->dims ();
0 commit comments