Skip to content

Commit d84cdb7

Browse files
authored
Merge pull request #9911 from tonyyang-svail/unify_op_registry
Unify REGISTER_OP and REGISTER_OPERATOR
2 parents 1fd1284 + 68d9638 commit d84cdb7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+379
-274
lines changed

paddle/fluid/framework/grad_op_desc_maker.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include <algorithm>
1617
#include <string>
1718
#include <unordered_set>
1819
#include <vector>
@@ -69,8 +70,7 @@ class GradOpDescMakerBase {
6970
" for input argument with a list of variables, "
7071
" drop_empty_grad is not allowed because it makes"
7172
" the correspondence bewteen a variable and its gradient"
72-
" ambiguous. Use REGISTER_OP_EX to register the op"
73-
" or call InputGrad(?,false) in GradOpDescMaker."
73+
" ambiguous."
7474
" Op type %s",
7575
fwd_op_.Type());
7676

paddle/fluid/framework/op_registry.h

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License. */
1616

1717
#include <algorithm>
1818
#include <atomic>
19+
#include <string>
20+
#include <tuple>
1921
#include <type_traits>
2022
#include <typeinfo>
2123
#include <unordered_map>
@@ -141,36 +143,6 @@ class OpKernelRegistrar : public Registrar {
141143
return 0; \
142144
}
143145

144-
/**
145-
* Macro to register Operator. When the input is duplicable, you should
146-
* use REGISTER_OP_EX with drop_empty_grad=false instead.
147-
*/
148-
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
149-
grad_op_class) \
150-
REGISTER_OP_EX(op_type, op_class, op_maker_class, grad_op_type, \
151-
grad_op_class, true)
152-
153-
// When an argument is duplicable, we need to use this version.
154-
// Perhaps we can omit DropEmptyIG template parameter and
155-
// only have one version of REGISTER_OP.
156-
#define REGISTER_OP_EX(op_type, op_class, op_maker_class, grad_op_type, \
157-
grad_op_class, drop_empty_grad) \
158-
REGISTER_OPERATOR(grad_op_type, grad_op_class); \
159-
class _GradOpDescMaker_##grad_op_type##_ \
160-
: public ::paddle::framework::DefaultGradOpDescMaker<drop_empty_grad> { \
161-
using ::paddle::framework::DefaultGradOpDescMaker< \
162-
drop_empty_grad>::DefaultGradOpDescMaker; \
163-
\
164-
protected: \
165-
virtual std::string GradOpType() const { return #grad_op_type; } \
166-
}; \
167-
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
168-
op_maker_class);
169-
170-
#define REGISTER_OP_WITH_KERNEL(op_type, ...) \
171-
REGISTER_OPERATOR(op_type, ::paddle::framework::OperatorWithKernel, \
172-
##__VA_ARGS__)
173-
174146
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
175147
REGISTER_OPERATOR(op_type, op_class, op_maker_class)
176148

paddle/fluid/operators/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ function(op_library TARGET)
110110
# Note that it's enough to just adding one operator to pybind in a *_op.cc file.
111111
# And for detail pybind information, please see generated paddle/pybind/pybind.h.
112112
file(READ ${TARGET}.cc TARGET_CONTENT)
113-
string(REGEX MATCH "REGISTER_OP\\(.*REGISTER_OP\\(" multi_register "${TARGET_CONTENT}")
114-
string(REGEX MATCH "REGISTER_OP\\([a-z0-9_]*," one_register "${multi_register}")
113+
string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}")
114+
string(REGEX MATCH "REGISTER_OPERATOR\\([a-z0-9_]*," one_register "${multi_register}")
115115
if (one_register STREQUAL "")
116116
string(REPLACE "_op" "" TARGET "${TARGET}")
117117
else ()
118-
string(REPLACE "REGISTER_OP(" "" TARGET "${one_register}")
118+
string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}")
119119
string(REPLACE "," "" TARGET "${TARGET}")
120120
endif()
121121

paddle/fluid/operators/activation_op.cc

