Skip to content

Commit c6f01fa

Browse files
author
jorjortuajing
committed
add validator
1 parent c65f7fa commit c6f01fa

File tree

4 files changed

+135
-32
lines changed

4 files changed

+135
-32
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/bin/bash
2+
3+
if [ -z "$GRAPH_NET_BENCHMARK_PATH" ]; then
4+
GRAPH_NET_BENCHMARK_PATH="$(pwd)"
5+
fi
6+
7+
FILE_PATH=$GRAPH_NET_BENCHMARK_PATH/decomposer
8+
mkdir -p "$(dirname "$FILE_PATH/log.log")"
9+
10+
MODEL_PATH="./todo_works/range_decomposer_validator/test/simple_CNN"
11+
12+
python -m graph_net.torch.test_compiler \
13+
--model-path $MODEL_PATH \
14+
--compiler range_decomposer_validator \
15+
--device cuda > "$FILE_PATH/log.log" 2>&1
16+
17+
if [ $? -ne 0 ]; then
18+
echo "Error: decomposer_validator execution failed"
19+
echo "Please check the log file: $FILE_PATH/log.log"
20+
exit 1
21+
fi
22+
23+
python -m graph_net.log2json \
24+
--log-file "$FILE_PATH/log.log" \
25+
--output-dir "$FILE_PATH/JSON_results/"
26+
27+
python -m graph_net.plot_ESt \
28+
--benchmark-path "$FILE_PATH/JSON_results/" \
29+
--output-dir "$FILE_PATH"
30+
31+
echo "=================================================="
32+
echo "Results saved in: $FILE_PATH/ES_result.png"
33+
echo ""
34+
echo "IMPORTANT: Please verify if the curve in ES_result.png is a straight line"
35+
echo "If the curve is NOT a straight line, please check the log file: $FILE_PATH/log.log"
36+
echo "=================================================="

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@ GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
44
os.path.dirname(graph_net.__file__))")
55

