Skip to content

Commit b59dc33

Browse files
committed
add backend
1 parent af92b86 commit b59dc33

File tree

2 files changed

+346
-2
lines changed

2 files changed

+346
-2
lines changed
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
import argparse
2+
import base64
3+
import importlib.util
4+
import inspect
5+
import itertools
6+
import json
7+
import os
8+
import subprocess
9+
import sys
10+
from pathlib import Path
11+
from typing import Any, Callable, Dict, List, Tuple
12+
13+
import torch
14+
import torch.nn as nn
15+
16+
import graph_net
17+
from graph_net.torch import utils as graph_utils
18+
from graph_net.torch.rp_expr.longest_rp_expr_parser import LongestRpExprParser
19+
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
20+
21+
22+
def encode_config(config: Dict[str, Any]) -> str:
23+
json_str = json.dumps(config)
24+
return base64.b64encode(json_str.encode("utf-8")).decode("utf-8")
25+
26+
27+
class GraphExtractor:
28+
def __init__(self):
29+
self.extract_node = []
30+
31+
def _extract_operators_from_graph(
32+
self, gm: nn.Module, example_inputs: List[torch.Tensor] = None
33+
) -> List[Dict[str, Any]]:
34+
operator_list = []
35+
named_modules = dict(gm.named_modules())
36+
37+
for node in gm.graph.nodes:
38+
if node.op in ("call_method", "call_function", "call_module"):
39+
target_name = str(node.target)
40+
41+
if node.op == "call_module":
42+
module_instance = named_modules.get(node.target)
43+
if module_instance is not None:
44+
target_name = type(module_instance).__name__
45+
elif node.op == "call_function":
46+
if isinstance(node.target, Callable):
47+
target_name = node.target.__name__
48+
elif node.op == "call_method":
49+
target_name = str(node.target)
50+
51+
operator_info = {
52+
"op_type": node.op,
53+
"target": node.target,
54+
"name": node.name,
55+
"target_name": target_name,
56+
}
57+
operator_list.append(operator_info)
58+
59+
return operator_list
60+
61+
def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
62+
operator = self._extract_operators_from_graph(gm, inputs)
63+
self.extract_node = operator
64+
return gm.forward
65+
66+
67+
class ModelLoader:
68+
def load_class_from_file(self, model_path: str, device: str) -> Any:
69+
file_path = os.path.join(model_path, "model.py")
70+
file = Path(file_path).resolve()
71+
module_name = file.stem
72+
73+
if not os.path.exists(file_path):
74+
raise FileNotFoundError(f"Model file not found: {file_path}")
75+
76+
with open(file_path, "r", encoding="utf-8") as f:
77+
model_code = f.read()
78+
79+
model_code = graph_utils.modify_code_by_device(model_code, device)
80+
81+
spec = importlib.util.spec_from_loader(module_name, loader=None)
82+
module = importlib.util.module_from_spec(spec)
83+
sys.modules[module_name] = module
84+
85+
compiled_code = compile(model_code, filename=file, mode="exec")
86+
exec(compiled_code, module.__dict__)
87+
88+
model_class = getattr(module, "GraphModule", None)
89+
if model_class is None:
90+
raise ImportError(f"Class 'GraphModule' not found in {file_path}")
91+
92+
return model_class
93+
94+
def get_input_dict(self, model_path: str, device: str) -> Dict[str, torch.Tensor]:
95+
inputs_params = graph_utils.load_converted_from_text(f"{model_path}")
96+
params = inputs_params["weight_info"]
97+
for tensor_meta in params.values():
98+
if hasattr(tensor_meta, "device"):
99+
tensor_meta.device = device
100+
input_dict = {
101+
k: graph_utils.replay_tensor(v).to(torch.device(device))
102+
for k, v in params.items()
103+
}
104+
return input_dict
105+
106+
107+
class RangeDecomposerBackend:
108+
def __init__(self):
109+
self.window_size = 10
110+
self.graph_net_root = Path(graph_net.__file__).parent
111+
self.workspace_root = Path.cwd() / "naive_decompose_workspace"
112+
113+
def _resolve_token_to_ops(
114+
self, tid, num_primitives, token_id2primitive_id, symbol_map
115+
) -> List[str]:
116+
if tid < num_primitives:
117+
return [token_id2primitive_id[tid]]
118+
if tid in symbol_map:
119+
sub_tokens = symbol_map[tid].tolist()
120+
ops = []
121+
for t in sub_tokens:
122+
ops.extend(
123+
self._resolve_token_to_ops(
124+
t, num_primitives, token_id2primitive_id, symbol_map
125+
)
126+
)
127+
return ops
128+
return [f"Unknown({tid})"]
129+
130+
def _extract_ops_via_compile(
131+
self, model_path: str, device: str = "cpu"
132+
) -> List[str]:
133+
loader = ModelLoader()
134+
print(f"Loading model from {model_path} on {device}...")
135+
try:
136+
model_class = loader.load_class_from_file(model_path, device)
137+
model = model_class().to(torch.device(device))
138+
model.eval()
139+
input_dict = loader.get_input_dict(model_path, device)
140+
except Exception as e:
141+
print(f"Error loading/preparing model {model_path}: {e}")
142+
return []
143+
144+
extractor = GraphExtractor()
145+
compiled_model = torch.compile(model, backend=extractor.extract_compiler)
146+
147+
with torch.no_grad():
148+
compiled_model(**input_dict)
149+
150+
ops_info = extractor.extract_node
151+
if not ops_info:
152+
print(f"Warning: No operators extracted from {model_path}.")
153+
return []
154+
return [op["target_name"] for op in ops_info]
155+
156+
def _calculate_token_lengths(
157+
self, rp_expr, num_primitives, symbol_map
158+
) -> Dict[int, int]:
159+
token2len = {}
160+
161+
def get_len(tid):
162+
if tid in token2len:
163+
return token2len[tid]
164+
if tid < num_primitives:
165+
token2len[tid] = 1
166+
return 1
167+
if tid in symbol_map:
168+
sub_tokens = symbol_map[tid].tolist()
169+
length = sum(get_len(t) for t in sub_tokens)
170+
token2len[tid] = length
171+
return length
172+
token2len[tid] = 1
173+
return 1
174+
175+
for sym_id in rp_expr.symbol_token_ids:
176+
get_len(sym_id)
177+
return token2len
178+
179+
def _analyze_and_get_splits(self, args) -> Dict[str, Dict]:
180+
input_file = Path(args.model_path)
181+
if not input_file.exists():
182+
print(f"Error: Input file {input_file} does not exist.")
183+
return {}
184+
185+
with open(input_file, "r") as f:
186+
model_paths = [
187+
Path(line.strip())
188+
for line in f
189+
if line.strip() and not line.startswith("#")
190+
]
191+
192+
if not model_paths:
193+
print("No valid model paths found.")
194+
return {}
195+
196+
inputs_seqs = []
197+
valid_models = []
198+
199+
for p in model_paths:
200+
seq = self._extract_ops_via_compile(p, args.device)
201+
if seq:
202+
inputs_seqs.append(seq)
203+
valid_models.append((p.name, p))
204+
205+
if not inputs_seqs:
206+
return {}
207+
208+
rp_parser = RpExprParser(
209+
window_size=self.window_size, fold_policy="default", fold_times=0
210+
)
211+
rp_expr, token_id2primitive_id = rp_parser(inputs_seqs)
212+
rp_expr.try_unwrap_body_of_sole_symbol_token()
213+
rp_expr.try_recursive_inline_symbol_sole_used(token_id2primitive_id)
214+
num_primitives = len(token_id2primitive_id)
215+
symbol_map = dict(zip(rp_expr.symbol_token_ids, rp_expr.symbol_token_tensors))
216+
token2len = self._calculate_token_lengths(rp_expr, num_primitives, symbol_map)
217+
218+
results = {}
219+
220+
for i, (model_name, original_path) in enumerate(valid_models):
221+
if i >= len(rp_expr.body_rp_expr):
222+
break
223+
224+
target_body_tensor = rp_expr.body_rp_expr[i]
225+
seq_tokens = target_body_tensor.tolist()
226+
227+
full_model_ops = []
228+
for t in seq_tokens:
229+
full_model_ops.extend(
230+
self._resolve_token_to_ops(
231+
t, num_primitives, token_id2primitive_id, symbol_map
232+
)
233+
)
234+
235+
current_idx = 0
236+
split_points_set = set()
237+
total_len = sum(token2len.get(t, 1) for t in seq_tokens)
238+
239+
for token_id in seq_tokens:
240+
length = token2len.get(token_id, 1)
241+
is_pattern = token_id >= num_primitives
242+
243+
if is_pattern:
244+
if current_idx > 0:
245+
split_points_set.add(current_idx)
246+
end_idx = current_idx + length
247+
if end_idx < total_len:
248+
split_points_set.add(end_idx)
249+
250+
current_idx += length
251+
252+
sorted_splits = sorted(list(split_points_set))
253+
254+
self._print_analysis(
255+
model_name, original_path, sorted_splits, total_len, full_model_ops
256+
)
257+
258+
results[model_name] = {
259+
"path": str(original_path),
260+
"split_points": sorted_splits,
261+
}
262+
263+
return results
264+
265+
def _print_analysis(self, name, path, splits, total_len, full_ops):
266+
print("=" * 60)
267+
print(f"Model: {name}")
268+
print(f"Path: {path}")
269+
print(f"Splits: {splits}")
270+
print("-" * 60)
271+
272+
last_split = 0
273+
for split in splits + [total_len]:
274+
segment_len = split - last_split
275+
276+
start_safe = min(last_split, len(full_ops))
277+
end_safe = min(split, len(full_ops))
278+
segment_ops = full_ops[start_safe:end_safe]
279+
280+
ops_display = str(segment_ops)
281+
if len(segment_ops) > 5:
282+
ops_display = f"[{segment_ops[0]}, ..., {segment_ops[-1]}]"
283+
284+
print(
285+
f" Range [{last_split:3d}, {split:3d}), Len: {segment_len:3d} | Ops: {ops_display}"
286+
)
287+
last_split = split
288+
print("\n")
289+
290+
def __call__(self, args):
291+
model_data_map = self._analyze_and_get_splits(args)
292+
293+
for model_name, info in model_data_map.items():
294+
model_path = info["path"]
295+
split_points = info["split_points"]
296+
297+
model_output_dir = self.workspace_root / f"{model_name}_decomposed"
298+
model_output_dir.mkdir(parents=True, exist_ok=True)
299+
300+
config_dict = {
301+
"decorator_path": str(self.graph_net_root / "torch/extractor.py"),
302+
"decorator_config": {
303+
"name": model_name,
304+
"custom_extractor_path": str(
305+
self.graph_net_root / "torch/naive_graph_decomposer.py"
306+
),
307+
"custom_extractor_config": {
308+
"output_dir": str(model_output_dir),
309+
"split_positions": split_points,
310+
"group_head_and_tail": True,
311+
"filter_path": str(
312+
self.graph_net_root / "torch/naive_subgraph_filter.py"
313+
),
314+
"filter_config": {},
315+
},
316+
},
317+
}
318+
319+
encoded_config = encode_config(config_dict)
320+
321+
cmd = [
322+
sys.executable,
323+
"-m",
324+
"graph_net.torch.run_model",
325+
"--model-path",
326+
model_path,
327+
"--decorator-config",
328+
encoded_config,
329+
]
330+
331+
try:
332+
subprocess.run(cmd, check=True)
333+
print(f" [Success] Saved to {model_output_dir}")
334+
except subprocess.CalledProcessError as e:
335+
print(f" [Error] Process failed: {e}")
336+
except Exception as e:
337+
print(f" [Error] Unexpected: {e}")

