Skip to content

Commit 201d5b9

Browse files
lixinqiJewelRoam
andauthored
Improve efficiency of test/typical_sequence_decomposer_test.sh (#438)
* debug_typical_sequence * support model-path-prefix in splitting positions * fix * fix * Improve efficiency of test/typical_sequence_decomposer_test.sh --------- Co-authored-by: JewelRoam <[email protected]>
1 parent 51e558f commit 201d5b9

File tree

6 files changed

+61
-64
lines changed

6 files changed

+61
-64
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
samples/timm/convnextv2_base.fcmae_ft_in1k
2+
samples/timm/hgnet_tiny.paddle_in1k
3+
samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k
4+
samples/timm/regnety_080_tv.tv2_in1k
5+
samples/timm/res2net50_14w_8s.in1k
6+
samples/torchaudio/wavlm_base
7+
samples/torchgeometric/RECT_L
8+
samples/torchvision/vgg16_bn
9+
samples/transformers-auto-model/bge-small-en-v1.5
10+
samples/transformers-auto-model/distilbert_distilbert-base-multilingual-cased
11+
samples/transformers-auto-model/OFA-Sys_chinese-clip-vit-large-patch14
12+
samples/transformers-auto-model/opus-mt-ase-es
13+
samples/transformers-auto-model/opus-mt-en-gv
14+
samples/transformers-auto-model/opus-mt-en-phi
15+
samples/transformers-auto-model/opus-mt-en-sal
16+
samples/transformers-auto-model/opus-mt-en-tw
17+
samples/transformers-auto-model/opus-mt-fi-niu
18+
samples/transformers-auto-model/opus-mt-tc-bible-big-deu_eng_fra_por_spa-bat
19+
samples/transformers-auto-model/opus-mt-tc-bible-big-gmw-deu_eng_fra_por_spa
20+
samples/ultralytics/yolov3-tinyu

graph_net/test/naive_decomposer_and_post_extract_process_test.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/bin/bash
2-
# bash graph_net/test/naive_decomposer_and_post_extract_process_test.sh
32

43
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
54
os.path.dirname(graph_net.__file__))")

graph_net/test/typical_sequence_decomposer_test.sh

100644100755
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
44
DECOMPOSE_PATH=/tmp/decompose_workspace
5+
# DECOMPOSE_PATH=$GRAPH_NET_ROOT/decompose_test_level5_100
56

67
mkdir -p "$DECOMPOSE_PATH"
78

8-
model_list="$GRAPH_NET_ROOT/graph_net/config/small100_torch_samples_list.txt"
9+
# model_list="$GRAPH_NET_ROOT/graph_net/config/small100_torch_samples_list.txt"
10+
model_list="$GRAPH_NET_ROOT/graph_net/test/dev_model_list/validation_error_model_list.txt"
911

1012
python3 -m graph_net.torch.typical_sequence_split_points \
1113
--model-list "$model_list" \
14+
--model-path-prefix "$GRAPH_NET_ROOT" \
1215
--device "cuda" \
1316
--window-size 10 \
1417
--fold-policy default \
@@ -54,4 +57,4 @@ python3 -m graph_net.torch.test_compiler \
5457

5558
python3 -m graph_net.plot_ESt \
5659
--benchmark-path "$DECOMPOSE_PATH/validation.log" \
57-
--output-dir "$DECOMPOSE_PATH"
60+
--output-dir "$DECOMPOSE_PATH"

graph_net/torch/fx_graph_parse_util.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,28 @@ def _rename_placeholder(name, pattern2replacement):
122122
return name
123123

124124

125-
def parse_sole_graph_module(module, inputs):
125+
def parse_sole_graph_module_without_varify(module, inputs):
126126
traced_module = None
127127
traced_sample_inputs = None
128128

129129
def my_backend(gm, sample_inputs):
130130
nonlocal traced_module
131-
traced_module = gm
132131
nonlocal traced_sample_inputs
132+
assert traced_module is None
133+
assert traced_sample_inputs is None
134+
traced_module = gm
133135
traced_sample_inputs = sample_inputs
134136
return gm.forward
135137

136138
torch.compile(module, backend=my_backend)(*inputs)
137139
assert traced_module is not None
140+
return traced_module, traced_sample_inputs
141+
142+
143+
def parse_sole_graph_module(module, inputs):
144+
traced_module, traced_sample_inputs = parse_sole_graph_module_without_varify(
145+
module, inputs
146+
)
138147

139148
def get_input_names_from_signature():
140149
return inspect.signature(module.forward).parameters

graph_net/torch/graph_decomposer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,6 @@ def forward(self, *args):
269269
if not self.extracted:
270270
if self.need_extract(self.submodule, args):
271271
self.builtin_extractor(self.submodule, args)
272-
self._post_extract_process()
273272
self.extracted = True
274273
return self.submodule(*args)
275274

graph_net/torch/typical_sequence_split_points.py

Lines changed: 25 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
import os
44
from pathlib import Path
55
from typing import Any, Dict, List
6-
76
import torch
87
import torch.nn as nn
9-
import tempfile
10-
import graph_net.imp_util
11-
from graph_net.torch import utils as graph_utils
128
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
9+
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
10+
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module_without_varify
1311

1412

