Skip to content

Commit 8ee5833

Browse files
XrekiTelGome
andauthored
[Feature Enhancement] Optimize decomposition and evaluation implementation to support test_devices. (#392)
* [Feature Enhancement] Implement iterative subgraph decomposition and evaluation pipeline. * Update paddle decomposer to support group_head_and_tail. * Support paddle. * Use system temp directory if dump path is not set. * Support separate testing of reference and target device. * Run test_target_device successfully. * Update to use the original model_path to decompose. * Remove some changes. * Fix model_name and refine some codes. * Minor change to format codes. * Update commit of Athena. --------- Co-authored-by: TelGome <[email protected]>
1 parent bdd6e35 commit 8ee5833

File tree

7 files changed

+566
-428
lines changed

7 files changed

+566
-428
lines changed

graph_net/paddle/extractor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import os
22
import json
3-
import importlib.util
3+
import tempfile
44

55
import paddle
6-
from athena.graphnet_samples import GraphnetSample, RunGeneration
6+
from athena.graphnet_samples import RunGeneration
77
from graph_net import imp_util
8-
from graph_net.paddle import utils
98

109

1110
def load_class_from_file(file_path: str, class_name: str):
@@ -102,8 +101,8 @@ def __init__(
102101
self.num_samples_of_all_subgraphs = 0
103102
self.subgraph_idx2samples = None
104103

105-
dump_path = os.environ.get("GRAPH_NET_PIR_DUMP_WORKSPACE", "/tmp")
106-
self.dump_path = os.path.abspath(dump_path)
104+
dump_path = os.environ.get("GRAPH_NET_PIR_DUMP_WORKSPACE", None)
105+
self.dump_path = os.path.abspath(dump_path) if dump_path else tempfile.mkdtemp()
107106

108107
workspace_path = (
109108
workspace_path
@@ -167,7 +166,7 @@ def run_model_with_dump_enabled(self, model_dump_path, **input_dict):
167166
backend=None,
168167
)
169168
static_model.eval()
170-
program = static_model.forward.concrete_program.main_program
169+
# program = static_model.forward.concrete_program.main_program
171170
# print(program)
172171
static_model(**data_dict)
173172

@@ -176,7 +175,10 @@ def run_model_with_dump_enabled(self, model_dump_path, **input_dict):
176175
return static_model
177176

178177
def translate_pir_program_to_sample_codes(
179-
self, model_dump_path, split_positions=None
178+
self,
179+
model_dump_path,
180+
split_positions=None,
181+
group_head_and_tail=True,
180182
):
181183
ir_programs_path = os.path.join(model_dump_path, "exec_programs.py")
182184
example_inputs_path = os.path.join(
@@ -201,7 +203,9 @@ def translate_pir_program_to_sample_codes(
201203
example_inputs=example_inputs_path,
202204
op_example_inputs=op_example_inputs_path,
203205
split_positions=split_positions,
206+
group_head_and_tail=group_head_and_tail,
204207
eval_mode=True,
208+
tmp_dir=model_dump_path,
205209
)
206210

207211
self.subgraph_idx2samples = {}

graph_net/paddle/naive_graph_decomposer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def __init__(self, parent_graph_extractor):
5656
workspace_path=self.parent_graph_extractor.config["output_dir"],
5757
)
5858
self.split_positions = self.parent_graph_extractor.config["split_positions"]
59+
self.group_head_and_tail = self.parent_graph_extractor.config[
60+
"group_head_and_tail"
61+
]
5962
self.post_process = self.make_post_process(self.parent_graph_extractor.config)
6063

6164
def do_extract(self, **input_dict):
@@ -69,7 +72,9 @@ def do_extract(self, **input_dict):
6972

7073
# 2. Convert pir programs to graphnet samples
7174
self.builtin_extractor.translate_pir_program_to_sample_codes(
72-
model_dump_path, split_positions=self.split_positions
75+
model_dump_path,
76+
split_positions=self.split_positions,
77+
group_head_and_tail=self.group_head_and_tail,
7378
)
7479

7580
# 3. Save to model_path

graph_net/path_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ def get_graphnet_root():
77

88

99
def is_single_model_dir(model_dir):
10-
return os.path.isfile(f"{model_dir}/graph_net.json")
10+
return os.path.isfile(f"{model_dir}/graph_net.json") and os.path.isfile(
11+
f"{model_dir}/model.py"
12+
)
1113

1214

1315
def get_recursively_model_path(root_dir):

0 commit comments

Comments
 (0)