Skip to content

Commit f99ea99

Browse files
authored
Merge pull request #13720 from velconia/fix_grad_clip
Merge selected_rows for clip_by_norm op
2 parents afdc730 + 1456b8e commit f99ea99

File tree

3 files changed

+96
-4
lines changed

3 files changed

+96
-4
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ if(WITH_DISTRIBUTE)
230230
op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS})
231231
set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
232232
endforeach()
233-
233+
234234
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
235235
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
236236
# listen_and_serv_op sum_op executor SERIAL)
@@ -268,6 +268,7 @@ if (WITH_GPU AND TENSORRT_FOUND)
268268
else()
269269
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
270270
endif()
271+
op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows)
271272
op_library(sum_op DEPS selected_rows_functor)
272273
op_library(sgd_op DEPS selected_rows_functor)
273274
op_library(print_op DEPS lod_tensor)

paddle/fluid/operators/clip_by_norm_op.h

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ limitations under the License. */
1616

1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/framework/selected_rows.h"
20+
#include "paddle/fluid/operators/math/selected_rows_functor.h"
1921
#include "paddle/fluid/platform/transform.h"
2022

2123
namespace paddle {
2224
namespace operators {
2325

2426
using Tensor = framework::Tensor;
27+
using SelectedRows = framework::SelectedRows;
2528
template <typename T, int MajorType = Eigen::RowMajor,
2629
typename IndexType = Eigen::DenseIndex>
2730
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
@@ -31,9 +34,40 @@ class ClipByNormKernel : public framework::OpKernel<T> {
3134
public:
3235
void Compute(const framework::ExecutionContext& context) const override {
3336
auto max_norm = context.Attr<T>("max_norm");
34-
auto* input = context.Input<Tensor>("X");
35-
auto* output = context.Output<Tensor>("Out");
36-
output->mutable_data<T>(context.GetPlace());
37+
auto in_var = context.InputVar("X");
38+
39+
Tensor* output = nullptr;
40+
const Tensor* input = nullptr;
41+
if (in_var->IsType<framework::LoDTensor>()) {
42+
input = context.Input<Tensor>("X");
43+
44+
output = context.Output<Tensor>("Out");
45+
output->mutable_data<T>(context.GetPlace());
46+
} else if (in_var->IsType<SelectedRows>()) {
47+
auto* x = context.Input<SelectedRows>("X");
48+
49+
// merge ids in selected rows first
50+
math::scatter::MergeAdd<DeviceContext, T> merge_func;
51+
SelectedRows* merged_input =
52+
const_cast<framework::Scope&>(context.scope())
53+
.Var()
54+
->GetMutable<SelectedRows>();
55+
merge_func(context.template device_context<DeviceContext>(), *x,
56+
merged_input);
57+
input = &(merged_input->value());
58+
59+
SelectedRows* output_selected_rows = context.Output<SelectedRows>("Out");
60+
output_selected_rows->set_rows(merged_input->rows());
61+
output_selected_rows->set_height(merged_input->height());
62+
output = output_selected_rows->mutable_value();
63+
output->Resize(merged_input->value().dims());
64+
output->mutable_data<T>(context.GetPlace());
65+
} else {
66+
PADDLE_THROW("Unexpected branch, input variable type is %s",
67+
in_var->Type().name());
68+
}
69+
70+
PADDLE_ENFORCE_NOT_NULL(input);
3771

3872
auto x = EigenVector<T>::Flatten(*input);
3973
auto out = EigenVector<T>::Flatten(*output);

python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import numpy as np
1919
from op_test import OpTest
2020

21+
import paddle.fluid as fluid
22+
import paddle.fluid.core as core
23+
2124

2225
class TestClipByNormOp(OpTest):
2326
def setUp(self):
@@ -62,5 +65,59 @@ def initTestCase(self):
6265
self.max_norm = 1.0
6366

6467

68+
class TestClipByNormOpWithSelectedRows(OpTest):
69+
def check_with_place(self, place):
70+
self.config_test_case()
71+
scope = core.Scope()
72+
73+
# set input
74+
x_selected_rows = scope.var('X').get_selected_rows()
75+
x_selected_rows.set_rows(self.grad_rows)
76+
x_tensor = x_selected_rows.get_tensor()
77+
x_np = np.random.random(self.grad_shape).astype("float32")
78+
x_np[np.abs(x_np) < self.max_relative_error] = 0.5
79+
x_tensor.set(x_np, place)
80+
81+
# set output
82+
out_selected_rows = scope.var('Out').get_selected_rows()
83+
84+
# run clip_by_norm_op
85+
clip_by_norm_op = fluid.op.Operator(
86+
"clip_by_norm", max_norm=self.max_norm, X='X', Out='Out')
87+
clip_by_norm_op.run(scope, place)
88+
89+
# check output
90+
self.assertEqual(out_selected_rows.rows(), self.grad_clipped_rows)
91+
out_tensor = out_selected_rows.get_tensor()
92+
y_np = np.zeros(self.grad_clipped_shape)
93+
y_np[0] = np.sum(x_np[0:2])
94+
y_np[1] = x_np[2]
95+
y_np[2] = x_np[3]
96+
norm = np.sqrt(np.sum(np.square(y_np)))
97+
if norm > self.max_norm:
98+
output = self.max_norm * y_np / norm
99+
else:
100+
output = y_np
101+
self.assertTrue(
102+
np.allclose(
103+
np.array(out_tensor), output, atol=1e-5, equal_nan=False))
104+
105+
def test_clip_by_norm_with_selected_ros(self):
106+
places = [core.CPUPlace()]
107+
if core.is_compiled_with_cuda():
108+
places.append(core.CUDAPlace(0))
109+
110+
for place in places:
111+
self.check_with_place(place)
112+
113+
def config_test_case(self):
114+
self.max_norm = 1.0
115+
self.max_relative_error = 0.006
116+
self.grad_shape = (4, 1)
117+
self.grad_clipped_shape = (3, 1)
118+
self.grad_rows = [0, 0, 1, 2]
119+
self.grad_clipped_rows = [0, 1, 2]
120+
121+
65122
if __name__ == '__main__':
66123
unittest.main()

0 commit comments

Comments
 (0)