Skip to content

Commit f1554a4

Browse files
committed
fix sparse grad merge on pserver
1 parent c709a04 commit f1554a4

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,13 +1390,11 @@ def _append_pserver_grad_merge_ops(self, optimize_block,
13901390
inputs={"X": vars2merge},
13911391
outputs={"Out": merged_var},
13921392
attrs={"use_mkldnn": False})
1393-
# TODO(panyx0718): What if it's SELECTED_ROWS.
1394-
if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS:
1395-
optimize_block.append_op(
1396-
type="scale",
1397-
inputs={"X": merged_var},
1398-
outputs={"Out": merged_var},
1399-
attrs={"scale": 1.0 / float(self.trainer_num)})
1393+
optimize_block.append_op(
1394+
type="scale",
1395+
inputs={"X": merged_var},
1396+
outputs={"Out": merged_var},
1397+
attrs={"scale": 1.0 / float(self.trainer_num)})
14001398
return merged_var
14011399

14021400
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,

0 commit comments

Comments
 (0)