Lines changed: 116 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -558,95 +558,126 @@ Swish Activation Operator.
558558

559559
namespace ops = paddle::operators;
560560

561-
REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad,
562-
ops::ActivationOpGrad);
561+
REGISTER_OPERATOR(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker,
562+
paddle::framework::DefaultGradOpDescMaker<true>)
563+
REGISTER_OPERATOR(sigmoid_grad, ops::ActivationOpGrad)
563564

564-
REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker,
565-
logsigmoid_grad, ops::ActivationOpGrad);
565+
REGISTER_OPERATOR(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker,
566+
paddle::framework::DefaultGradOpDescMaker<true>)
567+
REGISTER_OPERATOR(logsigmoid_grad, ops::ActivationOpGrad)
566568

567-
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
568-
ops::ActivationOpGrad);
569+
REGISTER_OPERATOR(exp, ops::ActivationOp, ops::ExpOpMaker,
570+
paddle::framework::DefaultGradOpDescMaker<true>)
571+
REGISTER_OPERATOR(exp_grad, ops::ActivationOpGrad)
572+
573+
REGISTER_OPERATOR(relu, ops::ActivationWithMKLDNNOp, ops::ReluOpMaker,
574+
paddle::framework::DefaultGradOpDescMaker<true>)
575+
REGISTER_OPERATOR(relu_grad, ops::ActivationWithMKLDNNOpGrad)
576+
577+
REGISTER_OPERATOR(tanh, ops::ActivationWithMKLDNNOp, ops::TanhOpMaker,
578+
paddle::framework::DefaultGradOpDescMaker<true>)
579+
REGISTER_OPERATOR(tanh_grad, ops::ActivationWithMKLDNNOpGrad)
580+
581+
REGISTER_OPERATOR(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
582+
paddle::framework::DefaultGradOpDescMaker<true>)
583+
REGISTER_OPERATOR(tanh_shrink_grad, ops::ActivationOpGrad)
584+
585+
REGISTER_OPERATOR(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker,
586+
paddle::framework::DefaultGradOpDescMaker<true>)
587+
REGISTER_OPERATOR(softshrink_grad, ops::ActivationOpGrad)
588+
589+
REGISTER_OPERATOR(sqrt, ops::ActivationWithMKLDNNOp, ops::SqrtOpMaker,
590+
paddle::framework::DefaultGradOpDescMaker<true>)
591+
REGISTER_OPERATOR(sqrt_grad, ops::ActivationWithMKLDNNOpGrad)
592+
593+
REGISTER_OPERATOR(abs, ops::ActivationWithMKLDNNOp, ops::AbsOpMaker,
594+
paddle::framework::DefaultGradOpDescMaker<true>)
595+
REGISTER_OPERATOR(abs_grad, ops::ActivationWithMKLDNNOpGrad)
596+
597+
REGISTER_OPERATOR(ceil, ops::ActivationOp, ops::CeilOpMaker,
598+
paddle::framework::DefaultGradOpDescMaker<true>)
599+
REGISTER_OPERATOR(ceil_grad, ops::ActivationOpGrad)
600+
601+
REGISTER_OPERATOR(floor, ops::ActivationOp, ops::FloorOpMaker,
602+
paddle::framework::DefaultGradOpDescMaker<true>)
603+
REGISTER_OPERATOR(floor_grad, ops::ActivationOpGrad)
604+
605+
REGISTER_OPERATOR(cos, ops::ActivationOp, ops::CosOpMaker,
606+
paddle::framework::DefaultGradOpDescMaker<true>)
607+
REGISTER_OPERATOR(cos_grad, ops::ActivationOpGrad)
608+
609+
REGISTER_OPERATOR(sin, ops::ActivationOp, ops::SinOpMaker,
610+
paddle::framework::DefaultGradOpDescMaker<true>)
611+
REGISTER_OPERATOR(sin_grad, ops::ActivationOpGrad)
612+
613+
REGISTER_OPERATOR(round, ops::ActivationOp, ops::RoundOpMaker,
614+
paddle::framework::DefaultGradOpDescMaker<true>)
615+
REGISTER_OPERATOR(round_grad, ops::ActivationOpGrad)
616+
617+
REGISTER_OPERATOR(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker,
618+
paddle::framework::DefaultGradOpDescMaker<true>)
619+
REGISTER_OPERATOR(reciprocal_grad, ops::ActivationOpGrad)
620+
621+
REGISTER_OPERATOR(log, ops::ActivationOp, ops::LogOpMaker,
622+
paddle::framework::DefaultGradOpDescMaker<true>)
623+
REGISTER_OPERATOR(log_grad, ops::ActivationOpGrad)
624+
625+
REGISTER_OPERATOR(square, ops::ActivationOp, ops::SquareOpMaker,
626+
paddle::framework::DefaultGradOpDescMaker<true>)
627+
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad)
628+
629+
REGISTER_OPERATOR(softplus, ops::ActivationOp, ops::SoftplusOpMaker,
630+
paddle::framework::DefaultGradOpDescMaker<true>)
631+
REGISTER_OPERATOR(softplus_grad, ops::ActivationOpGrad)
632+
633+
REGISTER_OPERATOR(softsign, ops::ActivationOp, ops::SoftsignOpMaker,
634+
paddle::framework::DefaultGradOpDescMaker<true>)
635+
REGISTER_OPERATOR(softsign_grad, ops::ActivationOpGrad)
636+
637+
REGISTER_OPERATOR(brelu, ops::ActivationOp, ops::BReluOpMaker,
638+
paddle::framework::DefaultGradOpDescMaker<true>)
639+
REGISTER_OPERATOR(brelu_grad, ops::ActivationOpGrad)
640+
641+
REGISTER_OPERATOR(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
642+
paddle::framework::DefaultGradOpDescMaker<true>)
643+
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad)
644+
645+
REGISTER_OPERATOR(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker,
646+
paddle::framework::DefaultGradOpDescMaker<true>)
647+
REGISTER_OPERATOR(soft_relu_grad, ops::ActivationOpGrad)
648+
649+
REGISTER_OPERATOR(elu, ops::ActivationOp, ops::ELUOpMaker,
650+
paddle::framework::DefaultGradOpDescMaker<true>)
651+
REGISTER_OPERATOR(elu_grad, ops::ActivationOpGrad)
652+
653+
REGISTER_OPERATOR(relu6, ops::ActivationOp, ops::Relu6OpMaker,
654+
paddle::framework::DefaultGradOpDescMaker<true>)
655+
REGISTER_OPERATOR(relu6_grad, ops::ActivationOpGrad)
656+
657+
REGISTER_OPERATOR(pow, ops::ActivationOp, ops::PowOpMaker,
658+
paddle::framework::DefaultGradOpDescMaker<true>)
659+
REGISTER_OPERATOR(pow_grad, ops::ActivationOpGrad)
660+
661+
REGISTER_OPERATOR(stanh, ops::ActivationOp, ops::STanhOpMaker,
662+
paddle::framework::DefaultGradOpDescMaker<true>)
663+
REGISTER_OPERATOR(stanh_grad, ops::ActivationOpGrad)
569664