66
# input model path
7-
MODEL_PATH_IN_SAMPLES=/timm/resnet18
7+
MODEL_PATH_IN_SAMPLES=/timm/resnet18
8+
MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES")
9+
OUTPUT_DIR="${NAIVE_DECOMPOSE_WORKSPACE:-$(pwd)/naive_decompose_workspace}"
10+
811
extractor_config_json_str=$(cat <<EOF
912
{
1013
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
1114
"custom_extractor_config": {
12-
"output_dir": "/tmp/naive_decompose_workspace",
15+
"output_dir": "$OUTPUT_DIR/${MODEL_NAME}_decomposed",
1316
"split_positions": [8, 16, 32],
1417
"group_head_and_tail": true,
1518
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
@@ -20,5 +23,4 @@ EOF
2023
)
2124
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
2225

23-
mkdir -p /tmp/naive_decompose_workspace
24-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
26+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name $MODEL_NAME --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG

graph_net/torch/test_compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def load_class_from_file(
6969
exec(compiled_code, module.__dict__)
7070

7171
model_class = getattr(module, class_name, None)
72+
setattr(model_class, "__file_path__", file_path)
73+
setattr(model_class, "__device__", device)
7274
return model_class
7375

7476

todo_works/range_decomposer_validator/range_decomposer_validator.py

Lines changed: 91 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,106 @@
44
import sys
55
import inspect
66
import importlib.util
7-
from typing import List, Dict
7+
import itertools
8+
from typing import List, Tuple, Dict, Any, Callable
89

910

1011
class ComposedModel(nn.Module):
11-
def __init__(self, submodules: List[nn.Module]):
12+
def __init__(self, graph: nn.Module, subgraph: List[nn.Module]):
1213
super().__init__()
13-
self.submodules = nn.ModuleList(submodules)
14-
self.submodule_param_names = [
14+
self.graph = graph
15+
self.subgraph = nn.ModuleList(subgraph)
16+
self.subgraph_param_names = [
1517
list(inspect.signature(sm.forward).parameters.keys())
16-
for sm in self.submodules
18+
for sm in self.subgraph
1719
]
20+
self.extract_node = []
21+
22+
def _serialize_arg(self, arg: Any) -> Any:
23+
if isinstance(arg, torch.fx.Node):
24+
return arg.name
25+
if isinstance(arg, (list, tuple)):
26+
return type(arg)(self._serialize_arg(elem) for elem in arg)
27+
if isinstance(arg, dict):
28+
return {
29+
self._serialize_arg(k): self._serialize_arg(v) for k, v in arg.items()
30+
}
31+
return arg
32+
33+
def _extract_operators_from_graph(
34+
self, gm: nn.Module, example_inputs: List[torch.Tensor] = None
35+
) -> List[Dict[str, Any]]:
36+
operator_list = []
37+
for node in gm.graph.nodes:
38+
if node.op in ("call_method", "call_function", "call_module"):
39+
operator_info = {
40+
"op_type": node.op,
41+
"target": node.target,
42+
"name": node.name,
43+
"kwargs": self._serialize_arg(node.kwargs),
44+
}
45+
46+
if isinstance(node.target, Callable):
47+
try:
48+
operator_info["target_name"] = node.target.__name__
49+
except AttributeError:
50+
operator_info["target_name"] = str(node.target)
51+
else:
52+
operator_info["target_name"] = str(node.target)
53+
54+
operator_list.append(operator_info)
55+
56+
return operator_list
57+
58+
def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
59+
operator = self._extract_operators_from_graph(gm, inputs)
60+
self.extract_node.append(operator)
61+
return gm.forward
1862

1963
def forward(self, **kwargs):
2064
current_args = kwargs
65+
compiled_model = torch.compile(self.graph, backend=self.extract_compiler)
66+
compiled_model(**current_args)
67+
graph_node_list = list(itertools.chain.from_iterable(self.extract_node))
68+
self.extract_node = []
69+
2170
for i, (sm, param_names) in enumerate(
22-
zip(self.submodules, self.submodule_param_names)
71+
zip(self.subgraph, self.subgraph_param_names)
2372
):
24-
# 准备当前子图的输入字典
2573
call_kwargs = {}
2674
if i > 0:
27-
# 对于后续子图,第一个参数是上一个子图的输出
2875
first_param_name = param_names[0]
29-
call_kwargs[first_param_name] = current_args # current_args 此时是上一个子图的输出
76+
call_kwargs[first_param_name] = current_args
77+
remaining_params = param_names[1:]
78+
else:
79+
remaining_params = param_names
3080

31-
# 从主输入字典中筛选出当前子图需要的权重参数
32-
for name in param_names:
33-
if name in current_args:
34-
call_kwargs[name] = current_args[name]
81+
for name in remaining_params:
82+
if name in kwargs:
83+
call_kwargs[name] = kwargs[name]
3584

36-
outputs = sm(**call_kwargs)
37-
# 假设每个子图只有一个输出,并且返回的是一个元组
85+
compiled_model = torch.compile(sm, backend=self.extract_compiler)
86+
outputs = compiled_model(**call_kwargs)
3887
current_args = outputs[0]
3988

89+
subgraph_node_list = list(itertools.chain.from_iterable(self.extract_node))
90+
self.extract_node = []
91+
92+
if graph_node_list != subgraph_node_list:
93+
diff_in_graph = [
94+
item for item in graph_node_list if item not in subgraph_node_list
95+
]
96+
diff_in_subgraph = [
97+
item for item in subgraph_node_list if item not in graph_node_list
98+
]
99+
100+
error_msg = f"Subgraph segmentation verification failed\n"
101+
error_msg += f"Nodes in graph but not in subgraph: {diff_in_graph}\n"
102+
error_msg += f"Nodes in subgraph but not in graph: {diff_in_subgraph}"
103+
raise ValueError(error_msg)
104+
else:
105+
print("")
106+
40107
return (current_args,)
41108

42109

@@ -54,36 +121,32 @@ def _load_model_instance(self, path: str, device: str) -> torch.nn.Module:
54121
return instance
55122

56123
def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
57-
model_file_path = inspect.getfile(
58-
model.__class__
59-
) # e.g., /test/simple_CNN/model.py
60-
model_dir = os.path.dirname(model_file_path) # e.g., /test/simple_CNN
61-
62-
decomposed_parent_dir = (
63-
model_dir + "_decomposed"
64-
) # e.g., /test/simple_CNN_decomposed
124+
model_file_path = model.__class__.__file_path__
125+
model_dir = os.path.dirname(model_file_path)
126+
decomposed_parent_dir = model_dir + "_decomposed"
65127
subgraph_paths = []
66128
for name in sorted(os.listdir(decomposed_parent_dir)):
67129
full_path = os.path.join(decomposed_parent_dir, name)
68-
if os.path.isdir(full_path) and name.startswith("subgraph_"):
130+
if os.path.isdir(full_path) and name[-1].isdigit():
69131
subgraph_paths.append(full_path)
70132

71133
print(
72134
f"[RangeDecomposerValidatorBackend] Found subgraphs: {[os.path.basename(p) for p in subgraph_paths]}"
73135
)
74136

75-
submodule_instances = []
76-
device = next(model.parameters()).device # 从传入的model获取device信息
137+
device = model.__class__.__device__
138+
graph_instances = self._load_model_instance(model_dir, device)
139+
subgraph_instances = []
77140

78141
for path in subgraph_paths:
79142
instance = self._load_model_instance(path, device)
80-
submodule_instances.append(instance)
143+
subgraph_instances.append(instance)
81144
dir_name = os.path.basename(path)
82145
print(
83146
f"[RangeDecomposerValidatorBackend] Loaded and instantiated '{dir_name}'"
84147
)
85148

86-
composed_model = ComposedModel(submodule_instances)
149+
composed_model = ComposedModel(graph_instances, subgraph_instances)
87150
return composed_model.eval()
88151

89152
def synchronize(self):

0 commit comments

Comments
 (0)