@@ -24,7 +24,7 @@ class MomentumOp : public framework::OperatorWithKernel {
24
24
using framework::OperatorWithKernel::OperatorWithKernel;
25
25
26
26
protected:
27
- void InferShape (framework::InferShapeContext * ctx) const override {
27
+ void InferShape (framework::InferShapeContext* ctx) const override {
28
28
PADDLE_ENFORCE (ctx->HasInput (" Param" ),
29
29
" Input(param) of Momentum should not be null." );
30
30
PADDLE_ENFORCE (ctx->HasInput (" Grad" ),
@@ -45,26 +45,50 @@ class MomentumOp : public framework::OperatorWithKernel {
45
45
" Output(VelocityOut) of Momentum should not be null." );
46
46
47
47
auto param_dim = ctx->GetInputDim (" Param" );
48
- PADDLE_ENFORCE_EQ (
49
- param_dim, ctx->GetInputDim (" Grad" ),
50
- " Param and Grad input of MomentumOp should have the same dimension." );
51
- PADDLE_ENFORCE_EQ (
52
- param_dim, ctx->GetInputDim (" Velocity" ),
53
- " Param and Velocity of MomentumOp should have the same dimension." );
48
+ if (ctx->GetInputsVarType (" Grad" )[0 ] ==
49
+ framework::proto::VarType::LOD_TENSOR) {
50
+ PADDLE_ENFORCE_EQ (
51
+ param_dim, ctx->GetInputDim (" Grad" ),
52
+ " Param and Grad input of MomentumOp should have the same dimension." );
53
+ PADDLE_ENFORCE_EQ (
54
+ param_dim, ctx->GetInputDim (" Velocity" ),
55
+ " Param and Velocity of MomentumOp should have the same dimension." );
56
+ }
54
57
PADDLE_ENFORCE_EQ (framework::product (ctx->GetInputDim (" LearningRate" )), 1 ,
55
58
" Learning_rate should be a scalar" );
56
59
57
60
ctx->SetOutputDim (" ParamOut" , param_dim);
58
61
ctx->SetOutputDim (" VelocityOut" , param_dim);
59
62
}
60
63
framework::OpKernelType GetExpectedKernelType (
61
- const framework::ExecutionContext &ctx) const override {
62
- auto input_data_type =
63
- framework::ToDataType (ctx.Input <Tensor>(" Param" )->type ());
64
+ const framework::ExecutionContext& ctx) const override {
65
+ auto input_data_type = framework::GetDataTypeOfVar (ctx.InputVar (" Param" ));
64
66
return framework::OpKernelType (input_data_type, ctx.GetPlace ());
65
67
}
66
68
};
67
69
70
+ class MomentumOpInferVarType : public framework ::VarTypeInference {
71
+ public:
72
+ void operator ()(const framework::OpDesc& op_desc,
73
+ framework::BlockDesc* block) const override {
74
+ auto input_var = op_desc.Input (" Param" )[0 ];
75
+ for (auto & out_var : op_desc.Output (" ParamOut" )) {
76
+ if (block->FindRecursiveOrCreateVar (input_var).GetType () ==
77
+ framework::proto::VarType::SELECTED_ROWS) {
78
+ block->FindRecursiveOrCreateVar (out_var).SetType (
79
+ framework::proto::VarType::SELECTED_ROWS);
80
+ } else if (block->FindRecursiveOrCreateVar (input_var).GetType () ==
81
+ framework::proto::VarType::LOD_TENSOR) {
82
+ block->FindRecursiveOrCreateVar (out_var).SetType (
83
+ framework::proto::VarType::LOD_TENSOR);
84
+ } else {
85
+ PADDLE_THROW (
86
+ " Only support LodTensor and SelectedRows, Unexpected Input Type." );
87
+ }
88
+ }
89
+ }
90
+ };
91
+
68
92
class MomentumOpMaker : public framework ::OpProtoAndCheckerMaker {
69
93
public:
70
94
void Make () override {
@@ -115,6 +139,9 @@ else: \\
115
139
} // namespace paddle
116
140
117
141
namespace ops = paddle::operators;
118
- REGISTER_OP_WITHOUT_GRADIENT (momentum, ops::MomentumOp, ops::MomentumOpMaker);
119
- REGISTER_OP_CPU_KERNEL (momentum, ops::MomentumOpKernel<float >,
120
- ops::MomentumOpKernel<double >);
142
+ REGISTER_OPERATOR (momentum, ops::MomentumOp, ops::MomentumOpMaker,
143
+ paddle::framework::EmptyGradOpMaker,
144
+ ops::MomentumOpInferVarType);
145
+ REGISTER_OP_CPU_KERNEL (
146
+ momentum, ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, float >,
147
+ ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, double >);
0 commit comments