Skip to content

Commit e5abd5f

Browse files
committed
fix
1 parent 457c4af commit e5abd5f

File tree

4 files changed

+129
-67
lines changed

4 files changed

+129
-67
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
5+
MODEL1="$GRAPH_NET_ROOT/samples/torchvision/resnet18"
6+
MODEL2="$GRAPH_NET_ROOT/samples/torchvision/resnet34"
7+
MODEL_LIST_FILE=$(mktemp)
8+
echo "$MODEL1" > "$MODEL_LIST_FILE"
9+
echo "$MODEL2" >> "$MODEL_LIST_FILE"
10+
11+
python3 -m graph_net.torch.typical_sequence_split_points \
12+
--model-list "$MODEL_LIST_FILE" \
13+
--device "cuda" \
14+
--window-size 10 \
15+
--output-json "$GRAPH_NET_ROOT/split_results.json"
16+
17+
rm -f "$MODEL_LIST_FILE"
18+
19+
20+
MODEL_PATH_IN_SAMPLES=/torchvision/resnet18
21+
MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES")
22+
23+
decomposer_config_json_str=$(cat <<EOF
24+
{
25+
"split_results_path": "$GRAPH_NET_ROOT/split_results.json",
26+
"workspace_path": "$GRAPH_NET_ROOT/decompose_workspace",
27+
"chain_style": "True"
28+
}
29+
EOF
30+
)
31+
DECOMPOSER_CONFIG=$(echo $decomposer_config_json_str | base64 -w 0)
32+
33+
python3 -m graph_net.torch.test_compiler --model-path $GRAPH_NET_ROOT/samples/$MODEL_PATH_IN_SAMPLES --compiler range_decomposer --device cuda --config=$DECOMPOSER_CONFIG
34+
35+
36+
DECOMPOSE_PATH=$GRAPH_NET_ROOT/decompose_workspace
37+
cp -r "$GRAPH_NET_ROOT/samples/$MODEL_PATH_IN_SAMPLES" "$DECOMPOSE_PATH/"
38+
39+
python3 -m graph_net.torch.test_compiler \
40+
--model-path $DECOMPOSE_PATH/$MODEL_NAME \
41+
--compiler range_decomposer_validator \
42+
--device cuda > "$DECOMPOSE_PATH/log.log" 2>&1
43+
44+
python3 -m graph_net.plot_ESt \
45+
--benchmark-path $DECOMPOSE_PATH/log.log \
46+
--output-dir $DECOMPOSE_PATH \
Lines changed: 70 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,94 @@
1-
import argparse
21
import base64
3-
import importlib.util
4-
import inspect
5-
import itertools
62
import json
7-
import os
83
import subprocess
94
import sys
105
from pathlib import Path
11-
from typing import Any, Callable, Dict, List, Tuple
6+
from typing import Any, Dict
127

138
import torch
14-
import torch.nn as nn
15-
169
import graph_net
17-
from graph_net.torch import utils as graph_utils
18-
from graph_net.torch.rp_expr.longest_rp_expr_parser import LongestRpExprParser
19-
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
10+
11+
12+
def convert_to_dict(config_str):
13+
if config_str is None:
14+
return {}
15+
config_str = base64.b64decode(config_str).decode("utf-8")
16+
config = json.loads(config_str)
17+
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
18+
return config
2019

2120

2221
def encode_config(config: Dict[str, Any]) -> str:
2322
json_str = json.dumps(config)
2423
return base64.b64encode(json_str.encode("utf-8")).decode("utf-8")
2524

2625

26+
def load_json(file_path):
27+
with open(file_path, "r", encoding="utf-8") as file:
28+
data_dict = json.load(file)
29+
return data_dict
30+
31+
2732
class RangeDecomposerBackend:
2833
def __init__(self):
2934
self.graph_net_root = Path(graph_net.__file__).parent
30-
self.workspace_root = Path.cwd() / "naive_decompose_workspace"
3135

32-
def __call__(self, args):
33-
model_data_map = self._analyze_and_get_splits(args)
36+
def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
37+
config = convert_to_dict(self.config)
38+
workspace_path = Path(config["workspace_path"])
39+
chain_style = config["chain_style"]
3440

35-
for model_name, info in model_data_map.items():
36-
model_path = info["path"]
37-
split_points = info["split_points"]
41+
model_file_path = Path(model.__class__.__graph_net_file_path__)
42+
model_name = model_file_path.parent.name
3843

39-
model_output_dir = self.workspace_root / f"{model_name}_decomposed"
40-
model_output_dir.mkdir(parents=True, exist_ok=True)
44+
model_info = load_json(config["split_results_path"])[model_name]
45+
model_path = model_info["path"]
46+
split_points = model_info["split_points"]
4147

