Skip to content

Commit 39277e9

Browse files
committed
fix transpiler condition op in optimize
1 parent 9571045 commit 39277e9

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

python/paddle/fluid/distribute_transpiler.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -408,11 +408,16 @@ def _get_splited_name_and_shape(varname):
408408
pserver_vars = pserver_program.global_block().vars
409409
created_var_map = dict()
410410
for _, var in pserver_vars.iteritems():
411-
tmpvar = s_prog.global_block().create_var(
412-
name=var.name,
413-
persistable=var.persistable,
414-
dtype=var.dtype,
415-
shape=var.shape)
411+
if var.type == core.VarDesc.VarType.STEP_SCOPES:
412+
tmpvar = s_prog.global_block().create_var(
413+
name=var.name, persistable=var.persistable, type=var.type)
414+
else:
415+
tmpvar = s_prog.global_block().create_var(
416+
name=var.name,
417+
persistable=var.persistable,
418+
type=var.type,
419+
dtype=var.dtype,
420+
shape=var.shape)
416421
created_var_map[var.name] = tmpvar
417422

418423
# 2. rename op outputs
@@ -708,11 +713,18 @@ def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
708713
varlist = [varlist]
709714

710715
for var in varlist:
711-
program.global_block().create_var(
712-
name=var.name,
713-
persistable=var.persistable,
714-
dtype=var.dtype,
715-
shape=var.shape)
716+
print("##### deal var: ", var)
717+
if var.type == core.VarDesc.VarType.STEP_SCOPES:
718+
program.global_block().create_var(
719+
name=var.name,
720+
persistable=var.persistable,
721+
type=var.type)
722+
else:
723+
program.global_block().create_var(
724+
name=var.name,
725+
persistable=var.persistable,
726+
dtype=var.dtype,
727+
shape=var.shape)
716728

717729
optimize_block.append_op(
718730
type=opt_op.type,

0 commit comments

Comments
 (0)