Skip to content

Commit 718642e

Browse files
author
Yancey
authored
Merge pull request #8659 from Yancey1989/fix_dist_bug
Registry var type infer in split_selected_rows op
2 parents e9f2033 + 7bd16fe commit 718642e

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

paddle/fluid/operators/split_selected_rows_op.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
5959
}
6060
};
6161

62+
class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference {
63+
public:
64+
void operator()(const framework::OpDesc &op_desc,
65+
framework::BlockDesc *block) const override {
66+
for (auto &out_var : op_desc.Output("Out")) {
67+
block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS);
68+
}
69+
}
70+
};
71+
6272
class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
6373
public:
6474
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
@@ -80,7 +90,8 @@ class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
8090
namespace ops = paddle::operators;
8191
REGISTER_OPERATOR(split_selected_rows, ops::SplitSelectedRowsOp,
8292
ops::SplitSelectedRowsOpMaker,
83-
ops::SplitSelectedRowsGradMaker);
93+
ops::SplitSelectedRowsGradMaker,
94+
ops::SplitSelectedRowsOpInferVarType);
8495
REGISTER_OP_CPU_KERNEL(
8596
split_selected_rows,
8697
ops::SplitSelectedRowsOpKernel<paddle::platform::CPUPlace, float>);

python/paddle/fluid/distribute_transpiler.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,15 @@ def get_pserver_program(self, endpoint):
276276
pserver_program.global_block().create_var(
277277
name=orig_var_name,
278278
persistable=True,
279+
type=v.type,
279280
dtype=v.dtype,
280281
shape=v.shape)
281282
print("create origin var: ", orig_var_name)
282283
for trainer_id in xrange(self.trainers):
283284
var = pserver_program.global_block().create_var(
284285
name="%s.trainer_%d" % (orig_var_name, trainer_id),
285286
persistable=False,
287+
type=v.type,
286288
dtype=v.dtype,
287289
shape=v.shape)
288290
recv_inputs.append(var)
@@ -551,11 +553,12 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
551553
type="sum",
552554
inputs={"X": vars2merge},
553555
outputs={"Out": merged_var})
554-
optimize_block.append_op(
555-
type="scale",
556-
inputs={"X": merged_var},
557-
outputs={"Out": merged_var},
558-
attrs={"scale": 1.0 / float(self.trainers)})
556+
if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS:
557+
optimize_block.append_op(
558+
type="scale",
559+
inputs={"X": merged_var},
560+
outputs={"Out": merged_var},
561+
attrs={"scale": 1.0 / float(self.trainers)})
559562
new_inputs[key] = merged_var
560563
elif key == "Param":
561564
# param is already created on global program

0 commit comments

Comments
 (0)