Skip to content

Commit 7bd16fe

Browse files
committed
registry var type infer
1 parent 6e83c00 commit 7bd16fe

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

paddle/fluid/operators/split_selected_rows_op.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
5959
}
6060
};
6161

62+
class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference {
63+
public:
64+
void operator()(const framework::OpDesc &op_desc,
65+
framework::BlockDesc *block) const override {
66+
for (auto &out_var : op_desc.Output("Out")) {
67+
block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS);
68+
}
69+
}
70+
};
71+
6272
class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
6373
public:
6474
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
@@ -80,7 +90,8 @@ class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
8090
namespace ops = paddle::operators;
8191
REGISTER_OPERATOR(split_selected_rows, ops::SplitSelectedRowsOp,
8292
ops::SplitSelectedRowsOpMaker,
83-
ops::SplitSelectedRowsGradMaker);
93+
ops::SplitSelectedRowsGradMaker,
94+
ops::SplitSelectedRowsOpInferVarType);
8495
REGISTER_OP_CPU_KERNEL(
8596
split_selected_rows,
8697
ops::SplitSelectedRowsOpKernel<paddle::platform::CPUPlace, float>);

0 commit comments

Comments
 (0)