Skip to content

Commit 3bacc46

Browse files
authored
Fix var bugs (#1288)
* Fix wasm-pack not found error * add parameters --modelfile * Fix operator attribute visualization bug * Fix var bugs * check source type
1 parent 85b7887 commit 3bacc46

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

visualdl/component/graph/graph_component.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,30 @@ def get_sub_ops(op, op_name, all_ops, all_vars):
464464
all_ops[sub_op_name]['is_leaf_node'] = True
465465
now_var = utils.gen_var_name(sub_op.results())
466466
for source in sub_op.operands_source():
467+
if not source.type():
468+
# if source.type() == Value().type():
469+
continue
467470
input_name = utils.gen_var_name(source)
471+
if input_name not in all_vars.keys():
472+
all_vars[input_name] = {}
473+
all_vars[input_name]['name'] = input_name
474+
try:
475+
attrs = source.results()[0].get_defining_op().attrs()
476+
if 'place' in attrs:
477+
attrs['place'] = str(attrs['place'])
478+
attrs['dtype'] = safe_get_dtype(source)
479+
except Exception:
480+
attrs = {}
481+
482+
all_vars[input_name]['shape'] = safe_get_shape(source)
483+
all_vars[input_name]['type'] = safe_get_type(source)
484+
all_vars[input_name]['dtype'] = safe_get_dtype(source)
485+
all_vars[input_name]['value'] = []
486+
all_vars[input_name]['persistable'] = safe_get_persistable(source)
487+
all_vars[input_name]['attrs'] = attrs
488+
all_vars[input_name]['from_node'] = ''
489+
all_vars[input_name]['to_nodes'] = []
490+
468491
if sub_op.name() == "pd_op.increment_":
469492
all_vars[now_var]['to_nodes'].append(all_vars[input_name]['from_node'])
470493
all_ops[all_vars[input_name]['from_node']]['input_vars'][now_var] = [now_var]
@@ -633,7 +656,30 @@ def analyse_pir(program):
633656
all_ops[op_name]['is_leaf_node'] = True
634657
now_var = utils.gen_var_name(op.results())
635658
for source in op.operands_source():
659+
if not source.type():
660+
# if source.type() == Value().type():
661+
continue
636662
input_name = utils.gen_var_name(source)
663+
if input_name not in all_vars.keys():
664+
all_vars[input_name] = {}
665+
all_vars[input_name]['name'] = input_name
666+
try:
667+
attrs = source.results()[0].get_defining_op().attrs()
668+
if 'place' in attrs:
669+
attrs['place'] = str(attrs['place'])
670+
attrs['dtype'] = safe_get_dtype(source)
671+
except Exception:
672+
attrs = {}
673+
674+
all_vars[input_name]['shape'] = safe_get_shape(source)
675+
all_vars[input_name]['type'] = safe_get_type(source)
676+
all_vars[input_name]['dtype'] = safe_get_dtype(source)
677+
all_vars[input_name]['value'] = []
678+
all_vars[input_name]['persistable'] = safe_get_persistable(source)
679+
all_vars[input_name]['attrs'] = attrs
680+
all_vars[input_name]['from_node'] = ''
681+
all_vars[input_name]['to_nodes'] = []
682+
637683
if op.name() == "pd_op.increment_":
638684
all_vars[now_var]['to_nodes'].append(all_vars[input_name]['from_node'])
639685
all_ops[all_vars[input_name]['from_node']]['input_vars'][now_var] = [now_var]

0 commit comments

Comments
 (0)