Skip to content

Commit 417240e

Browse files
authored
Minor enhance on typical_sequence_decompose process (#437)
* debug_typical_sequence * support model-path-prefix in splitting positions * fix * fix * Update RangeDecomposer to handle missing split position in split_results.json * Add level5_subgraph_dataset_test.sh * remove redundant code * handle Exception in model_path_handler * revert graph_net/model_path_handler.py * minor fix * pass models without op list * support resume in GraphVariableRenamer * Minor Change * Use with + NamedTempFile
1 parent 428c02c commit 417240e

File tree

4 files changed

+29
-8
lines changed

4 files changed

+29
-8
lines changed

graph_net/tools/typical_sequence_decompose.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ python3 -m graph_net.model_path_handler \
6868
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/graph_variable_renamer.py",
6969
"handler_class_name": "GraphVariableRenamer",
7070
"handler_config": {
71+
"resume": true,
7172
"model_path_prefix": "$DECOMPOSE_WORKSPACE",
7273
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
7374
"data_input_predicator_class_name": "NaiveDataInputPredicator",

graph_net/torch/graph_decomposer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
import torch
55
import json
6+
import sys
67
from graph_net.torch.decompose_util import convert_to_submodules_graph
78
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
89
import graph_net.imp_util as imp_util
@@ -209,6 +210,12 @@ def __call__(self, rel_model_path):
209210
)
210211
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
211212
split_results = load_json(self.config["split_results_path"])
213+
if (
214+
split_results[rel_model_path]["split_positions"] is None
215+
or len(split_results[rel_model_path]["split_positions"]) == 0
216+
):
217+
sys.stderr.write(f"Error: {rel_model_path} has no split positions.\n")
218+
return
212219
split_positions = split_results[rel_model_path]["split_positions"]
213220
if self.config["resume"] and self._is_model_handled(
214221
rel_model_path, split_positions

graph_net/torch/graph_variable_renamer.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import shutil
44
import inspect
5+
import tempfile
56
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
67
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
78
from graph_net.tensor_meta import TensorMeta
@@ -37,8 +38,9 @@ def _make_model_runnable_predicator(self, config):
3738

3839
def _make_config(
3940
self,
40-
data_input_predicator_filepath,
41-
model_runnable_predicator_filepath,
41+
resume: bool = False,
42+
data_input_predicator_filepath=None,
43+
model_runnable_predicator_filepath=None,
4244
output_dir="./tmp/graph_variable_renamer_dir",
4345
filter_path=None,
4446
filter_config=None,
@@ -59,6 +61,7 @@ def _make_config(
5961
if model_runnable_predicator_config is None:
6062
model_runnable_predicator_config = {}
6163
return {
64+
"resume": resume,
6265
"output_dir": output_dir,
6366
"filter_path": filter_path,
6467
"filter_config": filter_config if filter_config is not None else {},
@@ -82,12 +85,20 @@ def __call__(self, rel_model_path):
8285
dst_model_path = os.path.realpath(
8386
os.path.join(self.config["output_dir"], rel_model_path)
8487
)
88+
if self.config["resume"] and os.path.exists(
89+
os.path.join(dst_model_path, "model.py")
90+
):
91+
return
8592
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
86-
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
87-
self._update_model_py_file(gm, dst_model_path)
88-
self._update_weight_meta_py_file(src_model_path, dst_model_path)
89-
self._update_input_meta_py_file(src_model_path, dst_model_path)
90-
self._try_run(dst_model_path)
93+
with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir:
94+
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
95+
shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True)
96+
self._update_model_py_file(gm, temp_model_path)
97+
self._update_weight_meta_py_file(src_model_path, temp_model_path)
98+
self._update_input_meta_py_file(src_model_path, temp_model_path)
99+
print("Try to run renamed model...")
100+
self._try_run(temp_model_path)
101+
shutil.copytree(temp_model_path, dst_model_path)
91102

92103
def _try_run(self, model_path):
93104
assert self.model_runnable_predicator(

graph_net/torch/typical_sequence_split_points.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def _resolve_token_to_ops(
116116
return [f"Unknown({tid})"]
117117

118118
def _load_op_names_from_file(self, txt_path: Path) -> List[str]:
119-
assert txt_path.exists(), f"{str(txt_path)=}"
119+
if not txt_path.exists():
120+
print(f"File not found: {txt_path}")
121+
return []
120122
return txt_path.read_text().split("\n")
121123

122124
def _calculate_token_lengths(

0 commit comments

Comments
 (0)