Skip to content

Commit 9ce981b

Browse files
committed
merge
2 parents 55929e1 + 201d5b9 commit 9ce981b

File tree

6 files changed

+35
-77
lines changed

6 files changed

+35
-77
lines changed
Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,20 @@
1-
samples/ultralytics/yolov3-tinyu
2-
samples/torchgeometric/RECT_L
3-
samples/transformers-auto-model/opus-mt-en-gv
41
samples/timm/convnextv2_base.fcmae_ft_in1k
5-
samples/torchvision/vgg16_bn
6-
samples/timm/regnety_080_tv.tv2_in1k
2+
samples/timm/hgnet_tiny.paddle_in1k
73
samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k
4+
samples/timm/regnety_080_tv.tv2_in1k
5+
samples/timm/res2net50_14w_8s.in1k
6+
samples/torchaudio/wavlm_base
7+
samples/torchgeometric/RECT_L
8+
samples/torchvision/vgg16_bn
9+
samples/transformers-auto-model/bge-small-en-v1.5
810
samples/transformers-auto-model/distilbert_distilbert-base-multilingual-cased
911
samples/transformers-auto-model/OFA-Sys_chinese-clip-vit-large-patch14
10-
samples/transformers-auto-model/opus-mt-tc-bible-big-deu_eng_fra_por_spa-bat
11-
samples/transformers-auto-model/opus-mt-en-tw
12-
samples/timm/hgnet_tiny.paddle_in1k
1312
samples/transformers-auto-model/opus-mt-ase-es
14-
samples/timm/resnetv2_18.ra4e3600r224_in1k
15-
samples/transformers-auto-model/TinyLlama_TinyLlama-1.1B-Chat-v0.4
16-
samples/timm/resnetaa50d.din12k
17-
samples/transformers-auto-model/bge-small-en-v1.5
18-
samples/timm/res2net50_14w_8s.in1k
19-
samples/torchaudio/wavlm_base
20-
samples/transformers-auto-model/opus-mt-en-sal
13+
samples/transformers-auto-model/opus-mt-en-gv
2114
samples/transformers-auto-model/opus-mt-en-phi
15+
samples/transformers-auto-model/opus-mt-en-sal
16+
samples/transformers-auto-model/opus-mt-en-tw
2217
samples/transformers-auto-model/opus-mt-fi-niu
18+
samples/transformers-auto-model/opus-mt-tc-bible-big-deu_eng_fra_por_spa-bat
2319
samples/transformers-auto-model/opus-mt-tc-bible-big-gmw-deu_eng_fra_por_spa
24-
samples/transformers-auto-model/opus-mt-NORTHEU-NORTHEU
25-
samples/transformers-auto-model/sentence-transformers_paraphrase-distilroberta-base-v1
26-
samples/transformers-auto-model/TrustSafeAI_RADARVicuna7B
27-
20+
samples/ultralytics/yolov3-tinyu

graph_net/test/naive_decomposer_and_post_extract_process_test.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/bin/bash
2-
# bash graph_net/test/naive_decomposer_and_post_extract_process_test.sh
32

43
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
54
os.path.dirname(graph_net.__file__))")

graph_net/test/typical_sequence_decomposer_test.sh

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,4 @@ python3 -m graph_net.torch.test_compiler \
5757

5858
python3 -m graph_net.plot_ESt \
5959
--benchmark-path "$DECOMPOSE_PATH/validation.log" \
60-
--output-dir "$DECOMPOSE_PATH"
60+
--output-dir "$DECOMPOSE_PATH"

graph_net/torch/fx_graph_parse_util.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,28 @@ def _rename_placeholder(name, pattern2replacement):
122122
return name
123123

124124

125-
def parse_sole_graph_module(module, inputs):
125+
def parse_sole_graph_module_without_varify(module, inputs):
126126
traced_module = None
127127
traced_sample_inputs = None
128128

129129
def my_backend(gm, sample_inputs):
130130
nonlocal traced_module
131-
traced_module = gm
132131
nonlocal traced_sample_inputs
132+
assert traced_module is None
133+
assert traced_sample_inputs is None
134+
traced_module = gm
133135
traced_sample_inputs = sample_inputs
134136
return gm.forward
135137

136138
torch.compile(module, backend=my_backend)(*inputs)
137139
assert traced_module is not None
140+
return traced_module, traced_sample_inputs
141+
142+
143+
def parse_sole_graph_module(module, inputs):
144+
traced_module, traced_sample_inputs = parse_sole_graph_module_without_varify(
145+
module, inputs
146+
)
138147

