Skip to content

Commit 513cc38

Browse files
committed
add a script to check if a given model is fully fusable
1 parent 9a28d45 commit 513cc38

File tree

13 files changed

+4522
-9
lines changed

13 files changed

+4522
-9
lines changed

graph_net/torch/post_extract_process_count_kernels.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from graph_net.torch import utils
22
import importlib.util
3-
import shutil
43
import torch
54
from typing import Type
65
from torch.profiler import profile, record_function, ProfilerActivity
@@ -30,25 +29,25 @@ def __call__(self, model_path=None):
3029
model(**state_dict)
3130
except Exception as e:
3231
print(f"failed in running model:{e}")
33-
print(f"removing: {model_path}")
34-
shutil.rmtree(model_path)
32+
# print(f"removing: {model_path}")
33+
# shutil.rmtree(model_path)
3534
return False
3635
# try to compile the model
3736
try:
3837
compiled_model = torch.compile(model)
3938
except Exception as e:
4039
print(f"failed in compiling model:{e}")
41-
print(f"removing: {model_path}")
42-
shutil.rmtree(model_path)
40+
# print(f"removing: {model_path}")
41+
# shutil.rmtree(model_path)
4342
return False
4443
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
4544
if compiled_num_of_kernels == 1:
4645
print(model_path, "can be fully integrated")
4746
return True
4847
else:
49-
print(f"{model_path} can not be fully integrated, to be removed")
50-
print(f"removing: {model_path}")
51-
shutil.rmtree(model_path)
48+
print(f"{model_path} can not be fully integrated")
49+
# print(f"removing: {model_path}")
50+
# shutil.rmtree(model_path)
5251
return False
5352

5453

@@ -84,7 +83,7 @@ def count_kernels(model, sample_inputs) -> int:
8483
record_shapes=True,
8584
) as prof:
8685
with record_function("model_inference"):
87-
output = model(**sample_inputs)
86+
_ = model(**sample_inputs)
8887
events = prof.key_averages()
8988

9089
total_count = 0
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
248d46ebcf5bc02d3e72953ea430b5e18175b0419dbdbcd2479202497f58319d
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"source": "timm",
6+
"heuristic_tag": "computer_vision"
7+
}

samples/timm/resnet18/input_meta.py