570-
REGISTER_OP(relu, ops::ActivationWithMKLDNNOp, ops::ReluOpMaker, relu_grad,
571-
ops::ActivationWithMKLDNNOpGrad);
665+
REGISTER_OPERATOR(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker,
666+
paddle::framework::DefaultGradOpDescMaker<true>)
667+
REGISTER_OPERATOR(hard_shrink_grad, ops::ActivationOpGrad)
572668

573-
REGISTER_OP(tanh, ops::ActivationWithMKLDNNOp, ops::TanhOpMaker, tanh_grad,
574-
ops::ActivationWithMKLDNNOpGrad);
575-
576-
REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
577-
tanh_shrink_grad, ops::ActivationOpGrad);
578-
579-
REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker,
580-
softshrink_grad, ops::ActivationOpGrad);
581-
582-
REGISTER_OP(sqrt, ops::ActivationWithMKLDNNOp, ops::SqrtOpMaker, sqrt_grad,
583-
ops::ActivationWithMKLDNNOpGrad);
584-
585-
REGISTER_OP(abs, ops::ActivationWithMKLDNNOp, ops::AbsOpMaker, abs_grad,
586-
ops::ActivationWithMKLDNNOpGrad);
587-
588-
REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad,
589-
ops::ActivationOpGrad);
590-
591-
REGISTER_OP(floor, ops::ActivationOp, ops::FloorOpMaker, floor_grad,
592-
ops::ActivationOpGrad);
593-
594-
REGISTER_OP(cos, ops::ActivationOp, ops::CosOpMaker, cos_grad,
595-
ops::ActivationOpGrad);
596-
597-
REGISTER_OP(sin, ops::ActivationOp, ops::SinOpMaker, sin_grad,
598-
ops::ActivationOpGrad);
599-
600-
REGISTER_OP(round, ops::ActivationOp, ops::RoundOpMaker, round_grad,
601-
ops::ActivationOpGrad);
602-
603-
REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker,
604-
reciprocal_grad, ops::ActivationOpGrad);
605-
606-
REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad,
607-
ops::ActivationOpGrad);
608-
609-
REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad,
610-
ops::ActivationOpGrad);
611-
612-
REGISTER_OP(softplus, ops::ActivationOp, ops::SoftplusOpMaker, softplus_grad,
613-
ops::ActivationOpGrad);
614-
615-
REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad,
616-
ops::ActivationOpGrad);
617-
618-
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker, brelu_grad,
619-
ops::ActivationOpGrad);
620-
621-
REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
622-
leaky_relu_grad, ops::ActivationOpGrad);
623-
624-
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, soft_relu_grad,
625-
ops::ActivationOpGrad);
626-
627-
REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker, elu_grad,
628-
ops::ActivationOpGrad);
629-
630-
REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker, relu6_grad,
631-
ops::ActivationOpGrad);
632-
633-
REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad,
634-
ops::ActivationOpGrad);
635-
636-
REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad,
637-
ops::ActivationOpGrad);
638-
639-
REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker,
640-
hard_shrink_grad, ops::ActivationOpGrad);
641-
642-
REGISTER_OP(thresholded_relu, ops::ActivationOp, ops::ThresholdedReluOpMaker,
643-
thresholded_relu_grad, ops::ActivationOpGrad);
644-
645-
REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker,
646-
hard_sigmoid_grad, ops::ActivationOpGrad);
647-
648-
REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad,
649-
ops::ActivationOpGrad);
669+
REGISTER_OPERATOR(thresholded_relu, ops::ActivationOp,
670+
ops::ThresholdedReluOpMaker,
671+
paddle::framework::DefaultGradOpDescMaker<true>)
672+
REGISTER_OPERATOR(thresholded_relu_grad, ops::ActivationOpGrad)
673+
674+
REGISTER_OPERATOR(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker,
675+
paddle::framework::DefaultGradOpDescMaker<true>)
676+
REGISTER_OPERATOR(hard_sigmoid_grad, ops::ActivationOpGrad)
677+
678+
REGISTER_OPERATOR(swish, ops::ActivationOp, ops::SwishOpMaker,
679+
paddle::framework::DefaultGradOpDescMaker<true>)
680+
REGISTER_OPERATOR(swish_grad, ops::ActivationOpGrad)
650681

