Skip to content

Commit 6730882

Browse files
committed
Add selected_rows merge for clip_by_norm op
test=develop
1 parent 2f5a7cc commit 6730882

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ if(WITH_DISTRIBUTE)
229229
op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS})
230230
set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
231231
endforeach()
232-
232+
233233
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
234234
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
235235
# listen_and_serv_op sum_op executor SERIAL)
@@ -267,6 +267,7 @@ if (WITH_GPU AND TENSORRT_FOUND)
267267
else()
268268
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
269269
endif()
270+
op_library(clip_by_norm_op DEPS selected_rows_functor)
270271
op_library(sum_op DEPS selected_rows_functor)
271272
op_library(sgd_op DEPS selected_rows_functor)
272273
op_library(print_op DEPS lod_tensor)

paddle/fluid/operators/clip_by_norm_op.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/math/selected_rows_functor.h"
1920
#include "paddle/fluid/platform/transform.h"
2021

2122
namespace paddle {
@@ -31,10 +32,31 @@ class ClipByNormKernel : public framework::OpKernel<T> {
3132
public:
3233
void Compute(const framework::ExecutionContext& context) const override {
3334
auto max_norm = context.Attr<T>("max_norm");
34-
auto* input = context.Input<Tensor>("X");
35+
auto in_var = context.InputVar("X");
3536
auto* output = context.Output<Tensor>("Out");
3637
output->mutable_data<T>(context.GetPlace());
3738

39+
const Tensor* input = nullptr;
40+
if (in_var->IsType<framework::LoDTensor>()) {
41+
input = context.Input<Tensor>("X");
42+
} else if (in_var->IsType<framework::SelectedRows>()) {
43+
auto* x = context.Input<framework::SelectedRows>("X");
44+
45+
// merge ids in selected rows first
46+
math::scatter::MergeAdd<DeviceContext, T> merge_func;
47+
auto* merged_input = const_cast<framework::Scope&>(context.scope())
48+
.Var()
49+
->GetMutable<framework::SelectedRows>();
50+
merge_func(context.template device_context<DeviceContext>(), *x,
51+
merged_input);
52+
input = &(merged_input->value());
53+
} else {
54+
PADDLE_THROW("Unexpected branch, input variable type is %s",
55+
in_var->Type().name());
56+
}
57+
58+
PADDLE_ENFORCE_NOT_NULL(input);
59+
3860
auto x = EigenVector<T>::Flatten(*input);
3961
auto out = EigenVector<T>::Flatten(*output);
4062
auto x_norm = x.square().sum().sqrt();

0 commit comments

Comments
 (0)