Skip to content

Commit c99bf74

Browse files
committed
use tempfile, fix sys problem, remove unsless configs
1 parent 5292810 commit c99bf74

File tree

4 files changed

+37
-52
lines changed

4 files changed

+37
-52
lines changed

graph_net/test/graph_decompose_and_look_for_fully_fusable_subgraph_test.sh

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,16 @@ decorator_config_json_str=$(cat <<EOF
1313
"name": "$MODEL_NAME",
1414
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/fully_fusable_subgraph_extractor.py",
1515
"custom_extractor_config": {
16-
"output_dir": "/tmp/naive_decompose_workspace",
1716
"split_positions": [],
1817
"group_head_and_tail": true,
19-
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
20-
"filter_config": {},
21-
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
22-
"post_extract_process_class_name": "GraphFullyFusable",
2318
"max_step": 5,
2419
"min_step": 2,
25-
"max_nodes": 32
20+
"max_nodes": 6
2621
}
2722
}
2823
}
2924
EOF
3025
)
3126
DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0)
3227

33-
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG
28+
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG
Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
22
import torch
3-
import sys
43
import graph_net
4+
import tempfile
5+
from graph_net.torch import constraint_util
56

67

78
class GraphExtractor:
@@ -25,11 +26,6 @@ def make_config(
2526
split_positions=(),
2627
group_head_and_tail=False,
2728
chain_style=False,
28-
output_dir="./tmp/naive_decomposer_dir",
29-
filter_path=None,
30-
filter_config=None,
31-
post_extract_process_path=None,
32-
post_extract_process_class_name=None,
3329
max_step=8,
3430
min_step=2,
3531
max_nodes=32,
@@ -42,40 +38,38 @@ def make_config(
4238
"split_positions": split_positions,
4339
"group_head_and_tail": group_head_and_tail,
4440
"chain_style": chain_style,
45-
"output_dir": output_dir,
46-
"filter_path": filter_path,
47-
"filter_config": filter_config if filter_config is not None else {},
48-
"post_extract_process_path": post_extract_process_path,
49-
"post_extract_process_class_name": post_extract_process_class_name,
5041
"max_step": max_step,
5142
"min_step": min_step,
5243
"max_nodes": max_nodes,
5344
}
5445

5546
def _get_sub_ranges(self):
56-
kMinLenOps = self.config["min_step"]
57-
num_ops = self.config["max_nodes"]
58-
for length in reversed(range(kMinLenOps, num_ops)):
59-
for start_pos in range(num_ops - length):
60-
end_pos = start_pos + length
47+
for step in reversed(
48+
range(self.config["min_step"], self.config["max_step"] + 1)
49+
):
50+
for start_pos in range(self.config["max_nodes"] - step):
51+
end_pos = start_pos + step
6152
yield start_pos, end_pos
6253

6354
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
64-
import json
65-
import base64
66-
55+
temp_dir_obj = tempfile.TemporaryDirectory(prefix="_check_fusable_")
56+
temp_output_dir = temp_dir_obj.name
57+
found_fusable_subgraph = False
58+
print(f"Using temp output dir: {temp_output_dir}")
6759
for start_pos, end_pos in self._get_sub_ranges():
6860
self.config["split_positions"] = [start_pos, end_pos]
6961
print("current split_positions:", self.config["split_positions"])
7062
graph_net_root = os.path.dirname(graph_net.__file__)
71-
model_path = f"{graph_net_root}/../samples//timm/{self.name}"
63+
model_path = os.path.join(
64+
graph_net_root, "..", "samples", "timm", self.name
65+
)
7266
check_fusable_config = {
7367
"decorator_path": f"{graph_net_root}/torch/extractor.py",
7468
"decorator_config": {
7569
"name": f"{self.name}",
7670
"custom_extractor_path": f"{graph_net_root}/torch/naive_graph_decomposer.py",
7771
"custom_extractor_config": {
78-
"output_dir": "/tmp/naive_decompose_workspace",
72+
"output_dir": temp_output_dir,
7973
"split_positions": self.config["split_positions"],
8074
"group_head_and_tail": False,
8175
"filter_path": f"{graph_net_root}/torch/naive_subgraph_filter.py",
@@ -85,15 +79,20 @@ def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
8579
},
8680
},
8781
}
88-
json_string = json.dumps(check_fusable_config)
89-
json_bytes = json_string.encode("utf-8")
90-
b64_encoded_bytes = base64.b64encode(json_bytes)
91-
checker_config = b64_encoded_bytes.decode("utf-8")
92-
cmd = f"{sys.executable} -m graph_net.torch.run_model --model-path {model_path} --decorator-config '{checker_config}'"
93-
res_code = os.system(cmd)
94-
if res_code == 0:
95-
print("find the biggest fully fusable subgraph")
82+
success = constraint_util.RunModelPredicator(check_fusable_config)(
83+
model_path
84+
)
85+
if success:
86+
found_fusable_subgraph = True
87+
temp_dir_obj.cleanup = lambda: None
88+
print(
89+
f"SUCCESS in finding the biggest fully fusable subgraph saved in: {temp_output_dir}."
90+
)
9691
break
9792
else:
93+
print("Failed attempt. clean up the workspace and continue the search.")
94+
temp_dir_obj.cleanup()
9895
continue
96+
if not found_fusable_subgraph:
97+
print("No fusable subgraph found")
9998
return gm

graph_net/torch/naive_graph_decomposer.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import os
22
import torch
3-
import sys
4-
import shutil
53
from graph_net.torch.decompose_util import convert_to_submodules_graph
64
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
75
import graph_net.imp_util as imp_util
@@ -104,15 +102,7 @@ def _post_extract_process(self):
104102
model_path = os.path.join(
105103
self.parent_graph_extractor.config["output_dir"], self.model_name
106104
)
107-
fully_fusable = self.post_extract_process(model_path)
108-
if fully_fusable:
109-
print(f"{model_path} is the biggest fully fusable subgraph!")
110-
sys.exit(0)
111-
else:
112-
# remove if not fully fusable
113-
shutil.rmtree(model_path)
114-
print(f"remove: {model_path}")
115-
sys.exit(1)
105+
return self.post_extract_process(model_path)
116106

117107
def make_filter(self, config):
118108
if config["filter_path"] is None:

graph_net/torch/post_extract_process_count_kernels.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from graph_net.torch import utils
22
import importlib.util
33
import torch
4+
import sys
45
from typing import Type
56
from torch.profiler import profile, record_function, ProfilerActivity
67

@@ -12,7 +13,7 @@ def __init__(self, config):
1213
def __call__(self, model_path=None):
1314
torch._dynamo.reset()
1415
if model_path is None:
15-
return False
16+
sys.exit(1)
1617
# model
1718
model_class = load_class_from_file(
1819
f"{model_path}/model.py", class_name="GraphModule"
@@ -30,20 +31,20 @@ def __call__(self, model_path=None):
3031
model(**state_dict)
3132
except Exception as e:
3233
print(f"failed in running model:{e}")
33-
return False
34+
sys.exit(1)
3435
# try to compile the model
3536
try:
3637
compiled_model = torch.compile(model)
3738
except Exception as e:
3839
print(f"failed in compiling model:{e}")
39-
return False
40+
sys.exit(1)
4041
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
4142
if compiled_num_of_kernels == 1:
4243
print(model_path, "can be fully integrated!!!!!!!!!!!")
43-
return True
44+
sys.exit(0)
4445
else:
4546
print(f"{model_path} can not be fully integrated, to be removed...")
46-
return False
47+
sys.exit(1)
4748

4849

4950
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:

0 commit comments

Comments
 (0)