File tree Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -59,6 +59,16 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
59
59
}
60
60
};
61
61
62
+ class SplitSelectedRowsOpInferVarType : public framework ::VarTypeInference {
63
+ public:
64
+ void operator ()(const framework::OpDesc &op_desc,
65
+ framework::BlockDesc *block) const override {
66
+ for (auto &out_var : op_desc.Output (" Out" )) {
67
+ block->Var (out_var)->SetType (framework::proto::VarType::SELECTED_ROWS);
68
+ }
69
+ }
70
+ };
71
+
62
72
class SplitSelectedRowsGradMaker : public framework ::SingleGradOpDescMaker {
63
73
public:
64
74
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
@@ -80,7 +90,8 @@ class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
80
90
namespace ops = paddle::operators;
81
91
REGISTER_OPERATOR (split_selected_rows, ops::SplitSelectedRowsOp,
82
92
ops::SplitSelectedRowsOpMaker,
83
- ops::SplitSelectedRowsGradMaker);
93
+ ops::SplitSelectedRowsGradMaker,
94
+ ops::SplitSelectedRowsOpInferVarType);
84
95
REGISTER_OP_CPU_KERNEL (
85
96
split_selected_rows,
86
97
ops::SplitSelectedRowsOpKernel<paddle::platform::CPUPlace, float >);
You can’t perform that action at this time.
0 commit comments