Skip to content

Commit 0aa1827

Browse files
authored
[Feature Enhancement] Optimize subgraph_decompose_and_evaluation implementation. (#403)
* Optimize the calculation of split_positions. * Fix the parsing of model_path. * Optimize the run_decomposer process. * Minor modification. * Fix model_name in decompose config. * Optimize the definition of decompose config and fix the config saving of test_target_device. * Redirect the output of run_decomposer to file.
1 parent e2061ef commit 0aa1827

File tree

2 files changed

+198
-225
lines changed

2 files changed

+198
-225
lines changed

graph_net/paddle/naive_graph_decomposer.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -82,29 +82,14 @@ def do_extract(self, **input_dict):
8282
model_path = os.path.join(
8383
self.builtin_extractor.workspace_path, self.builtin_extractor.name
8484
)
85-
for (
86-
subgraph_idx,
87-
samples,
88-
) in self.builtin_extractor.subgraph_idx2samples.items():
89-
for seq_idx in range(len(samples)):
90-
if (
91-
self.builtin_extractor.num_samples_of_all_subgraphs == 1
92-
and len(samples) == 1
93-
):
94-
subgraph_path = model_path
95-
elif len(samples) == 1:
96-
subgraph_path = os.path.join(model_path, f"subgraph_{subgraph_idx}")
97-
else:
98-
subgraph_path = os.path.join(
99-
model_path, f"subgraph_{subgraph_idx}_{seq_idx}"
100-
)
101-
self.subgraph_path_list.append(subgraph_path)
102-
self.builtin_extractor.write_sample_to_file(
103-
subgraph_path, samples[seq_idx]
104-
)
105-
print(
106-
f"Graph and tensors for '{self.builtin_extractor.name}' extracted successfully to: {model_path}"
107-
)
85+
assert len(self.builtin_extractor.subgraph_idx2samples) == 1
86+
87+
samples = self.builtin_extractor.subgraph_idx2samples[0]
88+
for seq_idx in range(len(samples)):
89+
subgraph_path = f"{model_path}_{seq_idx}"
90+
self.subgraph_path_list.append(subgraph_path)
91+
self.builtin_extractor.write_sample_to_file(subgraph_path, samples[seq_idx])
92+
print(f"Save to {subgraph_path}")
10893
return static_model
10994

11095
def __call__(self, **input_dict):

0 commit comments

Comments
 (0)