Skip to content

Commit a47d29a

Browse files
committed
split
1 parent b59dc33 commit a47d29a

File tree

2 files changed

+295
-258
lines changed

2 files changed

+295
-258
lines changed

graph_net/torch/backend/range_decomposer_backend.py

Lines changed: 0 additions & 258 deletions
Original file line numberDiff line numberDiff line change
@@ -24,269 +24,11 @@ def encode_config(config: Dict[str, Any]) -> str:
2424
return base64.b64encode(json_str.encode("utf-8")).decode("utf-8")
2525

2626

27-
class GraphExtractor:
28-
def __init__(self):
29-
self.extract_node = []
30-
31-
def _extract_operators_from_graph(
32-
self, gm: nn.Module, example_inputs: List[torch.Tensor] = None
33-
) -> List[Dict[str, Any]]:
34-
operator_list = []
35-
named_modules = dict(gm.named_modules())
36-
37-
for node in gm.graph.nodes:
38-
if node.op in ("call_method", "call_function", "call_module"):
39-
target_name = str(node.target)
40-
41-
if node.op == "call_module":
42-
module_instance = named_modules.get(node.target)
43-
if module_instance is not None:
44-
target_name = type(module_instance).__name__
45-
elif node.op == "call_function":
46-
if isinstance(node.target, Callable):
47-
target_name = node.target.__name__
48-
elif node.op == "call_method":
49-
target_name = str(node.target)
50-
51-
operator_info = {
52-
"op_type": node.op,
53-
"target": node.target,
54-
"name": node.name,
55-
"target_name": target_name,
56-
}
57-
operator_list.append(operator_info)
58-
59-
return operator_list
60-
61-
def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
62-
operator = self._extract_operators_from_graph(gm, inputs)
63-
self.extract_node = operator
64-
return gm.forward
65-
66-
67-
class ModelLoader:
68-
def load_class_from_file(self, model_path: str, device: str) -> Any:
69-
file_path = os.path.join(model_path, "model.py")
70-
file = Path(file_path).resolve()
71-
module_name = file.stem
72-
73-
if not os.path.exists(file_path):
74-
raise FileNotFoundError(f"Model file not found: {file_path}")
75-
76-
with open(file_path, "r", encoding="utf-8") as f:
77-
model_code = f.read()
78-
79-
model_code = graph_utils.modify_code_by_device(model_code, device)
80-
81-
spec = importlib.util.spec_from_loader(module_name, loader=None)
82-
module = importlib.util.module_from_spec(spec)
83-
sys.modules[module_name] = module
84-
85-
compiled_code = compile(model_code, filename=file, mode="exec")
86-
exec(compiled_code, module.__dict__)
87-
88-
model_class = getattr(module, "GraphModule", None)
89-
if model_class is None:
90-
raise ImportError(f"Class 'GraphModule' not found in {file_path}")
91-
92-
return model_class
93-
94-
def get_input_dict(self, model_path: str, device: str) -> Dict[str, torch.Tensor]:
95-
inputs_params = graph_utils.load_converted_from_text(f"{model_path}")
96-
params = inputs_params["weight_info"]
97-
for tensor_meta in params.values():
98-
if hasattr(tensor_meta, "device"):
99-
tensor_meta.device = device
100-
input_dict = {
101-
k: graph_utils.replay_tensor(v).to(torch.device(device))
102-
for k, v in params.items()
103-
}
104-
return input_dict
105-
106-
10727
class RangeDecomposerBackend:
10828
def __init__(self):
109-
self.window_size = 10
11029
self.graph_net_root = Path(graph_net.__file__).parent
11130
self.workspace_root = Path.cwd() / "naive_decompose_workspace"
11231

