Skip to content

Commit 6cf6c6a

Browse files
typhoonzeroYancey0623
authored andcommitted
Merge pull request #11698 from typhoonzero/fix_sparse_dist_paraexe
fix sparse paraexe dist train
1 parent fac1d47 commit 6cf6c6a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
470470
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
471471
const OpDesc &op) const {
472472
int op_dev_id = -1;
473-
if (op.Type() == "split_byref") {
473+
if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") {
474474
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
475475
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
476476
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());

0 commit comments

Comments
 (0)