@@ -16,6 +16,7 @@ limitations under the License. */
16
16
17
17
#include " paddle/fluid/framework/eigen.h"
18
18
#include " paddle/fluid/framework/op_registry.h"
19
+ #include " paddle/fluid/operators/math/selected_rows_functor.h"
19
20
#include " paddle/fluid/platform/transform.h"
20
21
21
22
namespace paddle {
@@ -31,10 +32,31 @@ class ClipByNormKernel : public framework::OpKernel<T> {
31
32
public:
32
33
void Compute (const framework::ExecutionContext& context) const override {
33
34
auto max_norm = context.Attr <T>(" max_norm" );
34
- auto * input = context.Input <Tensor> (" X" );
35
+ auto in_var = context.InputVar (" X" );
35
36
auto * output = context.Output <Tensor>(" Out" );
36
37
output->mutable_data <T>(context.GetPlace ());
37
38
39
+ const Tensor* input = nullptr ;
40
+ if (in_var->IsType <framework::LoDTensor>()) {
41
+ input = context.Input <Tensor>(" X" );
42
+ } else if (in_var->IsType <framework::SelectedRows>()) {
43
+ auto * x = context.Input <framework::SelectedRows>(" X" );
44
+
45
+ // merge ids in selected rows first
46
+ math::scatter::MergeAdd<DeviceContext, T> merge_func;
47
+ auto * merged_input = const_cast <framework::Scope&>(context.scope ())
48
+ .Var ()
49
+ ->GetMutable <framework::SelectedRows>();
50
+ merge_func (context.template device_context <DeviceContext>(), *x,
51
+ merged_input);
52
+ input = &(merged_input->value ());
53
+ } else {
54
+ PADDLE_THROW (" Unexpected branch, input variable type is %s" ,
55
+ in_var->Type ().name ());
56
+ }
57
+
58
+ PADDLE_ENFORCE_NOT_NULL (input);
59
+
38
60
auto x = EigenVector<T>::Flatten (*input);
39
61
auto out = EigenVector<T>::Flatten (*output);
40
62
auto x_norm = x.square ().sum ().sqrt ();
0 commit comments