Skip to content

Commit b699698

Browse files
committed
fix
1 parent 90b622d commit b699698

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

graph_net/tensor_meta.py

100644100755
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def serialize_to_py_str(self) -> str:
5858
lines = [
5959
(f"class {self.record_class_name}:"),
6060
(f'\tname = "{self.name}"'),
61+
*(
62+
[f'\toriginal_name = "{self.original_name}"']
63+
if self.original_name is not None
64+
else []
65+
),
6166
(f"\tshape = {self.shape}"),
6267
(f'\tdtype = "{self.dtype}"'),
6368
(f'\tdevice = "{self.device}"'),

graph_net/test/graph_variable_rename_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ EOF
2323
)
2424
CONFIG=$(echo $config_json_str | base64 -w 0)
2525

26-
# python3 -m graph_net.model_path_handler --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
27-
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/decomposition_error_tmp_torch_samples_list.txt --handler-config=$CONFIG
26+
python3 -m graph_net.model_path_handler --model-path samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
27+
# python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/decomposition_error_tmp_torch_samples_list.txt --handler-config=$CONFIG

graph_net/torch/graph_variable_renamer.py

100644100755
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ def __call__(self, rel_model_path):
7878
module, inputs = get_torch_module_and_inputs(src_model_path)
7979
gm = parse_sole_graph_module(module, inputs)
8080
gm = self.rename_graph_variables(gm, inputs, src_model_path)
81-
# print(gm)
82-
dst_model_path = os.path.join(self.config["output_dir"], rel_model_path)
81+
dst_model_path = os.path.realpath(
82+
os.path.join(self.config["output_dir"], rel_model_path)
83+
)
8384
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
8485
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
8586
self._update_model_py_file(gm, dst_model_path)

0 commit comments

Comments
 (0)