Skip to content

Commit 9c02d0a

Browse files
roll-awaylixinqi
andauthored
get fusible subgraph test (#431)
* 1119 * 1120 * 1120.2 * model_path * remove unnecessary files and pre-committed * remove unnecessary files and pre-committed * 1121 remove unnecessary files * modify rev version * modify rev version * modify rev version * accuracy issues targeted * test script and modify feature * return set[str] * add logfile for test * filter can get the number of kernels in naive_graph_decomposer * post extract process feature * remove unnecessary code blocks and variables * modify the way of counting kernels used * modify the way of counting kernels used * modify script, rename files and variables * add failure protection and log output when removing directories * add a script to check fusability of a given model * add a script to check if a given model is fully fusable * add a script to check if a given model is fully fusable * a script to check if a given model is fully fusable * add a script to check if a given model is fully fusionable * add a script to find fully fusionable subgraph * find the biggest fully fusionable subgraph * get fusible subgraph test * modify get fully fusible subgraph * improve fully_fusible_subgraph_extractor.py efficiency * backup code * Improve efficiency of test/fully_fusible_subgraph_extractor_test.sh * fully_fusible_subgraph_extractor test * delete empty file * modify input file path --------- Co-authored-by: Xinqi Li <[email protected]>
1 parent de2ba85 commit 9c02d0a

15 files changed

+535
-199
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
samples/timm/crossvit_small_240.in1k
2+
samples/timm/poolformerv2_s12.sail_in1k
3+
samples/timm/regnety_080.pycls_in1k
4+
samples/timm/dla46x_c.in1k
5+
samples/timm/mobilenetv1_100.ra4_e3600_r224_in1k
6+
samples/timm/efficientnetv2_rw_s.ra2_in1k
7+
samples/timm/vit_base_patch16_rope_ape_224.naver_in1k
8+
samples/timm/fastvit_t8.apple_dist_in1k
9+
samples/timm/test_byobnet.r160_in1k
10+
samples/timm/mambaout_base.in1k
11+
samples/timm/davit_small
12+
samples/timm/resnet61q.ra2_in1k
13+
samples/timm/coat_tiny
14+
samples/timm/regnetx_004.pycls_in1k
15+
samples/timm/convnextv2_large.fcmae
16+
samples/timm/regnety_640.seer
17+
samples/timm/repvit_m1_1.dist_300e_in1k
18+
samples/timm/tinynet_d.in1k
19+
samples/timm/resnetrs270.tf_in1k
20+
samples/timm/cait_m48_448
21+
samples/timm/legacy_seresnet50.in1k
22+
samples/timm/tinynet_a.in1k
23+
samples/timm/convnext_small.fb_in1k
24+
samples/timm/vit_huge_patch14_clip_quickgelu_224.dfn5b
25+
samples/timm/dpn131.mx_in1k
26+
samples/timm/convnextv2_large.fcmae_ft_in1k
27+
samples/timm/convnextv2_small
28+
samples/timm/repvit_m1.dist_in1k
29+
samples/timm/cs3darknet_s
30+
samples/timm/resnet50d.a1_in1k
31+
samples/timm/dm_nfnet_f6
32+
samples/timm/coatnet_1_rw_224
33+
samples/timm/lcnet_050.ra2_in1k
34+
samples/timm/efficientnet_em.ra2_in1k
35+
samples/timm/dpn48b
36+
samples/timm/semnasnet_075.rmsp_in1k
37+
samples/timm/skresnet34.ra_in1k
38+
samples/timm/crossvit_15_dagger_240.in1k
39+
samples/timm/mnasnet_100.rmsp_in1k
40+
samples/timm/mobilenetv3_rw.rmsp_in1k
41+
samples/timm/xception65p.ra3_in1k
42+
samples/timm/coatnet_0_rw_224
43+
samples/timm/eca_nfnet_l3
44+
samples/timm/deit3_base_patch16_224.fb_in1k
45+
samples/timm/mambaout_base_short_rw.sw_e500_in1k
46+
samples/timm/mobilenetv4_conv_small.e1200_r224_in1k
47+
samples/timm/xception71.tf_in1k
48+
samples/timm/dla60.in1k
49+
samples/timm/repghostnet_130.in1k
50+
samples/timm/mambaout_base_plus_rw.sw_e150_in12k
51+
samples/timm/poolformerv2_s36.sail_in1k
52+
samples/timm/deit3_huge_patch14_224.fb_in1k
53+
samples/timm/vit_base_patch32_clip_224.datacompxl
54+
samples/timm/poolformer_m48.sail_in1k
55+
samples/timm/regnety_006.pycls_in1k
56+
samples/timm/starnet_s4.in1k
57+
samples/timm/poolformer_m36.sail_in1k
58+
samples/timm/vit_huge_patch14_gap_224.in1k_ijepa
59+
samples/timm/efficientnet_b3.ra2_in1k
60+
samples/timm/mobilenetv3_large_150d.ra4_e3600_r256_in1k
61+
samples/timm/hgnetv2_b0.ssld_stage1_in22k_in1k
62+
samples/timm/convnextv2_huge.fcmae
63+
samples/timm/davit_huge
64+
samples/timm/regnetx_004_tv.tv2_in1k
65+
samples/timm/dla34.in1k
66+
samples/timm/convnext_xlarge.fb_in22k
67+
samples/timm/resmlp_12_224.fb_dino
68+
samples/timm/fasternet_t1.in1k
69+
samples/timm/resnetblur50.bt_in1k
70+
samples/timm/res2net50d.in1k
71+
samples/timm/vit_base_patch32_224.augreg_in1k
72+
samples/timm/mambaout_base_wide_rw.sw_e500_in1k
73+
samples/timm/vgg19_bn.tv_in1k
74+
samples/timm/vit_small_patch16_rope_ape_224.naver_in1k
75+
samples/timm/hardcorenas_b.miil_green_in1k
76+
samples/timm/vgg16.tv_in1k
77+
samples/timm/xception41p.ra3_in1k
78+
samples/timm/efficientnet_lite0.ra_in1k
79+
samples/timm/regnetv_064.ra3_in1k
80+
samples/timm/regnety_320.pycls_in1k
81+
samples/timm/convnext_pico.d1_in1k
82+
samples/timm/repvit_m1_0.dist_300e_in1k
83+
samples/timm/resnet50c.gluon_in1k
84+
samples/timm/mobileone_s4.apple_in1k
85+
samples/timm/ghostnet_100.in1k
86+
samples/timm/deit_base_distilled_patch16_384
87+
samples/timm/dpn68b.mx_in1k
88+
samples/timm/dla60_res2next
89+
samples/timm/resnet101d.gluon_in1k
90+
samples/timm/eva02_large_patch14_clip_224.merged2b
91+
samples/timm/fasternet_m.in1k
92+
samples/timm/mobilenetv2_110d.ra_in1k
93+
samples/timm/regnetx_064.pycls_in1k
94+
samples/timm/cspresnet50.ra_in1k
95+
samples/timm/resmlp_24_224.fb_dino
96+
samples/timm/mobileone_s3.apple_in1k
97+
samples/timm/mobileone_s2.apple_in1k
98+
samples/timm/res2net101d
99+
samples/timm/hardcorenas_f.miil_green_in1k
100+
samples/timm/hrnet_w18_ssld.paddle_in1k
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#samples/timm/crossvit_small_240.in1k
2+
#samples/timm/poolformerv2_s12.sail_in1k
3+
#samples/timm/regnety_080.pycls_in1k
4+
#samples/timm/dla46x_c.in1k
5+
#samples/timm/mobilenetv1_100.ra4_e3600_r224_in1k
6+
samples/timm/efficientnetv2_rw_s.ra2_in1k
7+
#samples/timm/vit_base_patch16_rope_ape_224.naver_in1k
8+
#samples/timm/fastvit_t8.apple_dist_in1k
9+
#samples/timm/test_byobnet.r160_in1k
10+
#samples/timm/mambaout_base.in1k
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
MODEL_NAME=resnet18
8+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9+
# INPUT_MODEL_LIST=$GRAPH_NET_ROOT/test/dev_model_list/get_fusible_subgraph_sample_list.txt
10+
INPUT_MODEL_LIST=$GRAPH_NET_ROOT/test/dev_model_list/small_sample_list_for_get_fusible_subgraph.txt
11+
12+
OUTPUT_DIR="/tmp/find_fully_fusible_output"
13+
config_json_str=$(cat <<EOF
14+
{
15+
"handler_path": "$GRAPH_NET_ROOT/torch/fully_fusible_subgraph_extractor.py",
16+
"handler_class_name":"FullyFusibleSubgraphExtractor",
17+
"handler_config": {
18+
"resume": false,
19+
"model_path_prefix": "$GRAPH_NET_ROOT/../",
20+
"output_dir": "$OUTPUT_DIR",
21+
"nn_module_fully_fusible_decorator_path": "$GRAPH_NET_ROOT/torch/count_kernels_util.py",
22+
"nn_module_fully_fusible_decorator_class_name": "TorchSubModuleFullyFusibleDecorator",
23+
"max_step": 3,
24+
"min_step": 2,
25+
"max_nodes": 4
26+
}
27+
}
28+
EOF
29+
)
30+
CONFIG=$(echo $config_json_str | base64 -w 0)
31+
32+
# python3 -m graph_net.model_path_handler --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
33+
python3 -m graph_net.model_path_handler --model-path-list $INPUT_MODEL_LIST --handler-config=$CONFIG
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
# added
6+
# input model path
7+
MODEL_NAME=resnet18d.ra2_in1k
8+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9+
checker_config_json_str=$(cat <<EOF
10+
{
11+
"post_extract_process_config": {
12+
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/count_kernels_util.py",
13+
"post_extract_process_class_name": "GraphFullyFusionable"
14+
}
15+
}
16+
EOF
17+
)
18+
CHECKER_CONFIG=$(echo $checker_config_json_str | base64 -w 0)
19+
20+
python3 -m graph_net.torch.check_if_a_given_model_is_fully_fusionable --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --checker-config=$CHECKER_CONFIG

graph_net/test/naive_decomposer_and_post_extract_process_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ decorator_config_json_str=$(cat <<EOF
1919
"group_head_and_tail": true,
2020
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
2121
"filter_config": {},
22-
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
22+
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/count_kernels_util.py",
2323
"post_extract_process_class_name": "GraphFullyFusible"
2424
}
2525
}

