Skip to content

Commit 5833d88

Browse files
committed
fully_fusible_subgraph_extractor test
1 parent 48467f7 commit 5833d88

File tree

2 files changed

+3
-42
lines changed

2 files changed

+3
-42
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#samples/timm/crossvit_small_240.in1k
22
#samples/timm/poolformerv2_s12.sail_in1k
3-
#samples/timm/regnety_080.pycls_in1k
3+
samples/timm/regnety_080.pycls_in1k
44
#samples/timm/dla46x_c.in1k
55
#samples/timm/mobilenetv1_100.ra4_e3600_r224_in1k
6-
samples/timm/efficientnetv2_rw_s.ra2_in1k
7-
samples/timm/vit_base_patch16_rope_ape_224.naver_in1k
6+
#samples/timm/efficientnetv2_rw_s.ra2_in1k
7+
#samples/timm/vit_base_patch16_rope_ape_224.naver_in1k
88
#samples/timm/fastvit_t8.apple_dist_in1k
99
#samples/timm/test_byobnet.r160_in1k
1010
#samples/timm/mambaout_base.in1k

graph_net/torch/graph_decomposer.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,8 @@ def make_config(
4848
output_dir="./tmp/naive_decomposer_dir",
4949
filter_path=None,
5050
filter_config=None,
51-
post_extract_process_path=None,
52-
post_extract_process_class_name=None,
53-
post_extract_process_config=None,
5451
**kwargs,
5552
):
56-
if post_extract_process_config is None:
57-
post_extract_process_config = {}
5853
for pos in split_positions:
5954
assert isinstance(
6055
pos, int
@@ -66,9 +61,6 @@ def make_config(
6661
"output_dir": output_dir,
6762
"filter_path": filter_path,
6863
"filter_config": filter_config if filter_config is not None else {},
69-
"post_extract_process_path": post_extract_process_path,
70-
"post_extract_process_class_name": post_extract_process_class_name,
71-
"post_extract_process_config": post_extract_process_config,
7264
}
7365

7466
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
@@ -111,14 +103,9 @@ def _make_config(
111103
chain_style=False,
112104
filter_path=None,
113105
filter_config=None,
114-
post_extract_process_path=None,
115-
post_extract_process_class_name=None,
116-
post_extract_process_config=None,
117106
model_path_prefix="",
118107
**kwargs,
119108
):
120-
if post_extract_process_config is None:
121-
post_extract_process_config = {}
122109
for pos in split_positions:
123110
assert isinstance(
124111
pos, int
@@ -130,9 +117,6 @@ def _make_config(
130117
"output_dir": output_dir,
131118
"filter_path": filter_path,
132119
"filter_config": filter_config if filter_config is not None else {},
133-
"post_extract_process_path": post_extract_process_path,
134-
"post_extract_process_class_name": post_extract_process_class_name,
135-
"post_extract_process_config": post_extract_process_config,
136120
"model_path_prefix": model_path_prefix,
137121
}
138122

@@ -186,9 +170,6 @@ def _make_config(
186170
output_dir="./tmp/naive_decomposer_dir",
187171
filter_path=None,
188172
filter_config=None,
189-
post_extract_process_path=None,
190-
post_extract_process_class_name=None,
191-
post_extract_process_config=None,
192173
model_path_prefix="",
193174
**kwargs,
194175
):
@@ -198,18 +179,13 @@ def _make_config(
198179
raise ValueError(
199180
f"split_results_path should be a valid JSON file path, but got {split_results_path=}"
200181
)
201-
if post_extract_process_config is None:
202-
post_extract_process_config = {}
203182
return {
204183
"split_results_path": split_results_path,
205184
"group_head_and_tail": group_head_and_tail,
206185
"chain_style": chain_style,
207186
"output_dir": output_dir,
208187
"filter_path": filter_path,
209188
"filter_config": filter_config if filter_config is not None else {},
210-
"post_extract_process_path": post_extract_process_path,
211-
"post_extract_process_class_name": post_extract_process_class_name,
212-
"post_extract_process_config": post_extract_process_config,
213189
"model_path_prefix": model_path_prefix,
214190
}
215191

@@ -274,7 +250,6 @@ def __init__(
274250
),
275251
)
276252
self.filter = self.make_filter(self.config)
277-
self.post_extract_process = self.make_post_extract_process(self.config)
278253

279254
def _get_model_path(self):
280255
return os.path.join(
@@ -284,33 +259,19 @@ def _get_model_path(self):
284259
)
285260

286261
def forward(self, *args):
287-
logger.warning("naive decomposer forwarding")
288262
if not self.extracted:
289263
if self.need_extract(self.submodule, args):
290264
self.builtin_extractor(self.submodule, args)
291-
self._post_extract_process()
292265
self.extracted = True
293-
logger.warning("naive decomposer end")
294266
return self.submodule(*args)
295267

296268
def need_extract(self, gm, sample_inputs):
297269
if self.filter is None:
298270
return True
299271
return self.filter(gm, sample_inputs)
300272

301-
def _post_extract_process(self):
302-
model_path = self._get_model_path()
303-
return self.post_extract_process(model_path)
304-
305273
def make_filter(self, config):
306274
if config["filter_path"] is None:
307275
return None
308276
module = imp_util.load_module(config["filter_path"])
309277
return module.GraphFilter(config["filter_config"])
310-
311-
def make_post_extract_process(self, config):
312-
if config.get("post_extract_process_path") is None:
313-
return lambda *args, **kwargs: None
314-
module = imp_util.load_module(config["post_extract_process_path"])
315-
cls = getattr(module, config["post_extract_process_class_name"])
316-
return cls(config["post_extract_process_config"])

0 commit comments

Comments
 (0)