42-
config_dict = {
43-
"decorator_path": str(self.graph_net_root / "torch/extractor.py"),
44-
"decorator_config": {
45-
"name": model_name,
46-
"custom_extractor_path": str(
47-
self.graph_net_root / "torch/naive_graph_decomposer.py"
48+
model_output_dir = workspace_path / f"{model_name}_decomposed"
49+
model_output_dir.mkdir(parents=True, exist_ok=True)
50+
51+
config_dict = {
52+
"decorator_path": str(self.graph_net_root / "torch/extractor.py"),
53+
"decorator_config": {
54+
"name": model_name,
55+
"custom_extractor_path": str(
56+
self.graph_net_root / "torch/naive_graph_decomposer.py"
57+
),
58+
"custom_extractor_config": {
59+
"output_dir": str(model_output_dir),
60+
"split_positions": split_points,
61+
"group_head_and_tail": True,
62+
"filter_path": str(
63+
self.graph_net_root / "torch/naive_subgraph_filter.py"
4864
),
49-
"custom_extractor_config": {
50-
"output_dir": str(model_output_dir),
51-
"split_positions": split_points,
52-
"group_head_and_tail": True,
53-
"filter_path": str(
54-
self.graph_net_root / "torch/naive_subgraph_filter.py"
55-
),
56-
"filter_config": {},
57-
},
65+
"filter_config": {},
66+
"chain_style": chain_style,
5867
},
59-
}
60-
61-
encoded_config = encode_config(config_dict)
62-
63-
cmd = [
64-
sys.executable,
65-
"-m",
66-
"graph_net.torch.run_model",
67-
"--model-path",
68-
model_path,
69-
"--decorator-config",
70-
encoded_config,
71-
]
72-
73-
try:
74-
subprocess.run(cmd, check=True)
75-
print(f" [Success] Saved to {model_output_dir}")
76-
except subprocess.CalledProcessError as e:
77-
print(f" [Error] Process failed: {e}")
78-
except Exception as e:
79-
print(f" [Error] Unexpected: {e}")
68+
},
69+
}
70+
71+
encoded_config = encode_config(config_dict)
72+
73+
cmd = [
74+
sys.executable,
75+
"-m",
76+
"graph_net.torch.run_model",
77+
"--model-path",
78+
model_path,
79+
"--decorator-config",
80+
encoded_config,
81+
]
82+
83+
try:
84+
subprocess.run(cmd, check=True)
85+
print(f"[Success] Saved to {model_output_dir}")
86+
except subprocess.CalledProcessError as e:
87+
print(f"[Error] Process failed: {e}")
88+
except Exception as e:
89+
print(f"[Error] Unexpected: {e}")
90+
return model
91+
92+
def synchronize(self):
93+
if torch.cuda.is_available():
94+
torch.cuda.synchronize()

graph_net/torch/test_compiler.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ def load_class_from_file(
9696

9797
def get_compiler_backend(args) -> GraphCompilerBackend:
9898
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
99-
return registry_backend[args.compiler]
99+
backend = registry_backend[args.compiler]
100+
if args.config is not None:
101+
backend.config = args.config
102+
return backend
100103

101104

102105
def get_model(args):
@@ -396,16 +399,11 @@ def test_multi_models(args):
396399

397400

398401
def main(args):
402+
assert os.path.isdir(args.model_path)
403+
399404
initalize_seed = 123
400405
set_seed(random_seed=initalize_seed)
401406

402-
if args.compiler == "range_decomposer":
403-
compiler = get_compiler_backend(args)
404-
compiler(args)
405-
return
406-
407-
assert os.path.isdir(args.model_path)
408-
409407
if path_utils.is_single_model_dir(args.model_path):
410408
test_single_model(args)
411409
else:
@@ -454,5 +452,12 @@ def main(args):
454452
default=None,
455453
help="Path to samples list, each line contains a sample path",
456454
)
455+
parser.add_argument(
456+
"--config",
457+
type=str,
458+
required=False,
459+
default=None,
460+
help="Path to configuration file.",
461+
)
457462
args = parser.parse_args()
458463
main(args=args)

graph_net/torch/typical_sequence_split_points.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import argparse
2-
import importlib.util
32
import json
43
import os
5-
import sys
64
from pathlib import Path
75
from typing import Any, Callable, Dict, List
86

@@ -53,8 +51,6 @@ def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor])
5351
class TypicalSequenceModelLoader:
5452
def load_class_from_file(self, model_path: str, device: str) -> Any:
5553
file_path = os.path.join(model_path, "model.py")
56-
file = Path(file_path).resolve()
57-
module_name = file.stem
5854

5955
if not os.path.exists(file_path):
6056
raise FileNotFoundError(f"Model file not found: {file_path}")

0 commit comments

Comments
 (0)