@@ -559,19 +559,8 @@ def find_name(var_list, name):
559
559
self .attrs [attr_name ] is None ):
560
560
continue
561
561
attr_val = self .attrs [attr_name ]
562
- if isinstance (attr_val , Block ):
563
- self .desc .set_block_attr (attr_name ,
564
- self .attrs [attr_name ].desc )
565
- elif isinstance (attr_val , list ) and attr_val and \
566
- all (isinstance (v , Block ) for v in attr_val ):
567
- self .desc .set_blocks_attr (attr_name ,
568
- [v .desc for v in attr_val ])
569
- elif isinstance (attr_val , core .BlockDesc ) or \
570
- isinstance (attr_val , core .ProgramDesc ):
571
- self .desc .set_serialized_attr (
572
- attr_name , attr_val .serialize_to_string ())
573
- else :
574
- self .desc .set_attr (attr_name , attr_val )
562
+ self ._update_desc_attr (attr_name , attr_val )
563
+
575
564
self .desc .check_attrs ()
576
565
if self .has_kernel (type ):
577
566
self .desc .infer_var_type (self .block .desc )
@@ -718,6 +707,19 @@ def set_attr(self, name, val):
718
707
ValueError: If the type of value doesn't match with desc.attr_type(name).
719
708
"""
720
709
self .attrs [name ] = val
710
+ self ._update_desc_attr (name , val )
711
+
712
+ def _update_desc_attr (self , name , val ):
713
+ """
714
+ Update the value of desc's attribute by attribute's name.
715
+
716
+ Args:
717
+ name(str): the attribute name.
718
+ val(bool|int|str|float|list): the value of the attribute.
719
+
720
+ Raises:
721
+ ValueError: If the type of value doesn't match with desc.attr_type(name).
722
+ """
721
723
if isinstance (val , Block ):
722
724
self .desc .set_block_attr (name , val .desc )
723
725
elif isinstance (val , list ) and val and all (
0 commit comments