graph_net/test/torch_extractor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def forward(self, x):
7676
start_node_idx=0,
7777
end_node_idx=2,
7878
submodule_hook=submodule_hook,
79-
# group_head_and_tail=False,
79+
group_head_and_tail=True,
8080
)
8181
folded_output = folded(inp)
8282

graph_net/torch/constraint_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import sys
22
import os
33
import graph_net
4+
import logging
5+
6+
logger = logging.getLogger(__name__)
47

58

69
class NaiveDataInputPredicator:
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import traceback
2+
from graph_net.torch import utils
3+
import importlib.util
4+
import torch
5+
import sys
6+
from typing import Type
7+
from torch.profiler import profile, record_function, ProfilerActivity
8+
9+
from graph_net.torch.graph_fusibility_status import (
10+
GraphFusibilityStatus,
11+
GraphFusibility,
12+
)
13+
14+
15+
class TorchSubModuleFullyFusibleDecorator:
16+
def __init__(self, config):
17+
self.config = config
18+
19+
def __call__(self, module, sub_module_idx):
20+
return TorchNNModuleFullyFusiblePredicator(module)
21+
22+
23+
class TorchNNModuleFullyFusiblePredicator(torch.nn.Module):
24+
def __init__(self, module):
25+
super().__init__()
26+
self.module = module
27+
28+
def forward(self, *inputs):
29+
try:
30+
compiled_model = torch.compile(self.module)
31+
except Exception:
32+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
33+
ret_tensors, compiled_num_of_kernels = count_kernels(compiled_model, inputs)
34+
if compiled_num_of_kernels == 1:
35+
raise GraphFusibilityStatus(GraphFusibility.kFullyFusible)
36+
else:
37+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
38+
return ret_tensors
39+
40+
41+
class ThrowExitStatusIfGraphFullyFusible:
42+
def __init__(self, config):
43+
self.config = config
44+
45+
def __call__(self, model_path=None):
46+
# def callback = lambda: logger.warning("post-extract-process-call-end")
47+
# logger.warning("post-extract-process-call-begin")
48+
# atexit.register(callback)
49+
torch._dynamo.reset()
50+
if model_path is None:
51+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
52+
# model
53+
model_class = load_class_from_file(
54+
f"{model_path}/model.py", class_name="GraphModule"
55+
)
56+
assert model_class is not None
57+
model = model_class()
58+
# print(f"{model_path=}")
59+
60+
inputs_params = utils.load_converted_from_text(f"{model_path}")
61+
params = inputs_params["weight_info"]
62+
state_dict = {k: utils.get_dummy_tensor(v) for k, v in params.items()}
63+
64+
# try to run the model
65+
try:
66+
model(**state_dict)
67+
except Exception:
68+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
69+
# try to compile the model
70+
try:
71+
compiled_model = torch.compile(model)
72+
except Exception:
73+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
74+
_, compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
75+
if compiled_num_of_kernels == 1:
76+
raise GraphFusibilityStatus(GraphFusibility.kFullyFusible)
77+
else:
78+
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
79+
80+
81+
class GraphFullyFusible:
82+
def __init__(self, config):
83+
self.predicator = ThrowExitStatusIfGraphFullyFusible(config)
84+
85+
def __call__(self, model_path=None):
86+
try:
87+
self.predicator(model_path)
88+
except GraphFusibilityStatus as status:
89+
if status.graph_fusibility == GraphFusibility.kFullyFusible:
90+
sys.exit(0)
91+
elif status.graph_fusibility == GraphFusibility.kNotFullyFusible:
92+
sys.exit(1)
93+
else:
94+
raise NotImplementedError(f"{status.graph_fusibility=}")
95+
except Exception:
96+
traceback.print_exc()
97+
sys.exit(1)
98+
99+
100+
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
101+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
102+
unnamed = importlib.util.module_from_spec(spec)
103+
spec.loader.exec_module(unnamed)
104+
model_class = getattr(unnamed, class_name, None)
105+
return model_class
106+
107+
108+
def count_kernels(model, sample_inputs) -> int:
109+
"""
110+
Count the number of CUDA kernel launches performed during a model's forward pass.
111+
112+
Args:
113+
model(graph models)
114+
sample_inputs(tensors)
115+
116+
Returns:
117+
int: The number of kernels used.
118+
119+
Behavior:
120+
- Runs the model once inside a PyTorch profiler context.
121+
- Identifies the event with key = 'cudaLaunchKernel', which corresponds
122+
to the number of CUDA kernel launches.
123+
"""
124+
model.eval()
125+
# Use PyTorch Profiler
126+
127+
with profile(
128+
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU],
129+
record_shapes=True,
130+
) as prof:
131+
with record_function("model_inference"):
132+
if isinstance(sample_inputs, dict):
133+
ret_tensors = model(**sample_inputs)
134+
elif isinstance(sample_inputs, (list, tuple)):
135+
ret_tensors = model(*sample_inputs)
136+
else:
137+
raise NotImplementedError(f"{type(sample_inputs)=}")
138+
139+
events = prof.key_averages()
140+
141+
total_count = 0
142+
for e in events:
143+
if e.key == "cuLaunchKernel" or e.key == "cudaLaunchKernel":
144+
total_count += e.count
145+
return ret_tensors, total_count

graph_net/torch/decompose_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def fold_range_to_submodule(
181181
end_node_idx: int,
182182
submodule_hook=None,
183183
submodule_name="extracted_submodule",
184-
group_head_and_tail=True,
184+
group_head_and_tail=False,
185185
):
186186
return convert_to_submodules_graph(
187187
gm,
@@ -249,7 +249,9 @@ def get_args_node(arg):
249249
yield arg.stop
250250
yield arg.step
251251
else:
252-
assert isinstance(arg, (int, bool, float, str, type(None))), f"{type(arg)=}"
252+
assert isinstance(
253+
arg, (int, bool, float, str, type(...), type(None))
254+
), f"{type(arg)=}"
253255

254256
def get_args_node_and_self_node(node):
255257
for arg in node.args:

0 commit comments

Comments
 (0)