Skip to content

Commit c33dd79

Browse files
authored
refactor ResumableSamplePass.sample_handled to a abstract method (#466)
* init 'symbolic_dimension_reifier' field in graph_net.json * remove unused files * support --subgraph-ranges-json for typical_sequence_split_points.py * minor fix for torch/decompose_util.py * refactor ResumableSamplePass.sample_handled to a abstract method
1 parent ddc026b commit c33dd79

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

graph_net/sample_pass/resumable_sample_pass_mixin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@ def declare_config(
1818
):
1919
pass
2020

21+
@abc.abstractmethod
2122
def sample_handled(self, rel_model_path: str) -> bool:
23+
raise NotImplementedError()
24+
25+
def naive_sample_handled(self, rel_model_path: str, search_file_name: str) -> bool:
2226
dst_model_path = Path(self.config["output_dir"]) / rel_model_path
2327
if not dst_model_path.exists():
2428
return False
25-
num_model_py_files = len(list(dst_model_path.rglob("model.py")))
29+
num_model_py_files = len(list(dst_model_path.rglob(search_file_name)))
2630
assert num_model_py_files <= 1
2731
return num_model_py_files == 1
2832

graph_net/torch/fx_graph_module_util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
from dataclasses import asdict
66

77

8+
def get_fx_graph_num_ops(fx_graph_module):
9+
def get_num_ops(node):
10+
return 0 if node.op in {"placeholder", "output"} else 1
11+
12+
return sum(map(get_num_ops, fx_graph_module.graph.nodes))
13+
14+
815
def get_torch_module_and_inputs(model_path, use_dummy_inputs=True):
916
module = _get_torch_module(model_path)
1017
tensor_metas = _get_tensor_metas(model_path)

graph_net/torch/sample_passes/device_rewrite_sample_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def declare_config(
2626
def __call__(self, rel_model_path: str):
2727
self.resumable_handle_sample(rel_model_path)
2828

29+
def sample_handled(self, rel_model_path: str) -> bool:
30+
return self.naive_sample_handled(rel_model_path, search_file_name="model.py")
31+
2932
def resume(self, rel_model_path: str):
3033
return self.copy_sample_and_handle_model_py_file(rel_model_path)
3134

0 commit comments

Comments
 (0)