Skip to content

Commit dc1d35b

Browse files
committed
Fix
1 parent 5f01912 commit dc1d35b

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

graph_net/torch/naive_graph_decomposer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@ def make_config(
3131
filter_config=None,
3232
post_extract_process_path=None,
3333
post_extract_process_class_name=None,
34+
post_extract_process_config=None,
3435
**kwargs,
3536
):
37+
if post_extract_process_config is None:
38+
post_extract_process_config = {}
3639
for pos in split_positions:
3740
assert isinstance(
3841
pos, int
@@ -46,6 +49,7 @@ def make_config(
4649
"filter_config": filter_config if filter_config is not None else {},
4750
"post_extract_process_path": post_extract_process_path,
4851
"post_extract_process_class_name": post_extract_process_class_name,
52+
"post_extract_process_config": post_extract_process_config,
4953
}
5054

5155
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
@@ -112,8 +116,8 @@ def make_filter(self, config):
112116
return module.GraphFilter(config["filter_config"])
113117

114118
def make_post_extract_process(self, config):
115-
if config["post_extract_process_path"] is None:
116-
return None
119+
if config.get("post_extract_process_path") is None:
120+
return lambda *args, **kwargs: None
117121
module = imp_util.load_module(config["post_extract_process_path"])
118122
cls = getattr(module, config["post_extract_process_class_name"])
119-
return cls(config["post_extract_process_path"])
123+
return cls(config["post_extract_process_config"])

graph_net/torch/typical_sequence_split_points.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import os
44
from pathlib import Path
5-
from typing import Any, Callable, Dict, List
5+
from typing import Any, Dict, List
66

77
import torch
88
import torch.nn as nn
@@ -252,7 +252,15 @@ def _print_analysis(self, name, path, splits, total_len, full_ops):
252252
print("\n")
253253

254254

255-
def main():
255+
def main(args):
256+
analyzer = SplitAnalyzer(window_size=args.window_size)
257+
results = analyzer.analyze(args.model_list, args.device)
258+
if args.output_json:
259+
with open(args.output_json, "w") as f:
260+
json.dump(results, f, indent=4)
261+
262+
263+
if __name__ == "__main__":
256264
parser = argparse.ArgumentParser(
257265
description="Analyze graph and calculate split points."
258266
)
@@ -278,14 +286,4 @@ def main():
278286
help="Path to save the analysis results in JSON format.",
279287
)
280288
args = parser.parse_args()
281-
282-
analyzer = SplitAnalyzer(window_size=args.window_size)
283-
results = analyzer.analyze(args.model_list, args.device)
284-
285-
if args.output_json:
286-
with open(args.output_json, "w") as f:
287-
json.dump(results, f, indent=4)
288-
289-
290-
if __name__ == "__main__":
291-
main()
289+
main(args)

0 commit comments

Comments
 (0)