@@ -395,7 +395,11 @@ def var(self, name):
395
395
return v
396
396
397
397
def all_parameters (self ):
398
- return {v for k , v in self .vars .iteritems () if isinstance (v , Parameter )}
398
+ return list (self .iter_parameters ())
399
+
400
+ def iter_parameters (self ):
401
+ return (item [1 ] for item in self .vars .iteritems ()
402
+ if isinstance (item [1 ], Parameter ))
399
403
400
404
def create_var (self , * args , ** kwargs ):
401
405
var = Variable (self , * args , ** kwargs )
@@ -469,6 +473,37 @@ def sync_with_cpp(self):
469
473
for index in range (len (self .ops )):
470
474
assert self .ops [index ].desc == ops_in_cpp [index ]
471
475
476
+ def copy_param_info_from (self , other ):
477
+ """
478
+ Copy the information of parameters from other block
479
+ Args:
480
+ other(Block): other block
481
+
482
+ Returns:
483
+ None
484
+ """
485
+ if not isinstance (other , Block ):
486
+ raise TypeError ("copy_param_info_from should be invoked with Block" )
487
+ for p in other .iter_parameters ():
488
+ assert isinstance (p , Parameter )
489
+ v = self .vars .get (p .name , None )
490
+ if v is None :
491
+ raise ValueError ("copy_param_info_from should be invoked with "
492
+ "same topology" )
493
+ assert isinstance (v , Variable )
494
+ new_p = Parameter (
495
+ block = self ,
496
+ shape = v .shape ,
497
+ dtype = v .dtype ,
498
+ type = v .type ,
499
+ lod_level = v .lod_level ,
500
+ stop_gradient = p .stop_gradient ,
501
+ trainable = p .trainable ,
502
+ optimize_attr = p .optimize_attr ,
503
+ regularizer = p .regularizer ,
504
+ name = v .name )
505
+ self .vars [new_p .name ] = new_p
506
+
472
507
473
508
class Program (object ):
474
509
def __init__ (self ):
@@ -489,6 +524,7 @@ def clone(self):
489
524
p .desc = core .ProgramDesc (self .desc )
490
525
p .blocks = [Block (p , i ) for i in xrange (self .desc .num_blocks ())]
491
526
p .sync_with_cpp ()
527
+ p .copy_param_info_from (self )
492
528
return p
493
529
494
530
def prune (self , targets ):
@@ -572,6 +608,24 @@ def sync_with_cpp(self):
572
608
for block in self .blocks :
573
609
block .sync_with_cpp ()
574
610
611
+ def copy_param_info_from (self , other ):
612
+ """
613
+ Copy the information of parameters from other program.
614
+ Args:
615
+ other(Program): Other program
616
+
617
+ Returns:
618
+ None
619
+ """
620
+ if not isinstance (other , Program ):
621
+ raise TypeError ("copy_param_info_from should be invoked with "
622
+ "Program" )
623
+
624
+ if len (self .blocks ) != len (other .blocks ):
625
+ raise ValueError ("copy_param_info_from should be invoked with two "
626
+ "program, with represent the same topology" )
627
+ self .global_block ().copy_param_info_from (other .global_block ())
628
+
575
629
def list_vars (self ):
576
630
for each_block in self .blocks :
577
631
for each_var in each_block .vars .itervalues ():
0 commit comments