Skip to content

Commit 0707abb

Browse files
committed
lookup table fix
1 parent 83c85f3 commit 0707abb

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -877,9 +877,15 @@ def _create_table_optimize_block(self, pserver_index, pserver_program,
877877
# create table param and grad var in pserver program
878878
origin_param_var = self.origin_program.global_block().vars[
879879
self.table_name]
880+
881+
zero_dim = long(
882+
math.ceil(origin_param_var.shape[0] / len(self.pserver_endpoints)))
883+
table_shape = list(origin_param_var.shape)
884+
table_shape[0] = zero_dim
885+
880886
param_var = pserver_program.global_block().create_var(
881887
name=origin_param_var.name,
882-
shape=origin_param_var.shape,
888+
shape=table_shape,
883889
dtype=origin_param_var.dtype,
884890
type=core.VarDesc.VarType.SELECTED_ROWS,
885891
persistable=True)

0 commit comments

Comments
 (0)