@@ -16,6 +16,7 @@ limitations under the License. */
16
16
17
17
#include " paddle/fluid/framework/eigen.h"
18
18
#include " paddle/fluid/framework/op_registry.h"
19
+ #include " paddle/fluid/operators/math/selected_rows_functor.h"
19
20
#include " paddle/fluid/platform/transform.h"
20
21
21
22
namespace paddle {
@@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel<T> {
61
62
void Compute (const framework::ExecutionContext& context) const override {
62
63
auto max = context.Attr <T>(" max" );
63
64
auto min = context.Attr <T>(" min" );
64
- auto * x = context.Input <Tensor>(" X" );
65
- auto * out = context.Output <Tensor>(" Out" );
66
- T* out_data = out->mutable_data <T>(context.GetPlace ());
67
- const T* x_data = x->data <T>();
68
- int64_t numel = x->numel ();
69
- Transform<DeviceContext> trans;
70
- trans (context.template device_context <DeviceContext>(), x_data,
71
- x_data + numel, out_data, ClipFunctor<T>(min, max));
65
+ auto * x_var = context.InputVar (" X" );
66
+ if (x_var->IsType <framework::LoDTensor>()) {
67
+ auto * x = context.Input <framework::LoDTensor>(" X" );
68
+ auto * out = context.Output <framework::LoDTensor>(" Out" );
69
+ T* out_data = out->mutable_data <T>(context.GetPlace ());
70
+ const T* x_data = x->data <T>();
71
+ int64_t numel = x->numel ();
72
+ Transform<DeviceContext> trans;
73
+ trans (context.template device_context <DeviceContext>(), x_data,
74
+ x_data + numel, out_data, ClipFunctor<T>(min, max));
75
+ } else if (x_var->IsType <framework::SelectedRows>()) {
76
+ auto * x = context.Input <framework::SelectedRows>(" X" );
77
+ auto * out = context.Output <framework::SelectedRows>(" Out" );
78
+ PADDLE_ENFORCE_NE (x, out,
79
+ " Inplace clip is not allowed when x is SelectedRows" );
80
+ math::scatter::MergeAdd<DeviceContext, T> merge_func;
81
+ merge_func (context.template device_context <DeviceContext>(), *x, out);
82
+ auto * out_tensor = out->mutable_value ();
83
+ auto * out_data = out_tensor->data <T>();
84
+ int64_t numel = out_tensor->numel ();
85
+ Transform<DeviceContext> trans;
86
+ trans (context.template device_context <DeviceContext>(), out_data,
87
+ out_data + numel, out_data, ClipFunctor<T>(min, max));
88
+ } else {
89
+ PADDLE_THROW (" ClipOp only supports LoDTensor and SelectedRows" );
90
+ }
72
91
}
73
92
};
74
93
@@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel<T> {
78
97
void Compute (const framework::ExecutionContext& context) const override {
79
98
auto max = context.Attr <T>(" max" );
80
99
auto min = context.Attr <T>(" min" );
81
- auto * d_out = context.Input <Tensor>(framework::GradVarName (" Out" ));
82
- auto * d_x = context.Output <Tensor>(framework::GradVarName (" X" ));
100
+ auto * d_out =
101
+ context.Input <framework::LoDTensor>(framework::GradVarName (" Out" ));
102
+ auto * d_x =
103
+ context.Output <framework::LoDTensor>(framework::GradVarName (" X" ));
83
104
if (d_x != nullptr ) {
84
- auto * x = context.Input <Tensor >(" X" );
105
+ auto * x = context.Input <framework::LoDTensor >(" X" );
85
106
int64_t numel = d_out->numel ();
86
107
auto * d_x_data = d_x->mutable_data <T>(context.GetPlace ());
87
108
const T* d_out_data = d_out->data <T>();
0 commit comments