Skip to content

Commit 4143a1c

Browse files
authored
Merge pull request #16491 from sneaxiy/feature/advance_gc
Fix grad op makers
2 parents 2265d09 + 4c8254e commit 4143a1c

22 files changed

+443
-57
lines changed

paddle/fluid/operators/bpr_loss_op.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/bpr_loss_op.h"
16+
#include <memory>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -127,14 +128,31 @@ neural networks>(https://arxiv.org/abs/1511.06939)
127128
)DOC");
128129
}
129130
};
131+
132+
class BprLossGradDescMaker : public framework::SingleGradOpDescMaker {
133+
public:
134+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
135+
136+
protected:
137+
std::unique_ptr<framework::OpDesc> Apply() const override {
138+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
139+
op->SetType("bpr_loss_grad");
140+
op->SetInput("X", Input("X"));
141+
op->SetInput("Label", Input("Label"));
142+
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
143+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
144+
op->SetAttrMap(Attrs());
145+
return op;
146+
}
147+
};
130148
} // namespace operators
131149
} // namespace paddle
132150

133151
namespace ops = paddle::operators;
134152
using CPUCtx = paddle::platform::CPUDeviceContext;
135153

136154
REGISTER_OPERATOR(bpr_loss, ops::BprLossOp, ops::BprLossOpMaker,
137-
paddle::framework::DefaultGradOpDescMaker<true>);
155+
ops::BprLossGradDescMaker);
138156
REGISTER_OPERATOR(bpr_loss_grad, ops::BprLossGradientOp);
139157
REGISTER_OP_CPU_KERNEL(bpr_loss, ops::BprLossOpKernel<CPUCtx, float>,
140158
ops::BprLossOpKernel<CPUCtx, double>);

paddle/fluid/operators/detection/roi_perspective_transform_op.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <algorithm>
16+
#include <memory>
1617
#include <vector>
1718
#include "paddle/fluid/framework/op_registry.h"
1819
#include "paddle/fluid/operators/math/math_function.h"
@@ -568,13 +569,31 @@ class ROIPerspectiveTransformOpMaker
568569
}
569570
};
570571

572+
class ROIPerspectiveTransformGradDescMaker
573+
: public framework::SingleGradOpDescMaker {
574+
public:
575+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
576+
577+
protected:
578+
std::unique_ptr<framework::OpDesc> Apply() const override {
579+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
580+
op->SetType("roi_perspective_transform_grad");
581+
op->SetInput("X", Input("X"));
582+
op->SetInput("ROIs", Input("ROIs"));
583+
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
584+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
585+
op->SetAttrMap(Attrs());
586+
return op;
587+
}
588+
};
589+
571590
} // namespace operators
572591
} // namespace paddle
573592

574593
namespace ops = paddle::operators;
575594
REGISTER_OPERATOR(roi_perspective_transform, ops::ROIPerspectiveTransformOp,
576595
ops::ROIPerspectiveTransformOpMaker,
577-
paddle::framework::DefaultGradOpDescMaker<true>);
596+
ops::ROIPerspectiveTransformGradDescMaker);
578597
REGISTER_OPERATOR(roi_perspective_transform_grad,
579598
ops::ROIPerspectiveTransformGradOp);
580599
REGISTER_OP_CPU_KERNEL(roi_perspective_transform,

paddle/fluid/operators/gaussian_random_batch_size_like_op.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,17 @@ by input arguments.
6565
}
6666
};
6767

68+
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
69+
GaussianRandomBatchSizeLikeNoNeedBufferVarsInference, "Input");
70+
6871
} // namespace operators
6972
} // namespace paddle
7073

71-
REGISTER_OP_WITHOUT_GRADIENT(
74+
REGISTER_OPERATOR(
7275
gaussian_random_batch_size_like,
7376
paddle::operators::GaussianRandomBatchSizeLikeOp,
74-
paddle::operators::GaussianRandomBatchSizeLikeOpMaker);
77+
paddle::operators::GaussianRandomBatchSizeLikeOpMaker,
78+
paddle::framework::EmptyGradOpMaker,
79+
paddle::operators::GaussianRandomBatchSizeLikeNoNeedBufferVarsInference);
80+
7581
// Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu

paddle/fluid/operators/im2sequence_op.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/im2sequence_op.h"
16+
#include <memory>
1617
#include <string>
1718
#include <vector>
1819

@@ -146,12 +147,28 @@ class Im2SequenceGradOp : public framework::OperatorWithKernel {
146147
}
147148
};
148149