651682
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
652683
REGISTER_OP_CPU_KERNEL( \

paddle/fluid/operators/bilinear_tensor_product_op.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,11 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
153153
} // namespace paddle
154154

155155
namespace ops = paddle::operators;
156-
REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp,
157-
ops::BilinearTensorProductOpMaker, bilinear_tensor_product_grad,
158-
ops::BilinearTensorProductOpGrad);
156+
REGISTER_OPERATOR(bilinear_tensor_product, ops::BilinearTensorProductOp,
157+
ops::BilinearTensorProductOpMaker,
158+
paddle::framework::DefaultGradOpDescMaker<true>)
159+
REGISTER_OPERATOR(bilinear_tensor_product_grad,
160+
ops::BilinearTensorProductOpGrad)
159161
REGISTER_OP_CPU_KERNEL(
160162
bilinear_tensor_product,
161163
ops::BilinearTensorProductKernel<paddle::platform::CPUDeviceContext, float>,

paddle/fluid/operators/clip_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ class ClipOpGrad : public framework::OperatorWithKernel {
8181
} // namespace paddle
8282

8383
namespace ops = paddle::operators;
84-
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
85-
ops::ClipOpGrad);
84+
REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>,
85+
paddle::framework::DefaultGradOpDescMaker<true>)
86+
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad)
8687
REGISTER_OP_CPU_KERNEL(
8788
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>);
8889
REGISTER_OP_CPU_KERNEL(

paddle/fluid/operators/concat_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
103103
} // namespace paddle
104104

