@@ -16,13 +16,15 @@ limitations under the License. */
16
16
17
17
#include " paddle/fluid/framework/eigen.h"
18
18
#include " paddle/fluid/framework/op_registry.h"
19
+ #include " paddle/fluid/framework/selected_rows.h"
19
20
#include " paddle/fluid/operators/math/selected_rows_functor.h"
20
21
#include " paddle/fluid/platform/transform.h"
21
22
22
23
namespace paddle {
23
24
namespace operators {
24
25
25
26
using Tensor = framework::Tensor;
27
+ using SelectedRows = framework::SelectedRows;
26
28
template <typename T, int MajorType = Eigen::RowMajor,
27
29
typename IndexType = Eigen::DenseIndex>
28
30
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
@@ -41,22 +43,24 @@ class ClipByNormKernel : public framework::OpKernel<T> {
41
43
42
44
output = context.Output <Tensor>(" Out" );
43
45
output->mutable_data <T>(context.GetPlace ());
44
- } else if (in_var->IsType <framework:: SelectedRows>()) {
45
- auto * x = context.Input <framework:: SelectedRows>(" X" );
46
+ } else if (in_var->IsType <SelectedRows>()) {
47
+ auto * x = context.Input <SelectedRows>(" X" );
46
48
47
49
// merge ids in selected rows first
48
50
math::scatter::MergeAdd<DeviceContext, T> merge_func;
49
- auto * merged_input = const_cast <framework::Scope&>(context.scope ())
50
- .Var ()
51
- ->GetMutable <framework::SelectedRows>();
51
+ SelectedRows* merged_input =
52
+ const_cast <framework::Scope&>(context.scope ())
53
+ .Var ()
54
+ ->GetMutable <SelectedRows>();
52
55
merge_func (context.template device_context <DeviceContext>(), *x,
53
56
merged_input);
54
57
input = &(merged_input->value ());
55
58
56
- auto * output_selected_rows = context.Output <SelectedRows>(" Out" );
57
- output_selected_rows->set_rows (merged_input.rows ());
58
- output = output_selected_rows->mutable_data ();
59
- output->Resize (framework::make_ddim (merged_input.value ().dims ()));
59
+ SelectedRows* output_selected_rows = context.Output <SelectedRows>(" Out" );
60
+ output_selected_rows->set_rows (merged_input->rows ());
61
+ output_selected_rows->set_height (merged_input->height ());
62
+ output = output_selected_rows->mutable_value ();
63
+ output->Resize (merged_input->value ().dims ());
60
64
} else {
61
65
PADDLE_THROW (" Unexpected branch, input variable type is %s" ,
62
66
in_var->Type ().name ());
0 commit comments