@@ -109,6 +109,29 @@ DECLARE_INPLACE_OP_INFERER(ClipGradInplaceInferer,
109
109
{framework::GradVarName (" Out" ),
110
110
framework::GradVarName (" X" )});
111
111
112
+ template <typename T>
113
+ class ClipDoubleGradOpMaker : public framework ::SingleGradOpMaker<T> {
114
+ public:
115
+ using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
116
+
117
+ protected:
118
+ void Apply (GradOpPtr<T> op) const override {
119
+ op->SetType (" clip_grad" );
120
+ op->SetInput (" X" , this ->Input (" X" ));
121
+ if (this ->HasInput (" Min" )) {
122
+ op->SetInput (" Min" , this ->Input (" Min" ));
123
+ }
124
+ if (this ->HasInput (" Max" )) {
125
+ op->SetInput (" Max" , this ->Input (" Max" ));
126
+ }
127
+ op->SetInput (framework::GradVarName (" Out" ),
128
+ this ->OutputGrad (framework::GradVarName (" X" )));
129
+ op->SetOutput (framework::GradVarName (" X" ),
130
+ this ->InputGrad (framework::GradVarName (" Out" )));
131
+ op->SetAttrMap (this ->Attrs ());
132
+ }
133
+ };
134
+
112
135
} // namespace operators
113
136
} // namespace paddle
114
137
@@ -117,7 +140,9 @@ REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>,
117
140
ops::ClipGradOpMaker<paddle::framework::OpDesc>,
118
141
ops::ClipGradOpMaker<paddle::imperative::OpBase>,
119
142
ops::ClipInplaceInferer);
120
- REGISTER_OPERATOR (clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer);
143
+ REGISTER_OPERATOR (clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer,
144
+ ops::ClipDoubleGradOpMaker<paddle::framework::OpDesc>,
145
+ ops::ClipDoubleGradOpMaker<paddle::imperative::OpBase>);
121
146
REGISTER_OP_CPU_KERNEL (
122
147
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float >,
123
148
ops::ClipKernel<paddle::platform::CPUDeviceContext, double >);
0 commit comments