|
1 | | -import argparse |
2 | 1 | import base64 |
3 | | -import importlib.util |
4 | | -import inspect |
5 | | -import itertools |
6 | 2 | import json |
7 | | -import os |
8 | 3 | import subprocess |
9 | 4 | import sys |
10 | 5 | from pathlib import Path |
11 | | -from typing import Any, Callable, Dict, List, Tuple |
| 6 | +from typing import Any, Dict |
12 | 7 |
|
13 | 8 | import torch |
14 | | -import torch.nn as nn |
15 | | - |
16 | 9 | import graph_net |
17 | | -from graph_net.torch import utils as graph_utils |
18 | | -from graph_net.torch.rp_expr.longest_rp_expr_parser import LongestRpExprParser |
19 | | -from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser |
| 10 | + |
| 11 | + |
| 12 | +def convert_to_dict(config_str): |
| 13 | + if config_str is None: |
| 14 | + return {} |
| 15 | + config_str = base64.b64decode(config_str).decode("utf-8") |
| 16 | + config = json.loads(config_str) |
| 17 | + assert isinstance(config, dict), f"config should be a dict. {config_str=}" |
| 18 | + return config |
20 | 19 |
|
21 | 20 |
|
22 | 21 | def encode_config(config: Dict[str, Any]) -> str: |
23 | 22 | json_str = json.dumps(config) |
24 | 23 | return base64.b64encode(json_str.encode("utf-8")).decode("utf-8") |
25 | 24 |
|
26 | 25 |
|
| 26 | +def load_json(file_path): |
| 27 | + with open(file_path, "r", encoding="utf-8") as file: |
| 28 | + data_dict = json.load(file) |
| 29 | + return data_dict |
| 30 | + |
| 31 | + |
27 | 32 | class RangeDecomposerBackend: |
28 | 33 | def __init__(self): |
29 | 34 | self.graph_net_root = Path(graph_net.__file__).parent |
30 | | - self.workspace_root = Path.cwd() / "naive_decompose_workspace" |
31 | 35 |
|
32 | | - def __call__(self, args): |
33 | | - model_data_map = self._analyze_and_get_splits(args) |
| 36 | + def __call__(self, model: torch.nn.Module) -> torch.nn.Module: |
| 37 | + config = convert_to_dict(self.config) |
| 38 | + workspace_path = Path(config["workspace_path"]) |
| 39 | + chain_style = config["chain_style"] |
34 | 40 |
|
35 | | - for model_name, info in model_data_map.items(): |
36 | | - model_path = info["path"] |
37 | | - split_points = info["split_points"] |
| 41 | + model_file_path = Path(model.__class__.__graph_net_file_path__) |
| 42 | + model_name = model_file_path.parent.name |
38 | 43 |
|
39 | | - model_output_dir = self.workspace_root / f"{model_name}_decomposed" |
40 | | - model_output_dir.mkdir(parents=True, exist_ok=True) |
| 44 | + model_info = load_json(config["split_results_path"])[model_name] |
| 45 | + model_path = model_info["path"] |
| 46 | + split_points = model_info["split_points"] |
41 | 47 |
|
42 | | - config_dict = { |
43 | | - "decorator_path": str(self.graph_net_root / "torch/extractor.py"), |
44 | | - "decorator_config": { |
45 | | - "name": model_name, |
46 | | - "custom_extractor_path": str( |
47 | | - self.graph_net_root / "torch/naive_graph_decomposer.py" |
| 48 | + model_output_dir = workspace_path / f"{model_name}_decomposed" |
| 49 | + model_output_dir.mkdir(parents=True, exist_ok=True) |
| 50 | + |
| 51 | + config_dict = { |
| 52 | + "decorator_path": str(self.graph_net_root / "torch/extractor.py"), |
| 53 | + "decorator_config": { |
| 54 | + "name": model_name, |
| 55 | + "custom_extractor_path": str( |
| 56 | + self.graph_net_root / "torch/naive_graph_decomposer.py" |
| 57 | + ), |
| 58 | + "custom_extractor_config": { |
| 59 | + "output_dir": str(model_output_dir), |
| 60 | + "split_positions": split_points, |
| 61 | + "group_head_and_tail": True, |
| 62 | + "filter_path": str( |
| 63 | + self.graph_net_root / "torch/naive_subgraph_filter.py" |
48 | 64 | ), |
49 | | - "custom_extractor_config": { |
50 | | - "output_dir": str(model_output_dir), |
51 | | - "split_positions": split_points, |
52 | | - "group_head_and_tail": True, |
53 | | - "filter_path": str( |
54 | | - self.graph_net_root / "torch/naive_subgraph_filter.py" |
55 | | - ), |
56 | | - "filter_config": {}, |
57 | | - }, |
| 65 | + "filter_config": {}, |
| 66 | + "chain_style": chain_style, |
58 | 67 | }, |
59 | | - } |
60 | | - |
61 | | - encoded_config = encode_config(config_dict) |
62 | | - |
63 | | - cmd = [ |
64 | | - sys.executable, |
65 | | - "-m", |
66 | | - "graph_net.torch.run_model", |
67 | | - "--model-path", |
68 | | - model_path, |
69 | | - "--decorator-config", |
70 | | - encoded_config, |
71 | | - ] |
72 | | - |
73 | | - try: |
74 | | - subprocess.run(cmd, check=True) |
75 | | - print(f" [Success] Saved to {model_output_dir}") |
76 | | - except subprocess.CalledProcessError as e: |
77 | | - print(f" [Error] Process failed: {e}") |
78 | | - except Exception as e: |
79 | | - print(f" [Error] Unexpected: {e}") |
| 68 | + }, |
| 69 | + } |
| 70 | + |
| 71 | + encoded_config = encode_config(config_dict) |
| 72 | + |
| 73 | + cmd = [ |
| 74 | + sys.executable, |
| 75 | + "-m", |
| 76 | + "graph_net.torch.run_model", |
| 77 | + "--model-path", |
| 78 | + model_path, |
| 79 | + "--decorator-config", |
| 80 | + encoded_config, |
| 81 | + ] |
| 82 | + |
| 83 | + try: |
| 84 | + subprocess.run(cmd, check=True) |
| 85 | + print(f"[Success] Saved to {model_output_dir}") |
| 86 | + except subprocess.CalledProcessError as e: |
| 87 | + print(f"[Error] Process failed: {e}") |
| 88 | + except Exception as e: |
| 89 | + print(f"[Error] Unexpected: {e}") |
| 90 | + return model |
| 91 | + |
| 92 | + def synchronize(self): |
| 93 | + if torch.cuda.is_available(): |
| 94 | + torch.cuda.synchronize() |
0 commit comments