Skip to content

Commit 3fb8890

Browse files
authored
Update some scripts in subgraph decomposer (#432)
* Fix * Optimize typical_sequence_decomposer_test * change the entry of naive_graph_decomposer from graph_net.torch.run_model to graph_net.model_path_handler * update test_compiler and validator backend to support config and model_list * Tidy model lists in repo * Fix relative script * Update the process of typical_sequence_decomposer * renamed: naive_graph_decomposer -> graph_decomposer * remove antique test code * Fix: process torch.device and torch.dtype in inputs * change the key of split_results.json from model_name to model_path * Remove range_decomposer_backend for graph_net.torch.test_compiler, which is replaced by graph_net.torch.graph_decomposer * change the debug info of fx.GraphModule replay error * revert some change in decompose_util * revert some change in decompose_util
1 parent 21d5850 commit 3fb8890

File tree

8 files changed

+16
-150
lines changed

8 files changed

+16
-150
lines changed

graph_net/test/decomposer_validator_test.sh

Lines changed: 0 additions & 51 deletions
This file was deleted.

graph_net/torch/backend/range_decomposer_backend.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

graph_net/torch/decompose_util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ def get_args_node(arg):
248248
yield arg.start
249249
yield arg.stop
250250
yield arg.step
251+
elif isinstance(arg, torch.device):
252+
pass
253+
elif isinstance(arg, torch.dtype):
254+
pass
251255
else:
252256
assert isinstance(
253257
arg, (int, bool, float, str, type(...), type(None))

graph_net/torch/fx_graph_module_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def get_torch_module_and_inputs(model_path):
1515
def _get_torch_module(model_path):
1616
py_module = load_module(f"{model_path}/model.py")
1717
torch_module_cls = py_module.GraphModule
18+
torch_module_cls.__graph_net_file_path__ = model_path
1819
return torch_module_cls()
1920

2021

graph_net/torch/fx_graph_parse_util.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,11 @@ def handle_underscore_suffix_difference():
221221

222222
zip_filter_names = get_zip_filter_names()
223223

224-
def zip_filter_names_str():
224+
def zip_filter_names_error_str():
225225
for triple in zip_filter_names:
226226
print(triple)
227-
return "<printed before>"
227+
error_model_path = module.__graph_net_file_path__
228+
return f"{error_model_path=}"
228229

229230
from pathlib import Path
230231

graph_net/torch/graph_decomposer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,7 @@ def _make_config(
193193
def __call__(self, rel_model_path):
194194
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
195195
split_results = load_json(self.config["split_results_path"])
196-
split_positions = split_results[os.path.basename(rel_model_path)][
197-
"split_points"
198-
]
196+
split_positions = split_results[rel_model_path]["split_positions"]
199197
config = {
200198
"split_positions": split_positions,
201199
"group_head_and_tail": self.config.get("group_head_and_tail", False),

graph_net/torch/test_compiler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from graph_net.torch.backend.blade_disc_backend import BladeDISCBackend
2222
from graph_net.torch.backend.nope_backend import NopeBackend
2323
from graph_net.torch.backend.unstable_to_stable_backend import UnstableToStableBackend
24-
from graph_net.torch.backend.range_decomposer_backend import RangeDecomposerBackend
2524
from graph_net.torch.backend.range_decomposer_validator_backend import (
2625
RangeDecomposerValidatorBackend,
2726
)
@@ -37,7 +36,6 @@
3736
"bladedisc": BladeDISCBackend(),
3837
"nope": NopeBackend(),
3938
"unstable_to_stable": UnstableToStableBackend(),
40-
"range_decomposer": RangeDecomposerBackend(),
4139
"range_decomposer_validator": RangeDecomposerValidatorBackend(),
4240
}
4341

graph_net/torch/typical_sequence_split_points.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]:
203203
)
204204

205205
current_idx = 0
206-
split_points_set = set()
206+
split_positions = set()
207207
total_len = sum(token2len.get(t, 1) for t in seq_tokens)
208208

209209
for token_id in seq_tokens:
@@ -212,22 +212,22 @@ def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]:
212212

213213
if is_pattern:
214214
if current_idx > 0:
215-
split_points_set.add(current_idx)
215+
split_positions.add(current_idx)
216216
end_idx = current_idx + length
217217
if end_idx < total_len:
218-
split_points_set.add(end_idx)
218+
split_positions.add(end_idx)
219219

220220
current_idx += length
221221

222-
sorted_splits = sorted(list(split_points_set))
222+
sorted_splits = sorted(list(split_positions))
223223

224224
self._print_analysis(
225225
model_name, str(original_path), sorted_splits, total_len, full_model_ops
226226
)
227227

228-
results[model_name] = {
229-
"path": str(original_path),
230-
"split_points": sorted_splits,
228+
results[str(original_path)] = {
229+
"model_name": model_name,
230+
"split_positions": sorted_splits,
231231
"total_length": total_len,
232232
}
233233

0 commit comments

Comments
 (0)