Skip to content

Commit c00a5de

Browse files
authored
Merge pull request #9644 from typhoonzero/fix_condition_op_in_optimize
Fix transpiler condition op in optimize
2 parents a84a580 + a16a872 commit c00a5de

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

python/paddle/fluid/distribute_transpiler.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -408,11 +408,7 @@ def _get_splited_name_and_shape(varname):
408408
pserver_vars = pserver_program.global_block().vars
409409
created_var_map = dict()
410410
for _, var in pserver_vars.iteritems():
411-
tmpvar = s_prog.global_block().create_var(
412-
name=var.name,
413-
persistable=var.persistable,
414-
dtype=var.dtype,
415-
shape=var.shape)
411+
tmpvar = s_prog.global_block().clone_variable(var)
416412
created_var_map[var.name] = tmpvar
417413

418414
# 2. rename op outputs
@@ -708,11 +704,7 @@ def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
708704
varlist = [varlist]
709705

710706
for var in varlist:
711-
program.global_block().create_var(
712-
name=var.name,
713-
persistable=var.persistable,
714-
dtype=var.dtype,
715-
shape=var.shape)
707+
program.global_block().clone_variable(var)
716708

717709
optimize_block.append_op(
718710
type=opt_op.type,

python/paddle/fluid/framework.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -946,13 +946,20 @@ def clone_variable(self, var):
946946
The new variable cloned from 'var' in current block.
947947
"""
948948
assert isinstance(var, Variable)
949-
return self.create_var(
950-
name=var.name,
951-
shape=var.shape,
952-
dtype=var.dtype,
953-
type=var.type,
954-
lod_level=var.lod_level,
955-
persistable=True)
949+
ret_var = None
950+
# make STEP_SCOPES var can be safely cloned.
951+
if var.type == core.VarDesc.VarType.STEP_SCOPES:
952+
ret_var = self.create_var(
953+
name=var.name, persistable=var.persistable, type=var.type)
954+
else:
955+
ret_var = self.create_var(
956+
name=var.name,
957+
shape=var.shape,
958+
dtype=var.dtype,
959+
type=var.type,
960+
lod_level=var.lod_level,
961+
persistable=True)
962+
return ret_var
956963

957964

958965
class Program(object):

0 commit comments

Comments
 (0)