@@ -464,7 +464,30 @@ def get_sub_ops(op, op_name, all_ops, all_vars):
464464 all_ops [sub_op_name ]['is_leaf_node' ] = True
465465 now_var = utils .gen_var_name (sub_op .results ())
466466 for source in sub_op .operands_source ():
467+ if not source .type ():
468+ # if source.type() == Value().type():
469+ continue
467470 input_name = utils .gen_var_name (source )
471+ if input_name not in all_vars .keys ():
472+ all_vars [input_name ] = {}
473+ all_vars [input_name ]['name' ] = input_name
474+ try :
475+ attrs = source .results ()[0 ].get_defining_op ().attrs ()
476+ if 'place' in attrs :
477+ attrs ['place' ] = str (attrs ['place' ])
478+ attrs ['dtype' ] = safe_get_dtype (source )
479+ except Exception :
480+ attrs = {}
481+
482+ all_vars [input_name ]['shape' ] = safe_get_shape (source )
483+ all_vars [input_name ]['type' ] = safe_get_type (source )
484+ all_vars [input_name ]['dtype' ] = safe_get_dtype (source )
485+ all_vars [input_name ]['value' ] = []
486+ all_vars [input_name ]['persistable' ] = safe_get_persistable (source )
487+ all_vars [input_name ]['attrs' ] = attrs
488+ all_vars [input_name ]['from_node' ] = ''
489+ all_vars [input_name ]['to_nodes' ] = []
490+
468491 if sub_op .name () == "pd_op.increment_" :
469492 all_vars [now_var ]['to_nodes' ].append (all_vars [input_name ]['from_node' ])
470493 all_ops [all_vars [input_name ]['from_node' ]]['input_vars' ][now_var ] = [now_var ]
@@ -633,7 +656,30 @@ def analyse_pir(program):
633656 all_ops [op_name ]['is_leaf_node' ] = True
634657 now_var = utils .gen_var_name (op .results ())
635658 for source in op .operands_source ():
659+ if not source .type ():
660+ # if source.type() == Value().type():
661+ continue
636662 input_name = utils .gen_var_name (source )
663+ if input_name not in all_vars .keys ():
664+ all_vars [input_name ] = {}
665+ all_vars [input_name ]['name' ] = input_name
666+ try :
667+ attrs = source .results ()[0 ].get_defining_op ().attrs ()
668+ if 'place' in attrs :
669+ attrs ['place' ] = str (attrs ['place' ])
670+ attrs ['dtype' ] = safe_get_dtype (source )
671+ except Exception :
672+ attrs = {}
673+
674+ all_vars [input_name ]['shape' ] = safe_get_shape (source )
675+ all_vars [input_name ]['type' ] = safe_get_type (source )
676+ all_vars [input_name ]['dtype' ] = safe_get_dtype (source )
677+ all_vars [input_name ]['value' ] = []
678+ all_vars [input_name ]['persistable' ] = safe_get_persistable (source )
679+ all_vars [input_name ]['attrs' ] = attrs
680+ all_vars [input_name ]['from_node' ] = ''
681+ all_vars [input_name ]['to_nodes' ] = []
682+
637683 if op .name () == "pd_op.increment_" :
638684 all_vars [now_var ]['to_nodes' ].append (all_vars [input_name ]['from_node' ])
639685 all_ops [all_vars [input_name ]['from_node' ]]['input_vars' ][now_var ] = [now_var ]
0 commit comments