graph_net/torch/test_compiler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from graph_net.torch.backend.blade_disc_backend import BladeDISCBackend
2424
from graph_net.torch.backend.nope_backend import NopeBackend
2525
from graph_net.torch.backend.unstable_to_stable_backend import UnstableToStableBackend
26+
from graph_net.torch.backend.range_decomposer_backend import RangeDecomposerBackend
2627
from graph_net.torch.backend.range_decomposer_validator_backend import (
2728
RangeDecomposerValidatorBackend,
2829
)
@@ -39,6 +40,7 @@
3940
"bladedisc": BladeDISCBackend(),
4041
"nope": NopeBackend(),
4142
"unstable_to_stable": UnstableToStableBackend(),
43+
"range_decomposer": RangeDecomposerBackend(),
4244
"range_decomposer_validator": RangeDecomposerValidatorBackend(),
4345
}
4446

@@ -385,11 +387,16 @@ def test_multi_models(args):
385387

386388

387389
def main(args):
388-
assert os.path.isdir(args.model_path)
389-
390390
initalize_seed = 123
391391
set_seed(random_seed=initalize_seed)
392392

393+
if args.compiler == "range_decomposer":
394+
compiler = get_compiler_backend(args)
395+
compiler(args)
396+
return
397+
398+
assert os.path.isdir(args.model_path)
399+
393400
if path_utils.is_single_model_dir(args.model_path):
394401
test_single_model(args)
395402
else:

0 commit comments

Comments
 (0)