139148
def get_input_names_from_signature():
140149
return inspect.signature(module.forward).parameters

graph_net/torch/graph_decomposer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ def forward(self, *args):
262262
if not self.extracted:
263263
if self.need_extract(self.submodule, args):
264264
self.builtin_extractor(self.submodule, args)
265-
self._post_extract_process()
266265
self.extracted = True
267266
return self.submodule(*args)
268267

graph_net/torch/typical_sequence_split_points.py

Lines changed: 11 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
import os
44
from pathlib import Path
55
from typing import Any, Dict, List
6-
76
import torch
87
import torch.nn as nn
9-
import tempfile
10-
import graph_net.imp_util
11-
from graph_net.torch import utils as graph_utils
128
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
9+
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
10+
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module_without_varify
1311

1412

1513
class TypicalSequenceExtractor:
@@ -28,9 +26,12 @@ def _extract_operators_from_graph(
2826

2927
if node.op == "call_module":
3028
target_name = type(named_modules[node.target]).__name__
31-
else:
29+
elif node.op == "call_method":
30+
target_name = f"Tensor.{node.target}"
31+
elif node.op == "call_function":
3232
target_name = getattr(node.target, "__name__", str(node.target))
33-
33+
else:
34+
raise NotImplementedError()
3435
operator_list.append(
3536
{
3637
"op_type": node.op,
@@ -48,39 +49,6 @@ def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor])
4849
return gm.forward
4950

5051

51-
class TypicalSequenceModelLoader:
52-
def load_class_from_file(self, model_path: str, device: str) -> Any:
53-
file_path = os.path.join(model_path, "model.py")
54-
55-
if not os.path.exists(file_path):
56-
raise FileNotFoundError(f"Model file not found: {file_path}")
57-
58-
with open(file_path, "r", encoding="utf-8") as f:
59-
model_code = f.read()
60-
model_code = graph_utils.modify_code_by_device(model_code, device)
61-
62-
with tempfile.NamedTemporaryFile(
63-
mode="w", suffix=".py", encoding="utf-8"
64-
) as temp_file:
65-
temp_file.write(model_code)
66-
module = graph_net.imp_util.load_module(temp_file.name)
67-
model_class = getattr(module, "GraphModule", None)
68-
69-
return model_class
70-
71-
def get_input_dict(self, model_path: str, device: str) -> Dict[str, torch.Tensor]:
72-
inputs_params = graph_utils.load_converted_from_text(f"{model_path}")
73-
params = inputs_params["weight_info"]
74-
for tensor_meta in params.values():
75-
if hasattr(tensor_meta, "device"):
76-
tensor_meta.device = device
77-
input_dict = {
78-
k: graph_utils.replay_tensor(v).to(torch.device(device))
79-
for k, v in params.items()
80-
}
81-
return input_dict
82-
83-
8452
class SplitAnalyzer:
8553
def __init__(
8654
self, window_size: int = 10, fold_policy: str = "default", fold_times: int = 0
@@ -109,20 +77,11 @@ def _resolve_token_to_ops(
10977
def _extract_ops_via_compile(
11078
self, model_path: str, device: str = "cpu"
11179
) -> List[str]:
112-
loader = TypicalSequenceModelLoader()
113-
print(f"Loading model from {model_path} on {device}...")
114-
try:
115-
model_class = loader.load_class_from_file(model_path, device)
116-
model = model_class().to(torch.device(device))
117-
model.eval()
118-
input_dict = loader.get_input_dict(model_path, device)
119-
except Exception as e:
120-
print(f"Error loading/preparing model {model_path}: {e}")
121-
return []
122-
80+
print(f"extracting ops from {model_path}")
12381
extractor = TypicalSequenceExtractor()
124-
compiled_model = torch.compile(model, backend=extractor.extract_compiler)
125-
compiled_model(**input_dict)
82+
model, inputs = get_torch_module_and_inputs(model_path)
83+
compiled_model, _ = parse_sole_graph_module_without_varify(model, inputs)
84+
extractor.extract_compiler(compiled_model, inputs)
12685
ops_info = extractor.extract_node
12786

12887
return [op["target_name"] for op in ops_info]
@@ -174,7 +133,6 @@ def analyze(
174133

175134
if not inputs_seqs:
176135
return {}
177-
178136
rp_parser = RpExprParser(
179137
window_size=self.window_size,
180138
fold_policy=self.fold_policy,

0 commit comments

Comments
 (0)