150+
class Im2SequenceGradDescMaker : public framework::SingleGradOpDescMaker {
151+
public:
152+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
153+
154+
protected:
155+
std::unique_ptr<framework::OpDesc> Apply() const override {
156+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
157+
op->SetType("im2sequence_grad");
158+
op->SetInput("X", Input("X"));
159+
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
160+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
161+
op->SetAttrMap(Attrs());
162+
return op;
163+
}
164+
};
165+
149166
} // namespace operators
150167
} // namespace paddle
151168

152169
namespace ops = paddle::operators;
153170
REGISTER_OPERATOR(im2sequence, ops::Im2SequenceOp, ops::Im2SequenceOpMaker,
154-
paddle::framework::DefaultGradOpDescMaker<true>);
171+
ops::Im2SequenceGradDescMaker);
155172
REGISTER_OPERATOR(im2sequence_grad, ops::Im2SequenceGradOp);
156173
REGISTER_OP_CPU_KERNEL(
157174
im2sequence,

paddle/fluid/operators/interpolate_op.cc

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
limitations under the License. */
1111

1212
#include "paddle/fluid/operators/interpolate_op.h"
13+
#include <memory>
1314
#include <string>
1415
#include <vector>
1516
#include "paddle/fluid/framework/op_registry.h"
@@ -194,21 +195,46 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
194195

195196
framework::OpKernelType GetExpectedKernelType(
196197
const framework::ExecutionContext& ctx) const override {
197-
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
198-
ctx.GetPlace());
198+
return framework::OpKernelType(
199+
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
200+
ctx.GetPlace());
201+
}
202+
};
203+
204+
class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker {
205+
public:
206+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
207+
208+
protected:
209+
std::unique_ptr<framework::OpDesc> Apply() const override {
210+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
211+
op->SetType(ForwardOp().Type() + "_grad");
212+
op->SetInput("X", Input("X"));
213+
if (ForwardOp().Inputs().count("OutSize") > 0) {
214+
op->SetInput("OutSize", Input("OutSize"));
215+
}
216+
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
217+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
218+
op->SetAttrMap(Attrs());
219+
return op;
199220
}
200221
};
201222

223+
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(InterpolateGradNoNeedBufferVarsInference,
224+
"X");
225+
202226
} // namespace operators
203227
} // namespace paddle
204228

205229
namespace ops = paddle::operators;
206230
REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
207-
paddle::framework::DefaultGradOpDescMaker<true>);
208-
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad);
231+
ops::InterpolateGradDescMaker);
232+
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
233+
ops::InterpolateGradNoNeedBufferVarsInference);
209234
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
210-
paddle::framework::DefaultGradOpDescMaker<true>);
211-
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad);
235+
ops::InterpolateGradDescMaker);
236+
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
237+
ops::InterpolateGradNoNeedBufferVarsInference);
212238
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
213239
ops::InterpolateKernel<double>,
214240
ops::InterpolateKernel<uint8_t>);

paddle/fluid/operators/l1_norm_op.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/l1_norm_op.h"
16+
#include <memory>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -62,12 +63,28 @@ Computes the L1 norm of a tensor.
6263
}
6364
};
6465

66+
class L1NormGradDescMaker : public framework::SingleGradOpDescMaker {
67+
public:
68+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
69+
70+
protected:
71+
std::unique_ptr<framework::OpDesc> Apply() const override {
72+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
73+
op->SetType("l1_norm_grad");
74+
op->SetInput("X", Input("X"));
75+
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
76+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
77+
op->SetAttrMap(Attrs());
78+
return op;
79+
}
80+
};
81+
6582
} // namespace operators
6683
} // namespace paddle
6784

6885
namespace ops = paddle::operators;
6986
REGISTER_OPERATOR(l1_norm, ops::L1NormOp, ops::L1NormOpMaker,
70-
paddle::framework::DefaultGradOpDescMaker<true>);
87+
ops::L1NormGradDescMaker);
7188
REGISTER_OPERATOR(l1_norm_grad, ops::L1NormGradOp);
7289
REGISTER_OP_CPU_KERNEL(
7390
l1_norm, ops::L1NormKernel<paddle::platform::CPUDeviceContext, float>);

paddle/fluid/operators/label_smooth_op.cc

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/label_smooth_op.h"
16+
#include <memory>
1617
#include <string>
1718

1819
namespace paddle {
@@ -105,10 +106,23 @@ class LabelSmoothGradOp : public framework::OperatorWithKernel {
105106
: OperatorWithKernel(type, inputs, outputs, attrs) {}
106107

107108
void InferShape(framework::InferShapeContext *ctx) const override {
108-
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
109-
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
110-
"Input(Out@GRAD) shouldn't be null.");
111-
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
109+
ctx->SetOutputDim(framework::GradVarName("X"),
110+
ctx->GetInputDim(framework::GradVarName("Out")));
111+
}
112+
};
113+
114+
class LabelSmoothGradDescMaker : public framework::SingleGradOpDescMaker {
115+
public:
116+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
117+
118+
protected:
119+
std::unique_ptr<framework::OpDesc> Apply() const override {
120+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
121+
op->SetType("label_smooth_grad");
122+
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
123+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
124+
op->SetAttrMap(Attrs());
125+
return op;
112126
}
113127
};
114128

