Skip to content

Commit b1de6a0

Browse files
committed
fix
1 parent 92a2d40 commit b1de6a0

30 files changed

+34
-6631
lines changed

graph_net/test/chain_naive_graph_decomposer_test.sh

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,12 @@ GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
44
os.path.dirname(graph_net.__file__))")
55

66
# input model path
7-
MODEL_PATH_IN_SAMPLES=/timm/resnet18
8-
MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES")
9-
OUTPUT_DIR="${NAIVE_DECOMPOSE_WORKSPACE:-$(pwd)/naive_decompose_workspace}"
10-
7+
MODEL_PATH_IN_SAMPLES=/timm/resnet18
118
extractor_config_json_str=$(cat <<EOF
129
{
1310
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
1411
"custom_extractor_config": {
15-
"output_dir": "$OUTPUT_DIR/${MODEL_NAME}_decomposed",
12+
"output_dir": "/tmp/chain_naive_decompose_workspace",
1613
"split_positions": [8, 16, 32],
1714
"group_head_and_tail": true,
1815
"chain_style": true
@@ -23,4 +20,4 @@ EOF
2320
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
2421

2522
mkdir -p /tmp/naive_decompose_workspace
26-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name $MODEL_NAME --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
23+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG

graph_net/test/decomposer_validator_test.sh

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,40 @@
11
#!/bin/bash
22

3-
if [ -z "$GRAPH_NET_BENCHMARK_PATH" ]; then
4-
GRAPH_NET_BENCHMARK_PATH="$(pwd)/graphnet_benchmark"
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(graph_net.__file__))")
4+
5+
if [ -z "$GRAPH_NET_DECOMPOSE_PATH" ]; then
6+
GRAPH_NET_DECOMPOSE_PATH="$(pwd)/graphnet_decompose"
57
fi
68

7-
FILE_PATH=$GRAPH_NET_BENCHMARK_PATH/decomposer
9+
MODEL_PATH_IN_SAMPLES=/timm/resnet18
10+
MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES")
11+
OUTPUT_DIR="${GRAPH_NET_DECOMPOSE_PATH:-$(pwd)}"
12+
cp -r "$GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES" "$OUTPUT_DIR/"
13+
14+
extractor_config_json_str=$(cat <<EOF
15+
{
16+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
17+
"custom_extractor_config": {
18+
"output_dir": "$OUTPUT_DIR/${MODEL_NAME}_decomposed",
19+
"split_positions": [8, 16, 32],
20+
"group_head_and_tail": true,
21+
"chain_style": true
22+
}
23+
}
24+
EOF
25+
)
26+
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
27+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name $MODEL_NAME --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
28+
29+
FILE_PATH=$GRAPH_NET_DECOMPOSE_PATH/decomposer
830
mkdir -p "$(dirname "$FILE_PATH/log.log")"
9-
10-
MODEL_PATH="./todo_works/range_decomposer_validator/test/resnet18"
31+
MODEL_PATH="$GRAPH_NET_DECOMPOSE_PATH/$MODEL_NAME"
1132

1233
python -m graph_net.torch.test_compiler \
1334
--model-path $MODEL_PATH \
1435
--compiler range_decomposer_validator \
1536
--device cuda > "$FILE_PATH/log.log" 2>&1
1637

17-
if [ $? -ne 0 ]; then
18-
echo "Error: decomposer_validator execution failed"
19-
echo "Please check the log file: $FILE_PATH/log.log"
20-
exit 1
21-
fi
22-
2338
python -m graph_net.log2json \
2439
--log-file "$FILE_PATH/log.log" \
2540
--output-dir "$FILE_PATH/JSON_results/"

graph_net/torch/backend/range_decomposer_validator_backend.py

Lines changed: 5 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,54 +12,10 @@ class ComposedModel(nn.Module):
1212
def __init__(self, graph: nn.Module, subgraph: List[nn.Module]):
1313
super().__init__()
1414
self.graph = graph
15-
self.subgraph = nn.ModuleList(subgraph)
16-
self.extract_node = []
17-
self.graph_model = torch.compile(self.graph, backend=self.extract_compiler)
18-
19-
def _serialize_arg(self, arg: Any) -> Any:
20-
if isinstance(arg, torch.fx.Node):
21-
return arg.name
22-
if isinstance(arg, (list, tuple)):
23-
return type(arg)(self._serialize_arg(elem) for elem in arg)
24-
if isinstance(arg, dict):
25-
return {
26-
self._serialize_arg(k): self._serialize_arg(v) for k, v in arg.items()
27-
}
28-
return arg
29-
30-
def _extract_operators_from_graph(
31-
self, gm: nn.Module, example_inputs: List[torch.Tensor] = None
32-
) -> List[Dict[str, Any]]:
33-
operator_list = []
34-
for node in gm.graph.nodes:
35-
if node.op in ("call_method", "call_function", "call_module"):
36-
operator_info = {
37-
"op_type": node.op,
38-
"target": node.target,
39-
"kwargs": self._serialize_arg(node.kwargs),
40-
}
41-
42-
if isinstance(node.target, Callable):
43-
try:
44-
operator_info["target_name"] = node.target.__name__
45-
except AttributeError:
46-
operator_info["target_name"] = str(node.target)
47-
else:
48-
operator_info["target_name"] = str(node.target)
49-
50-
operator_list.append(operator_info)
51-
52-
return operator_list
53-
54-
def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
55-
operator = self._extract_operators_from_graph(gm, inputs)
56-
self.extract_node.append(operator)
57-
return gm.forward
15+
self.subgraphs = nn.ModuleList(subgraph)
5816

5917
def forward(self, **kwargs):
60-
self.graph_model(**kwargs)
61-
graph_node_list = list(itertools.chain.from_iterable(self.extract_node))
62-
self.extract_node = []
18+
self.graph(**kwargs)
6319

6420
subgraph_intput = {
6521
key.replace("L", "l_l", 1): value
@@ -68,31 +24,11 @@ def forward(self, **kwargs):
6824
}
6925

7026
output = None
71-
for subgraph_model in self.subgraph:
72-
compiled_model = torch.compile(
73-
subgraph_model, backend=self.extract_compiler
74-
)
75-
27+
for subgraph in self.subgraphs:
7628
if output is None:
77-
output = compiled_model(**subgraph_intput)
29+
output = subgraph(**subgraph_intput)
7830
else:
79-
output = compiled_model(*output)
80-
81-
subgraph_node_list = list(itertools.chain.from_iterable(self.extract_node))
82-
self.extract_node = []
83-
84-
if graph_node_list != subgraph_node_list:
85-
diff_in_graph = [
86-
item for item in graph_node_list if item not in subgraph_node_list
87-
]
88-
diff_in_subgraph = [
89-
item for item in subgraph_node_list if item not in graph_node_list
90-
]
91-
92-
error_msg = f"Subgraph segmentation verification failed\n"
93-
error_msg += f"Nodes in graph but not in subgraph: {diff_in_graph}\n"
94-
error_msg += f"Nodes in subgraph but not in graph: {diff_in_subgraph}"
95-
raise ValueError(error_msg)
31+
output = subgraph(*output)
9632

9733
return output
9834

todo_works/range_decomposer_validator/__init__.py

Whitespace-only changes.

todo_works/range_decomposer_validator/test/resnet18/graph_hash.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

todo_works/range_decomposer_validator/test/resnet18/graph_net.json

Lines changed: 0 additions & 7 deletions
This file was deleted.

todo_works/range_decomposer_validator/test/resnet18/input_meta.py

Whitespace-only changes.

todo_works/range_decomposer_validator/test/resnet18/input_tensor_constraints.py

Whitespace-only changes.

0 commit comments

Comments
 (0)