1513
class TypicalSequenceExtractor:
@@ -28,9 +26,12 @@ def _extract_operators_from_graph(
2826

2927
if node.op == "call_module":
3028
target_name = type(named_modules[node.target]).__name__
31-
else:
29+
elif node.op == "call_method":
30+
target_name = f"Tensor.{node.target}"
31+
elif node.op == "call_function":
3232
target_name = getattr(node.target, "__name__", str(node.target))
33-
33+
else:
34+
raise NotImplementedError()
3435
operator_list.append(
3536
{
3637
"op_type": node.op,
@@ -48,39 +49,6 @@ def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor])
4849
return gm.forward
4950

5051

51-
class TypicalSequenceModelLoader:
52-
def load_class_from_file(self, model_path: str, device: str) -> Any:
53-
file_path = os.path.join(model_path, "model.py")
54-
55-
if not os.path.exists(file_path):
56-
raise FileNotFoundError(f"Model file not found: {file_path}")
57-
58-
with open(file_path, "r", encoding="utf-8") as f:
59-
model_code = f.read()
60-
model_code = graph_utils.modify_code_by_device(model_code, device)
61-
62-
with tempfile.NamedTemporaryFile(
63-
mode="w", suffix=".py", encoding="utf-8"
64-
) as temp_file:
65-
temp_file.write(model_code)
66-
module = graph_net.imp_util.load_module(temp_file.name)
67-
model_class = getattr(module, "GraphModule", None)
68-
69-
return model_class
70-
71-
def get_input_dict(self, model_path: str, device: str) -> Dict[str, torch.Tensor]:
72-
inputs_params = graph_utils.load_converted_from_text(f"{model_path}")
73-
params = inputs_params["weight_info"]
74-
for tensor_meta in params.values():
75-
if hasattr(tensor_meta, "device"):
76-
tensor_meta.device = device
77-
input_dict = {
78-
k: graph_utils.replay_tensor(v).to(torch.device(device))
79-
for k, v in params.items()
80-
}
81-
return input_dict
82-
83-
8452
class SplitAnalyzer:
8553
def __init__(
8654
self, window_size: int = 10, fold_policy: str = "default", fold_times: int = 0
@@ -109,20 +77,11 @@ def _resolve_token_to_ops(
10977
def _extract_ops_via_compile(
11078
self, model_path: str, device: str = "cpu"
11179
) -> List[str]:
112-
loader = TypicalSequenceModelLoader()
113-
print(f"Loading model from {model_path} on {device}...")
114-
try:
115-
model_class = loader.load_class_from_file(model_path, device)
116-
model = model_class().to(torch.device(device))
117-
model.eval()
118-
input_dict = loader.get_input_dict(model_path, device)
119-
except Exception as e:
120-
print(f"Error loading/preparing model {model_path}: {e}")
121-
return []
122-
80+
print(f"extracting ops from {model_path}")
12381
extractor = TypicalSequenceExtractor()
124-
compiled_model = torch.compile(model, backend=extractor.extract_compiler)
125-
compiled_model(**input_dict)
82+
model, inputs = get_torch_module_and_inputs(model_path)
83+
compiled_model, _ = parse_sole_graph_module_without_varify(model, inputs)
84+
extractor.extract_compiler(compiled_model, inputs)
12685
ops_info = extractor.extract_node
12786

12887
return [op["target_name"] for op in ops_info]
@@ -150,11 +109,13 @@ def get_len(tid):
150109
get_len(sym_id)
151110
return token2len
152111

153-
def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]:
112+
def analyze(
113+
self, model_path_prefix: str, model_paths_file: str, device: str
114+
) -> Dict[str, Dict]:
154115
input_file = Path(model_paths_file)
155116

156117
with open(input_file, "r") as f:
157-
model_paths = [
118+
rel_model_paths = [
158119
Path(line.strip())
159120
for line in f
160121
if line.strip() and not line.startswith("#")
@@ -163,15 +124,15 @@ def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]:
163124
inputs_seqs = []
164125
valid_models = []
165126

166-
for p in model_paths:
167-
seq = self._extract_ops_via_compile(str(p), device)
127+
for p in rel_model_paths:
128+
model_full_path = os.path.join(model_path_prefix, p)
129+
seq = self._extract_ops_via_compile(model_full_path, device)
168130
if seq:
169131
inputs_seqs.append(seq)
170132
valid_models.append((p.name, p))
171133

172134
if not inputs_seqs:
173135
return {}
174-
175136
rp_parser = RpExprParser(
176137
window_size=self.window_size,
177138
fold_policy=self.fold_policy,
@@ -264,7 +225,7 @@ def main(args):
264225
fold_policy=args.fold_policy,
265226
fold_times=args.fold_times,
266227
)
267-
results = analyzer.analyze(args.model_list, args.device)
228+
results = analyzer.analyze(args.model_path_prefix, args.model_list, args.device)
268229
if args.output_json:
269230
with open(args.output_json, "w") as f:
270231
json.dump(results, f, indent=4)
@@ -280,6 +241,12 @@ def main(args):
280241
required=True,
281242
help="Path to a text file containing paths to models (one per line).",
282243
)
244+
parser.add_argument(
245+
"--model-path-prefix",
246+
type=str,
247+
default="./",
248+
help="Prefix to add to each model path in the list.",
249+
)
283250
parser.add_argument(
284251
"--device",
285252
type=str,

0 commit comments

Comments
 (0)