@@ -117,7 +131,7 @@ class LabelSmoothGradOp : public framework::OperatorWithKernel {
117131
namespace ops = paddle::operators;
118132

119133
REGISTER_OPERATOR(label_smooth, ops::LabelSmoothOp, ops::LabelSmoothOpMaker,
120-
paddle::framework::DefaultGradOpDescMaker<true>);
134+
ops::LabelSmoothGradDescMaker);
121135
REGISTER_OPERATOR(label_smooth_grad, ops::LabelSmoothGradOp);
122136
REGISTER_OP_CPU_KERNEL(
123137
label_smooth,

paddle/fluid/operators/linear_chain_crf_op.cc

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/linear_chain_crf_op.h"
16+
#include <memory>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -250,14 +251,46 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
250251
}
251252
};
252253

254+
class LinearChainCRFGradDescMaker : public framework::SingleGradOpDescMaker {
255+
public:
256+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
257+
258+
protected:
259+
std::unique_ptr<framework::OpDesc> Apply() const override {
260+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
261+
op->SetType("linear_chain_crf_grad");
262+
op->SetAttrMap(Attrs());
263+
264+
op->SetInput("Emission", Input("Emission"));
265+
op->SetInput("Transition", Input("Transition"));
266+
op->SetInput("Label", Input("Label"));
267+
268+
op->SetInput("Alpha", Output("Alpha"));
269+
op->SetInput("EmissionExps", Output("EmissionExps"));
270+
op->SetInput("TransitionExps", Output("TransitionExps"));
271+
272+
op->SetInput(framework::GradVarName("LogLikelihood"),
273+
OutputGrad("LogLikelihood"));
274+
275+
op->SetOutput(framework::GradVarName("Emission"), InputGrad("Emission"));
276+
op->SetOutput(framework::GradVarName("Transition"),
277+
InputGrad("Transition"));
278+
279+
return op;
280+
}
281+
};
282+
283+
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
284+
LinearChainCRFGradNoNeedBufferVarsInference, "Transition", "Emission");
285+
253286
} // namespace operators
254287
} // namespace paddle
255288

256289
namespace ops = paddle::operators;
257290
REGISTER_OPERATOR(linear_chain_crf, ops::LinearChainCRFOp,
258-
ops::LinearChainCRFOpMaker,
259-
paddle::framework::DefaultGradOpDescMaker<true>);
260-
REGISTER_OPERATOR(linear_chain_crf_grad, ops::LinearChainCRFGradOp);
291+
ops::LinearChainCRFOpMaker, ops::LinearChainCRFGradDescMaker);
292+
REGISTER_OPERATOR(linear_chain_crf_grad, ops::LinearChainCRFGradOp,
293+
ops::LinearChainCRFGradNoNeedBufferVarsInference);
261294
REGISTER_OP_CPU_KERNEL(
262295
linear_chain_crf,
263296
ops::LinearChainCRFOpKernel<paddle::platform::CPUDeviceContext, float>,

paddle/fluid/operators/log_loss_op.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/log_loss_op.h"
16+
#include <memory>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -100,12 +101,29 @@ class LogLossGradOp : public framework::OperatorWithKernel {
100101
}
101102
};
102103

104+
class LogLossGradDescMaker : public framework::SingleGradOpDescMaker {
105+
public:
106+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
107+
108+
protected:
109+
std::unique_ptr<framework::OpDesc> Apply() const override {
110+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
111+
op->SetType("log_loss_grad");
112+
op->SetInput("Predicted", Input("Predicted"));
113+
op->SetInput("Labels", Input("Labels"));
114+
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
115+
op->SetOutput(framework::GradVarName("Predicted"), InputGrad("Predicted"));
116+
op->SetAttrMap(Attrs());
117+
return op;
118+
}
119+
};
120+
103121
} // namespace operators
104122
} // namespace paddle
105123

106124
namespace ops = paddle::operators;
107125
REGISTER_OPERATOR(log_loss, ops::LogLossOp, ops::LogLossOpMaker<float>,
108-
paddle::framework::DefaultGradOpDescMaker<true>);
126+
ops::LogLossGradDescMaker);
109127
REGISTER_OPERATOR(log_loss_grad, ops::LogLossGradOp);
110128
REGISTER_OP_CPU_KERNEL(
111129
log_loss, ops::LogLossKernel<paddle::platform::CPUDeviceContext, float>);

0 commit comments

Comments
 (0)