Whitespace-only changes.
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
from sympy import Symbol
2+
3+
S0 = Symbol("S0")
4+
S1 = Symbol("S1")
5+
6+
dynamic_dim_constraint_symbols = [S0, S1]
7+
8+
dynamic_dim_constraint_symbol2example_value = {S0: 1, S1: 224}
9+
10+
dynamic_dim_constraint_relations = []
11+
12+
dynamic_dim_constraint_input_shapes = [
13+
([64], "L_self_modules_bn1_buffers_running_mean_"),
14+
([64], "L_self_modules_bn1_buffers_running_var_"),
15+
([64], "L_self_modules_bn1_parameters_bias_"),
16+
([64], "L_self_modules_bn1_parameters_weight_"),
17+
([64, 3, 7, 7], "L_self_modules_conv1_parameters_weight_"),
18+
([1000], "L_self_modules_fc_parameters_bias_"),
19+
([1000, 512], "L_self_modules_fc_parameters_weight_"),
20+
([64], "L_self_modules_layer1_modules_0_modules_bn1_buffers_running_mean_"),
21+
([64], "L_self_modules_layer1_modules_0_modules_bn1_buffers_running_var_"),
22+
([64], "L_self_modules_layer1_modules_0_modules_bn1_parameters_bias_"),
23+
([64], "L_self_modules_layer1_modules_0_modules_bn1_parameters_weight_"),
24+
([64], "L_self_modules_layer1_modules_0_modules_bn2_buffers_running_mean_"),
25+
([64], "L_self_modules_layer1_modules_0_modules_bn2_buffers_running_var_"),
26+
([64], "L_self_modules_layer1_modules_0_modules_bn2_parameters_bias_"),
27+
([64], "L_self_modules_layer1_modules_0_modules_bn2_parameters_weight_"),
28+
(
29+
[64, 64, 3, 3],
30+
"L_self_modules_layer1_modules_0_modules_conv1_parameters_weight_",
31+
),
32+
(
33+
[64, 64, 3, 3],
34+
"L_self_modules_layer1_modules_0_modules_conv2_parameters_weight_",
35+
),
36+
([64], "L_self_modules_layer1_modules_1_modules_bn1_buffers_running_mean_"),
37+
([64], "L_self_modules_layer1_modules_1_modules_bn1_buffers_running_var_"),
38+
([64], "L_self_modules_layer1_modules_1_modules_bn1_parameters_bias_"),
39+
([64], "L_self_modules_layer1_modules_1_modules_bn1_parameters_weight_"),
40+
([64], "L_self_modules_layer1_modules_1_modules_bn2_buffers_running_mean_"),
41+
([64], "L_self_modules_layer1_modules_1_modules_bn2_buffers_running_var_"),
42+
([64], "L_self_modules_layer1_modules_1_modules_bn2_parameters_bias_"),
43+
([64], "L_self_modules_layer1_modules_1_modules_bn2_parameters_weight_"),
44+
(
45+
[64, 64, 3, 3],
46+
"L_self_modules_layer1_modules_1_modules_conv1_parameters_weight_",
47+
),
48+
(
49+
[64, 64, 3, 3],
50+
"L_self_modules_layer1_modules_1_modules_conv2_parameters_weight_",
51+
),
52+
([128], "L_self_modules_layer2_modules_0_modules_bn1_buffers_running_mean_"),
53+
([128], "L_self_modules_layer2_modules_0_modules_bn1_buffers_running_var_"),
54+
([128], "L_self_modules_layer2_modules_0_modules_bn1_parameters_bias_"),
55+
([128], "L_self_modules_layer2_modules_0_modules_bn1_parameters_weight_"),
56+
([128], "L_self_modules_layer2_modules_0_modules_bn2_buffers_running_mean_"),
57+
([128], "L_self_modules_layer2_modules_0_modules_bn2_buffers_running_var_"),
58+
([128], "L_self_modules_layer2_modules_0_modules_bn2_parameters_bias_"),
59+
([128], "L_self_modules_layer2_modules_0_modules_bn2_parameters_weight_"),
60+
(
61+
[128, 64, 3, 3],
62+
"L_self_modules_layer2_modules_0_modules_conv1_parameters_weight_",
63+
),
64+
(
65+
[128, 128, 3, 3],
66+
"L_self_modules_layer2_modules_0_modules_conv2_parameters_weight_",
67+
),
68+
(
69+
[128, 64, 1, 1],
70+
"L_self_modules_layer2_modules_0_modules_downsample_modules_0_parameters_weight_",
71+
),
72+
(
73+
[128],
74+
"L_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_mean_",
75+
),
76+
(
77+
[128],
78+
"L_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_var_",
79+
),
80+
(
81+
[128],
82+
"L_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_bias_",
83+
),
84+
(
85+
[128],
86+
"L_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_weight_",
87+
),
88+
([128], "L_self_modules_layer2_modules_1_modules_bn1_buffers_running_mean_"),
89+
([128], "L_self_modules_layer2_modules_1_modules_bn1_buffers_running_var_"),
90+
([128], "L_self_modules_layer2_modules_1_modules_bn1_parameters_bias_"),
91+
([128], "L_self_modules_layer2_modules_1_modules_bn1_parameters_weight_"),
92+
([128], "L_self_modules_layer2_modules_1_modules_bn2_buffers_running_mean_"),
93+
([128], "L_self_modules_layer2_modules_1_modules_bn2_buffers_running_var_"),
94+
([128], "L_self_modules_layer2_modules_1_modules_bn2_parameters_bias_"),
95+
([128], "L_self_modules_layer2_modules_1_modules_bn2_parameters_weight_"),
96+
(
97+
[128, 128, 3, 3],
98+
"L_self_modules_layer2_modules_1_modules_conv1_parameters_weight_",
99+
),
100+
(
101+
[128, 128, 3, 3],
102+
"L_self_modules_layer2_modules_1_modules_conv2_parameters_weight_",
103+
),
104+
([256], "L_self_modules_layer3_modules_0_modules_bn1_buffers_running_mean_"),
105+
([256], "L_self_modules_layer3_modules_0_modules_bn1_buffers_running_var_"),
106+
([256], "L_self_modules_layer3_modules_0_modules_bn1_parameters_bias_"),
107+
([256], "L_self_modules_layer3_modules_0_modules_bn1_parameters_weight_"),
108+
([256], "L_self_modules_layer3_modules_0_modules_bn2_buffers_running_mean_"),
109+
([256], "L_self_modules_layer3_modules_0_modules_bn2_buffers_running_var_"),
110+
([256], "L_self_modules_layer3_modules_0_modules_bn2_parameters_bias_"),
111+
([256], "L_self_modules_layer3_modules_0_modules_bn2_parameters_weight_"),
112+
(
113+
[256, 128, 3, 3],
114+
"L_self_modules_layer3_modules_0_modules_conv1_parameters_weight_",
115+
),
116+
(
117+
[256, 256, 3, 3],
118+
"L_self_modules_layer3_modules_0_modules_conv2_parameters_weight_",
119+
),
120+
(
121+
[256, 128, 1, 1],
122+
"L_self_modules_layer3_modules_0_modules_downsample_modules_0_parameters_weight_",
123+
),
124+
(
125+
[256],
126+
"L_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_mean_",
127+
),
128+
(
129+
[256],
130+
"L_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_var_",
131+
),
132+
(
133+
[256],
134+
"L_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_bias_",
135+
),
136+
(
137+
[256],
138+
"L_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_weight_",
139+
),
140+
([256], "L_self_modules_layer3_modules_1_modules_bn1_buffers_running_mean_"),
141+
([256], "L_self_modules_layer3_modules_1_modules_bn1_buffers_running_var_"),
142+
([256], "L_self_modules_layer3_modules_1_modules_bn1_parameters_bias_"),
143+
([256], "L_self_modules_layer3_modules_1_modules_bn1_parameters_weight_"),
144+
([256], "L_self_modules_layer3_modules_1_modules_bn2_buffers_running_mean_"),
145+
([256], "L_self_modules_layer3_modules_1_modules_bn2_buffers_running_var_"),
146+
([256], "L_self_modules_layer3_modules_1_modules_bn2_parameters_bias_"),
147+
([256], "L_self_modules_layer3_modules_1_modules_bn2_parameters_weight_"),
148+
(
149+
[256, 256, 3, 3],
150+
"L_self_modules_layer3_modules_1_modules_conv1_parameters_weight_",
151+
),
152+
(
153+
[256, 256, 3, 3],
154+
"L_self_modules_layer3_modules_1_modules_conv2_parameters_weight_",
155+
),
156+
([512], "L_self_modules_layer4_modules_0_modules_bn1_buffers_running_mean_"),
157+
([512], "L_self_modules_layer4_modules_0_modules_bn1_buffers_running_var_"),
158+
([512], "L_self_modules_layer4_modules_0_modules_bn1_parameters_bias_"),
159+
([512], "L_self_modules_layer4_modules_0_modules_bn1_parameters_weight_"),
160+
([512], "L_self_modules_layer4_modules_0_modules_bn2_buffers_running_mean_"),
161+
([512], "L_self_modules_layer4_modules_0_modules_bn2_buffers_running_var_"),
162+
([512], "L_self_modules_layer4_modules_0_modules_bn2_parameters_bias_"),
163+
([512], "L_self_modules_layer4_modules_0_modules_bn2_parameters_weight_"),
164+
(
165+
[512, 256, 3, 3],
166+
"L_self_modules_layer4_modules_0_modules_conv1_parameters_weight_",
167+
),
168+
(
169+
[512, 512, 3, 3],
170+
"L_self_modules_layer4_modules_0_modules_conv2_parameters_weight_",
171+
),
172+
(
173+
[512, 256, 1, 1],
174+
"L_self_modules_layer4_modules_0_modules_downsample_modules_0_parameters_weight_",
175+
),
176+
(
177+
[512],
178+
"L_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_mean_",
179+
),
180+
(
181+
[512],
182+
"L_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_var_",
183+
),
184+
(
185+
[512],
186+
"L_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_bias_",
187+
),
188+
(
189+
[512],
190+
"L_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_weight_",
191+
),
192+
([512], "L_self_modules_layer4_modules_1_modules_bn1_buffers_running_mean_"),
193+
([512], "L_self_modules_layer4_modules_1_modules_bn1_buffers_running_var_"),
194+
([512], "L_self_modules_layer4_modules_1_modules_bn1_parameters_bias_"),
195+
([512], "L_self_modules_layer4_modules_1_modules_bn1_parameters_weight_"),
196+
([512], "L_self_modules_layer4_modules_1_modules_bn2_buffers_running_mean_"),
197+
([512], "L_self_modules_layer4_modules_1_modules_bn2_buffers_running_var_"),
198+
([512], "L_self_modules_layer4_modules_1_modules_bn2_parameters_bias_"),
199+
([512], "L_self_modules_layer4_modules_1_modules_bn2_parameters_weight_"),
200+
(
201+
[512, 512, 3, 3],
202+
"L_self_modules_layer4_modules_1_modules_conv1_parameters_weight_",
203+
),
204+
(
205+
[512, 512, 3, 3],
206+
"L_self_modules_layer4_modules_1_modules_conv2_parameters_weight_",
207+
),
208+
([S0, 3, S1, S1], "L_x_"),
209+
([], "s1"),
210+
]

0 commit comments

Comments
 (0)