113-
def _resolve_token_to_ops(
114-
self, tid, num_primitives, token_id2primitive_id, symbol_map
115-
) -> List[str]:
116-
if tid < num_primitives:
117-
return [token_id2primitive_id[tid]]
118-
if tid in symbol_map:
119-
sub_tokens = symbol_map[tid].tolist()
120-
ops = []
121-
for t in sub_tokens:
122-
ops.extend(
123-
self._resolve_token_to_ops(
124-
t, num_primitives, token_id2primitive_id, symbol_map
125-
)
126-
)
127-
return ops
128-
return [f"Unknown({tid})"]
129-
130-
def _extract_ops_via_compile(
131-
self, model_path: str, device: str = "cpu"
132-
) -> List[str]:
133-
loader = ModelLoader()
134-
print(f"Loading model from {model_path} on {device}...")
135-
try:
136-
model_class = loader.load_class_from_file(model_path, device)
137-
model = model_class().to(torch.device(device))
138-
model.eval()
139-
input_dict = loader.get_input_dict(model_path, device)
140-
except Exception as e:
141-
print(f"Error loading/preparing model {model_path}: {e}")
142-
return []
143-
144-
extractor = GraphExtractor()
145-
compiled_model = torch.compile(model, backend=extractor.extract_compiler)
146-
147-
with torch.no_grad():
148-
compiled_model(**input_dict)
149-
150-
ops_info = extractor.extract_node
151-
if not ops_info:
152-
print(f"Warning: No operators extracted from {model_path}.")
153-
return []
154-
return [op["target_name"] for op in ops_info]
155-
156-
def _calculate_token_lengths(
157-
self, rp_expr, num_primitives, symbol_map
158-
) -> Dict[int, int]:
159-
token2len = {}
160-
161-
def get_len(tid):
162-
if tid in token2len:
163-
return token2len[tid]
164-
if tid < num_primitives:
165-
token2len[tid] = 1
166-
return 1
167-
if tid in symbol_map:
168-
sub_tokens = symbol_map[tid].tolist()
169-
length = sum(get_len(t) for t in sub_tokens)
170-
token2len[tid] = length
171-
return length
172-
token2len[tid] = 1
173-
return 1
174-
175-
for sym_id in rp_expr.symbol_token_ids:
176-
get_len(sym_id)
177-
return token2len
178-
179-
def _analyze_and_get_splits(self, args) -> Dict[str, Dict]:
180-
input_file = Path(args.model_path)
181-
if not input_file.exists():
182-
print(f"Error: Input file {input_file} does not exist.")
183-
return {}
184-
185-
with open(input_file, "r") as f:
186-
model_paths = [
187-
Path(line.strip())
188-
for line in f
189-
if line.strip() and not line.startswith("#")
190-
]
191-
192-
if not model_paths:
193-
print("No valid model paths found.")
194-
return {}
195-
196-
inputs_seqs = []
197-
valid_models = []
198-
199-
for p in model_paths:
200-
seq = self._extract_ops_via_compile(p, args.device)
201-
if seq:
202-
inputs_seqs.append(seq)
203-
valid_models.append((p.name, p))
204-
205-
if not inputs_seqs:
206-
return {}
207-
208-
rp_parser = RpExprParser(
209-
window_size=self.window_size, fold_policy="default", fold_times=0
210-
)
211-
rp_expr, token_id2primitive_id = rp_parser(inputs_seqs)
212-
rp_expr.try_unwrap_body_of_sole_symbol_token()
213-
rp_expr.try_recursive_inline_symbol_sole_used(token_id2primitive_id)
214-
num_primitives = len(token_id2primitive_id)
215-
symbol_map = dict(zip(rp_expr.symbol_token_ids, rp_expr.symbol_token_tensors))
216-
token2len = self._calculate_token_lengths(rp_expr, num_primitives, symbol_map)
217-
218-
results = {}
219-
220-
for i, (model_name, original_path) in enumerate(valid_models):
221-
if i >= len(rp_expr.body_rp_expr):
222-
break
223-
224-
target_body_tensor = rp_expr.body_rp_expr[i]
225-
seq_tokens = target_body_tensor.tolist()
226-
227-
full_model_ops = []
228-
for t in seq_tokens:
229-
full_model_ops.extend(
230-
self._resolve_token_to_ops(
231-
t, num_primitives, token_id2primitive_id, symbol_map
232-
)
233-
)
234-
235-
current_idx = 0
236-
split_points_set = set()
237-
total_len = sum(token2len.get(t, 1) for t in seq_tokens)
238-
239-
for token_id in seq_tokens:
240-
length = token2len.get(token_id, 1)
241-
is_pattern = token_id >= num_primitives
242-
243-
if is_pattern:
244-
if current_idx > 0:
245-
split_points_set.add(current_idx)
246-
end_idx = current_idx + length
247-
if end_idx < total_len:
248-
split_points_set.add(end_idx)
249-
250-
current_idx += length
251-
252-
sorted_splits = sorted(list(split_points_set))
253-
254-
self._print_analysis(
255-
model_name, original_path, sorted_splits, total_len, full_model_ops
256-
)
257-
258-
results[model_name] = {
259-
"path": str(original_path),
260-
"split_points": sorted_splits,
261-
}
262-
263-
return results
264-
265-
def _print_analysis(self, name, path, splits, total_len, full_ops):
266-
print("=" * 60)
267-
print(f"Model: {name}")
268-
print(f"Path: {path}")
269-
print(f"Splits: {splits}")
270-
print("-" * 60)
271-
272-
last_split = 0
273-
for split in splits + [total_len]:
274-
segment_len = split - last_split
275-
276-
start_safe = min(last_split, len(full_ops))
277-
end_safe = min(split, len(full_ops))
278-
segment_ops = full_ops[start_safe:end_safe]
279-
280-
ops_display = str(segment_ops)
281-
if len(segment_ops) > 5:
282-
ops_display = f"[{segment_ops[0]}, ..., {segment_ops[-1]}]"
283-
284-
print(
285-
f" Range [{last_split:3d}, {split:3d}), Len: {segment_len:3d} | Ops: {ops_display}"
286-
)
287-
last_split = split
288-
print("\n")
289-
29032
def __call__(self, args):
29133
model_data_map = self._analyze_and_get_splits(args)
29234

0 commit comments

Comments
 (0)