Skip to content

Commit 5bd7c82

Browse files
authored
[Cherry-pick] Double grad for clip op #31109
Cherry-pick double grad for clip
1 parent 8177ece commit 5bd7c82

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

paddle/fluid/operators/clip_op.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,29 @@ DECLARE_INPLACE_OP_INFERER(ClipGradInplaceInferer,
109109
{framework::GradVarName("Out"),
110110
framework::GradVarName("X")});
111111

112+
template <typename T>
113+
class ClipDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
114+
public:
115+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
116+
117+
protected:
118+
void Apply(GradOpPtr<T> op) const override {
119+
op->SetType("clip_grad");
120+
op->SetInput("X", this->Input("X"));
121+
if (this->HasInput("Min")) {
122+
op->SetInput("Min", this->Input("Min"));
123+
}
124+
if (this->HasInput("Max")) {
125+
op->SetInput("Max", this->Input("Max"));
126+
}
127+
op->SetInput(framework::GradVarName("Out"),
128+
this->OutputGrad(framework::GradVarName("X")));
129+
op->SetOutput(framework::GradVarName("X"),
130+
this->InputGrad(framework::GradVarName("Out")));
131+
op->SetAttrMap(this->Attrs());
132+
}
133+
};
134+
112135
} // namespace operators
113136
} // namespace paddle
114137

@@ -117,7 +140,9 @@ REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>,
117140
ops::ClipGradOpMaker<paddle::framework::OpDesc>,
118141
ops::ClipGradOpMaker<paddle::imperative::OpBase>,
119142
ops::ClipInplaceInferer);
120-
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer);
143+
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer,
144+
ops::ClipDoubleGradOpMaker<paddle::framework::OpDesc>,
145+
ops::ClipDoubleGradOpMaker<paddle::imperative::OpBase>);
121146
REGISTER_OP_CPU_KERNEL(
122147
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>,
123148
ops::ClipKernel<paddle::platform::CPUDeviceContext, double>);

python/paddle/fluid/tests/unittests/test_nn_grad.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,27 @@ def test_grad(self):
329329
self.func(p)
330330

331331

332+
class TestClipDoubleGradCheck(unittest.TestCase):
333+
@prog_scope()
334+
def func(self, place):
335+
x_shape = [2, 4, 10]
336+
dtype = np.float64
337+
338+
x = layers.data('x', x_shape, False, dtype)
339+
x.persistable = True
340+
out = paddle.clip(x, min=-1., max=1.)
341+
x_arr = np.random.uniform(-5., 5., x_shape).astype(dtype)
342+
343+
gradient_checker.double_grad_check([x], out, x_init=x_arr, place=place)
344+
345+
def test_grad(self):
346+
places = [fluid.CPUPlace()]
347+
if core.is_compiled_with_cuda():
348+
places.append(fluid.CUDAPlace(0))
349+
for p in places:
350+
self.func(p)
351+
352+
332353
class TestTransposeDoubleGradCheck(unittest.TestCase):
333354
@prog_scope()
334355
def func(self, place):

0 commit comments

Comments
 (0)