Skip to content

Commit 75c3e61

Browse files
committed
post extract process feature
1 parent 00d5b4b commit 75c3e61

File tree

4 files changed

+212
-35
lines changed

4 files changed

+212
-35
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/bin/bash
2+
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
3+
GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR")
4+
PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR")
5+
6+
# 将项目根目录加入Python路径
7+
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
8+
9+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
10+
os.path.dirname(graph_net.__file__))")
11+
12+
# input model path
13+
MODEL_NAME=resnet18
14+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
15+
decorator_config_json_str=$(cat <<EOF
16+
{
17+
"decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py",
18+
"decorator_config": {
19+
"name": "$MODEL_NAME",
20+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
21+
"custom_extractor_config": {
22+
"output_dir": "/work/.BCloud/countkernels/",
23+
"split_positions": [8, 16, 32],
24+
"group_head_and_tail": true,
25+
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
26+
"filter_config": {},
27+
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process.py",
28+
"post_extract_process_config": {
29+
"decorator_path": "$GRAPH_NET_ROOT/torch/shape_prop.py",
30+
"decorator_class_name": "ShapePropagate"
31+
}
32+
}
33+
}
34+
}
35+
EOF
36+
)
37+
DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0)
38+
39+
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG

graph_net/torch/naive_graph_decomposer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def make_config(
3232
output_dir="./tmp/naive_decomposer_dir",
3333
filter_path=None,
3434
filter_config=None,
35+
post_extract_process_path=None,
36+
post_extract_process_config=None,
3537
):
3638
for pos in split_positions:
3739
assert isinstance(
@@ -44,6 +46,8 @@ def make_config(
4446
"output_dir": output_dir,
4547
"filter_path": filter_path,
4648
"filter_config": filter_config if filter_config is not None else {},
49+
"post_extract_process_path": post_extract_process_path,
50+
"post_extract_process_config": post_extract_process_config,
4751
}
4852

4953
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
@@ -71,6 +75,7 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
7175
self.seq_no = seq_no
7276
self.extracted = False
7377
name = f"{parent_graph_extractor.name}_{self.seq_no}"
78+
self.modelname = name
7479
self.builtin_extractor = BuiltinGraphExtractor(
7580
name=name,
7681
dynamic=False,
@@ -79,21 +84,45 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
7984
workspace_path=self.parent_graph_extractor.config["output_dir"],
8085
)
8186
self.filter = self.make_filter(self.parent_graph_extractor.config)
87+
self.post_extract_process = self.make_post_extract_process(
88+
self.parent_graph_extractor.config
89+
)
8290

8391
def forward(self, *args):
8492
if not self.extracted:
8593
if self.need_extract(self.submodule, args):
8694
self.builtin_extractor(self.submodule, args)
95+
self.get_post_extract_process(self.submodule, args)
8796
self.extracted = True
8897
return self.submodule(*args)
8998

9099
def need_extract(self, gm, sample_inputs):
100+
# print("need_extract")
91101
if self.filter is None:
92102
return True
103+
# if self.fusionablity_filter is not None:
104+
# print("fusionablity of this model is ", self.fusionablity_filter(gm, sample_inputs))
93105
return self.filter(gm, sample_inputs)
94106

107+
def get_post_extract_process(self, gm, sample_inputs):
108+
# print("modelname: ",self.modelname)
109+
# print("parent_graph_extractor.config: ",self.parent_graph_extractor.config['output_dir'])
110+
# print("get_post_extract_process")
111+
model_path = os.path.join(
112+
self.parent_graph_extractor.config["output_dir"], self.modelname
113+
)
114+
return self.post_extract_process(model_path)
115+
95116
def make_filter(self, config):
117+
# print("make_filter")
96118
if config["filter_path"] is None:
97119
return None
98120
module = imp_util.load_module(config["filter_path"])
99121
return module.GraphFilter(config["filter_config"])
122+
123+
def make_post_extract_process(self, config):
124+
# print("make post_extract_process")
125+
if config["filter_path"] is None:
126+
return None
127+
module = imp_util.load_module(config["post_extract_process_path"])
128+
return module.PostExtractProcess(config["post_extract_process_config"])
Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,8 @@
1-
from torch.profiler import profile, record_function, ProfilerActivity
2-
3-
41
class GraphFilter:
52
def __init__(self, config):
63
self.config = config
74

85
def __call__(self, gm, sample_inputs):
96
print("GraphFilter")
107
# print(f"GraphFilter\n{gm.code}")
11-
kernels_num = count_kernels(gm, sample_inputs)
12-
print("number of kernels is ", kernels_num)
138
return True
14-
15-
16-
def count_kernels(model, tensors) -> int:
17-
"""
18-
Count the number of CUDA kernel launches performed during a model's forward pass.
19-
20-
Args:
21-
model (torch.nn.Module)
22-
tensors
23-
24-
Returns:
25-
int: The number of kernel launches recorded by PyTorch profiler.
26-
27-
Behavior:
28-
- Runs the model once inside a PyTorch profiler context.
29-
- Identifies the event with key = 'cudaLaunchKernel', which corresponds
30-
to the number of CUDA kernel launches.
31-
"""
32-
# Use PyTorch Profiler
33-
with profile(
34-
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU],
35-
record_shapes=True,
36-
) as prof:
37-
with record_function("model_inference"):
38-
output = model(*tensors)
39-
# print(prof.key_averages().table()) #print a table of profiler result
40-
events = prof.key_averages()
41-
for e in events:
42-
if e.key == "cudaLaunchKernel":
43-
return e.count
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from graph_net.torch import utils
2+
import argparse
3+
import importlib.util
4+
import inspect
5+
import shutil
6+
import torch
7+
import logging
8+
from pathlib import Path
9+
from typing import Type, Any
10+
import sys
11+
import json
12+
import base64
13+
from contextlib import contextmanager
14+
15+
from torch.profiler import profile, record_function, ProfilerActivity
16+
17+
18+
class PostExtractProcess:
19+
def __init__(self, config):
20+
self.config = config
21+
22+
def __call__(self, model_path=None):
23+
print("PostExtractProcess")
24+
if model_path is None:
25+
return False
26+
import json
27+
import base64
28+
import sys
29+
import os
30+
31+
json_string = json.dumps(self.config)
32+
json_bytes = json_string.encode("utf-8")
33+
b64_encoded_bytes = base64.b64encode(json_bytes)
34+
decorator_config = b64_encoded_bytes.decode("utf-8")
35+
36+
# args
37+
parser = argparse.ArgumentParser(description="load and run model")
38+
parser.add_argument(
39+
"--model-path",
40+
type=str,
41+
required=True,
42+
help="Path to folder e.g '../../samples/torch/resnet18'",
43+
)
44+
parser.add_argument(
45+
"--decorator-config",
46+
type=str,
47+
required=False,
48+
default=None,
49+
help="decorator configuration string",
50+
)
51+
args = parser.parse_args()
52+
53+
# model
54+
model_class = load_class_from_file(
55+
f"{model_path}/model.py", class_name="GraphModule"
56+
)
57+
assert model_class is not None
58+
model = model_class()
59+
print(f"{model_path=}")
60+
61+
model = _get_decorator(args)(model)
62+
63+
inputs_params = utils.load_converted_from_text(f"{model_path}")
64+
params = inputs_params["weight_info"]
65+
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
66+
67+
compiled_num_of_kernels = compile_and_count_kernels(model, state_dict)
68+
print("compiled: nums_of_kernels = ", compiled_num_of_kernels)
69+
if compiled_num_of_kernels == 1:
70+
print("Graph is fully fusionable")
71+
return True
72+
else:
73+
print(f"Graph is not fully fusionable ({compiled_num_of_kernels} kernels)")
74+
shutil.rmtree(model_path)
75+
return False
76+
77+
78+
def _convert_to_dict(config_str):
79+
if config_str is None:
80+
return {}
81+
config_str = base64.b64decode(config_str).decode("utf-8")
82+
config = json.loads(config_str)
83+
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
84+
return config
85+
86+
87+
def _get_decorator(args):
88+
if args.decorator_config is None:
89+
return lambda model: model
90+
decorator_config = _convert_to_dict(args.decorator_config)
91+
if "decorator_path" not in decorator_config:
92+
return lambda model: model
93+
class_name = decorator_config.get("decorator_class_name", "RunModelDecorator")
94+
decorator_class = load_class_from_file(
95+
decorator_config["decorator_path"],
96+
class_name=class_name,
97+
)
98+
return decorator_class(decorator_config.get("decorator_config", {}))
99+
100+
101+
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
102+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
103+
unnamed = importlib.util.module_from_spec(spec)
104+
spec.loader.exec_module(unnamed)
105+
model_class = getattr(unnamed, class_name, None)
106+
return model_class
107+
108+
109+
def compile_and_count_kernels(gm, sample_inputs) -> int:
110+
"""
111+
Count the number of CUDA kernel launches performed during a model's forward pass.
112+
113+
Args:
114+
gm(graph models)
115+
sample_inputs(tensors)
116+
117+
Returns:
118+
int: The number of kernels used.
119+
120+
Behavior:
121+
- Runs the model once inside a PyTorch profiler context.
122+
- Identifies the event with key = 'cudaLaunchKernel', which corresponds
123+
to the number of CUDA kernel launches.
124+
"""
125+
gm.eval()
126+
# Use PyTorch Profiler
127+
compiled_gm = torch.compile(gm)
128+
_ = compiled_gm(**sample_inputs)
129+
130+
with profile(
131+
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU],
132+
record_shapes=True,
133+
) as prof:
134+
with record_function("model_inference"):
135+
output = compiled_gm(**sample_inputs)
136+
print(prof.key_averages().table()) # print a table of profiler result
137+
events = prof.key_averages()
138+
if_compile_work = any(e.key == "TorchDynamo Cache Lookup" for e in events)
139+
if not if_compile_work:
140+
print("Compile failed")
141+
return -1
142+
for e in events:
143+
if e.key == "cuLaunchKernel":
144+
return e.count

0 commit comments

Comments
 (0)