105105
namespace ops = paddle::operators;
106-
REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad,
107-
ops::ConcatOpGrad, false)
106+
REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
107+
paddle::framework::DefaultGradOpDescMaker<
108+
false> /* set false to disable empty grad */)
109+
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad)
108110
REGISTER_OP_CPU_KERNEL(
109111
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>)
110112
REGISTER_OP_CPU_KERNEL(

paddle/fluid/operators/conv_op.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,17 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
335335
} // namespace paddle
336336

337337
namespace ops = paddle::operators;
338-
REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad,
339-
ops::ConvOpGrad);
338+
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
339+
paddle::framework::DefaultGradOpDescMaker<true>)
340+
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad)
340341

341342
// depthwise convolution op
342-
REGISTER_OP(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
343-
depthwise_conv2d_grad, ops::ConvOpGrad);
344-
REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
345-
ops::ConvOpGrad);
343+
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
344+
paddle::framework::DefaultGradOpDescMaker<true>)
345+
REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad)
346+
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
347+
paddle::framework::DefaultGradOpDescMaker<true>)
348+
REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad)
346349

347350
// depthwise conv kernel
348351
// TODO(xingzhaolong): neon kernel for mobile

paddle/fluid/operators/conv_shift_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,9 @@ class ConvShiftGradKernel<platform::CPUPlace, T>
193193
} // namespace paddle
194194

195195
namespace ops = paddle::operators;
196-
REGISTER_OP(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker,
197-
conv_shift_grad, ops::ConvShiftGradOp);
196+
REGISTER_OPERATOR(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker,
197+
paddle::framework::DefaultGradOpDescMaker<true>)
198+
REGISTER_OPERATOR(conv_shift_grad, ops::ConvShiftGradOp)
198199
REGISTER_OP_CPU_KERNEL(conv_shift,
199200
ops::ConvShiftKernel<paddle::platform::CPUPlace, float>);
200201
REGISTER_OP_CPU_KERNEL(

paddle/fluid/operators/conv_transpose_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,10 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
298298

299299
namespace ops = paddle::operators;
300300

301-
REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
302-
conv2d_transpose_grad, ops::ConvTransposeOpGrad);
301+
REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp,
302+
ops::Conv2DTransposeOpMaker,
303+
paddle::framework::DefaultGradOpDescMaker<true>)
304+
REGISTER_OPERATOR(conv2d_transpose_grad, ops::ConvTransposeOpGrad)
303305

304306
REGISTER_OP_CPU_KERNEL(
305307
conv2d_transpose,
@@ -311,8 +313,10 @@ REGISTER_OP_CPU_KERNEL(
311313
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
312314
double>);
313315

314-
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
315-
conv3d_transpose_grad, ops::ConvTransposeOpGrad);
316+
REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp,
317+
ops::Conv3DTransposeOpMaker,
318+
paddle::framework::DefaultGradOpDescMaker<true>)
319+
REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad)
316320

317321
REGISTER_OP_CPU_KERNEL(
318322
conv3d_transpose,

0 commit comments

Comments
 (0)