Skip to content

Commit d89ff5b

Browse files
authored
Restore the param infos in Program.clone() (#5873)
* Restore the param infos in Program.clone() The Program.clone only clone the variables and ops in the program into a new program. However, the information of Parameter is not clone. So we need restore the information of Parameters. Fix #5871 * Follow comments * Fix CI * Fix CI * Fix CI
1 parent c9a9657 commit d89ff5b

File tree

2 files changed

+75
-5
lines changed

2 files changed

+75
-5
lines changed

python/paddle/v2/fluid/framework.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,11 @@ def var(self, name):
395395
return v
396396

397397
def all_parameters(self):
398-
return {v for k, v in self.vars.iteritems() if isinstance(v, Parameter)}
398+
return list(self.iter_parameters())
399+
400+
def iter_parameters(self):
401+
return (item[1] for item in self.vars.iteritems()
402+
if isinstance(item[1], Parameter))
399403

400404
def create_var(self, *args, **kwargs):
401405
var = Variable(self, *args, **kwargs)
@@ -469,6 +473,37 @@ def sync_with_cpp(self):
469473
for index in range(len(self.ops)):
470474
assert self.ops[index].desc == ops_in_cpp[index]
471475

476+
def copy_param_info_from(self, other):
477+
"""
478+
Copy the information of parameters from other block
479+
Args:
480+
other(Block): other block
481+
482+
Returns:
483+
None
484+
"""
485+
if not isinstance(other, Block):
486+
raise TypeError("copy_param_info_from should be invoked with Block")
487+
for p in other.iter_parameters():
488+
assert isinstance(p, Parameter)
489+
v = self.vars.get(p.name, None)
490+
if v is None:
491+
raise ValueError("copy_param_info_from should be invoked with "
492+
"same topology")
493+
assert isinstance(v, Variable)
494+
new_p = Parameter(
495+
block=self,
496+
shape=v.shape,
497+
dtype=v.dtype,
498+
type=v.type,
499+
lod_level=v.lod_level,
500+
stop_gradient=p.stop_gradient,
501+
trainable=p.trainable,
502+
optimize_attr=p.optimize_attr,
503+
regularizer=p.regularizer,
504+
name=v.name)
505+
self.vars[new_p.name] = new_p
506+
472507

473508
class Program(object):
474509
def __init__(self):
@@ -489,6 +524,7 @@ def clone(self):
489524
p.desc = core.ProgramDesc(self.desc)
490525
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
491526
p.sync_with_cpp()
527+
p.copy_param_info_from(self)
492528
return p
493529

494530
def prune(self, targets):
@@ -572,6 +608,24 @@ def sync_with_cpp(self):
572608
for block in self.blocks:
573609
block.sync_with_cpp()
574610

611+
def copy_param_info_from(self, other):
612+
"""
613+
Copy the information of parameters from other program.
614+
Args:
615+
other(Program): Other program
616+
617+
Returns:
618+
None
619+
"""
620+
if not isinstance(other, Program):
621+
raise TypeError("copy_param_info_from should be invoked with "
622+
"Program")
623+
624+
if len(self.blocks) != len(other.blocks):
625+
raise ValueError("copy_param_info_from should be invoked with two "
626+
"program, with represent the same topology")
627+
self.global_block().copy_param_info_from(other.global_block())
628+
575629
def list_vars(self):
576630
for each_block in self.blocks:
577631
for each_var in each_block.vars.itervalues():

python/paddle/v2/fluid/tests/test_program.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import print_function
12
import unittest
23

34
from paddle.v2.fluid.framework import Program
45
from paddle.v2.fluid.framework import g_main_program
6+
import paddle.v2.fluid.layers as layers
57

68

79
class TestProgram(unittest.TestCase):
@@ -48,8 +50,8 @@ def test_program_clone(self):
4850

4951
# FIXME(yuyang18): We manual compare the output string, since the order
5052
# of variable could be changed.
51-
print prog
52-
print prog.clone()
53+
print(prog)
54+
print(prog.clone())
5355

5456
def test_parse_program_from_string(self):
5557
prog = Program()
@@ -67,8 +69,8 @@ def test_parse_program_from_string(self):
6769
binary_str = prog.desc.serialize_to_string()
6870
prog_restored = Program.parse_from_string(binary_str)
6971

70-
print prog
71-
print prog_restored
72+
print(prog)
73+
print(prog_restored)
7274

7375
def test_append_backward(self):
7476
prog = Program()
@@ -123,6 +125,20 @@ def grad_name(name):
123125
actual_ops.append(op.type)
124126
self.assertEqual(actual_ops, expect_ops)
125127

128+
def test_program_clone_with_parameter(self):
129+
main_program = Program()
130+
startup_program = Program()
131+
kwargs = {
132+
'main_program': main_program,
133+
'startup_program': startup_program
134+
}
135+
d = layers.data(name='x', shape=[784], dtype='float32', **kwargs)
136+
hidden = layers.fc(input=d, size=100, **kwargs)
137+
layers.fc(input=hidden, size=100, **kwargs)
138+
139+
new_program = main_program.clone()
140+
self.assertNotEqual(0, len(new_program.blocks[0].all_parameters()))
141+
126142

127143
if __name__ == '__main__':
128144
unittest.main()

0 commit comments

Comments
 (0)