Skip to content

Commit 348b3e0

Browse files
authored
Merge branch 'develop' into one_hot119
2 parents 4e56dbb + 7268252 commit 348b3e0

30 files changed

+156
-495
lines changed

graph_net/test/chain_naive_graph_decomposer_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ EOF
2020
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
2121

2222
mkdir -p /tmp/naive_decompose_workspace
23-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
23+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(graph_net.__file__))")
4+
5+
if [ -z "$GRAPH_NET_DECOMPOSE_PATH" ]; then
6+
GRAPH_NET_DECOMPOSE_PATH="$(pwd)/graphnet_decompose"
7+
fi
8+
9+
MODEL_PATH_IN_SAMPLES=/timm/resnet18
10+
MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES")
11+
OUTPUT_DIR="${GRAPH_NET_DECOMPOSE_PATH:-$(pwd)}"
12+
cp -r "$GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES" "$OUTPUT_DIR/"
13+
14+
extractor_config_json_str=$(cat <<EOF
15+
{
16+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
17+
"custom_extractor_config": {
18+
"output_dir": "$OUTPUT_DIR/${MODEL_NAME}_decomposed",
19+
"split_positions": [8, 16, 32],
20+
"group_head_and_tail": true,
21+
"chain_style": true
22+
}
23+
}
24+
EOF
25+
)
26+
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
27+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name $MODEL_NAME --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
28+
29+
FILE_PATH=$GRAPH_NET_DECOMPOSE_PATH/decomposer
30+
mkdir -p "$(dirname "$FILE_PATH/log.log")"
31+
MODEL_PATH="$GRAPH_NET_DECOMPOSE_PATH/$MODEL_NAME"
32+
33+
python -m graph_net.torch.test_compiler \
34+
--model-path $MODEL_PATH \
35+
--compiler range_decomposer_validator \
36+
--device cuda > "$FILE_PATH/log.log" 2>&1
37+
38+
python -m graph_net.log2json \
39+
--log-file "$FILE_PATH/log.log" \
40+
--output-dir "$FILE_PATH/JSON_results/"
41+
42+
python -m graph_net.plot_ESt \
43+
--benchmark-path "$FILE_PATH/JSON_results/" \
44+
--output-dir "$FILE_PATH"
45+
46+
echo "=================================================="
47+
echo "Results saved in: $FILE_PATH/ES_result.png"
48+
echo ""
49+
echo "IMPORTANT: Please verify if the curve in ES_result.png is a straight line"
50+
echo "If the curve is NOT a straight line, please check the log file: $FILE_PATH/log.log"
51+
echo "=================================================="

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ EOF
2121
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
2222

2323
mkdir -p /tmp/naive_decompose_workspace
24-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
24+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch
2+
import torch.nn as nn
3+
import os
4+
import sys
5+
import inspect
6+
import importlib.util
7+
import itertools
8+
from typing import List, Tuple, Dict, Any, Callable
9+
10+
11+
class ComposedModel(nn.Module):
12+
def __init__(self, subgraph: List[nn.Module]):
13+
super().__init__()
14+
self.subgraphs = nn.ModuleList(subgraph)
15+
16+
def forward(self, **kwargs):
17+
subgraph_intput = {
18+
key.replace("L", "l_l", 1): value
19+
for key, value in kwargs.items()
20+
if key.startswith("L")
21+
}
22+
23+
output = None
24+
for subgraph in self.subgraphs:
25+
if output is None:
26+
output = subgraph(**subgraph_intput)
27+
else:
28+
output = subgraph(*output)
29+
30+
return output
31+
32+
33+
class RangeDecomposerValidatorBackend:
34+
def _load_model_instance(self, path: str, device: str) -> torch.nn.Module:
35+
class_name = "GraphModule"
36+
model_file = os.path.join(path, "model.py")
37+
38+
spec = importlib.util.spec_from_file_location(class_name, model_file)
39+
module = importlib.util.module_from_spec(spec)
40+
spec.loader.exec_module(module)
41+
42+
ModelClass = getattr(module, class_name)
43+
instance = ModelClass().to(device)
44+
return instance
45+
46+
def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
47+
model_file_path = model.__class__.__graph_net_file_path__
48+
model_dir = os.path.dirname(model_file_path)
49+
decomposed_parent_dir = model_dir + "_decomposed"
50+
subgraph_paths = []
51+
for name in sorted(os.listdir(decomposed_parent_dir)):
52+
full_path = os.path.join(decomposed_parent_dir, name)
53+
if os.path.isdir(full_path) and name[-1].isdigit():
54+
subgraph_paths.append(full_path)
55+
56+
print(
57+
f"[RangeDecomposerValidatorBackend] Found subgraphs: {[os.path.basename(p) for p in subgraph_paths]}"
58+
)
59+
60+
device = model.__class__.__graph_net_device__
61+
subgraph_instances = []
62+
63+
for path in subgraph_paths:
64+
instance = self._load_model_instance(path, device)
65+
subgraph_instances.append(instance)
66+
dir_name = os.path.basename(path)
67+
print(
68+
f"[RangeDecomposerValidatorBackend] Loaded and instantiated '{dir_name}'"
69+
)
70+
71+
composed_model = ComposedModel(subgraph_instances)
72+
return composed_model.eval()
73+
74+
def synchronize(self):
75+
if torch.cuda.is_available():
76+
torch.cuda.synchronize()

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,28 @@ def _impl_unstable_to_stable_one_hot(self, gm):
204204

205205
return gm
206206

207-
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
207+
def _impl_unstable_to_stable_set_grad_enabled(self, gm):
208+
"""
209+
Convert torch._C._set_grad_enabled and torch._C.set_grad_enabled to torch.set_grad_enabled
210+
"""
211+
212+
def replace_in_graph(graph_mod):
213+
for node in graph_mod.graph.nodes:
214+
if node.op == "call_function":
215+
if "set_grad_enabled" in str(node.target):
216+
node.target = torch.set_grad_enabled
217+
graph_mod.recompile()
218+
219+
modules = [gm]
220+
modules += [
221+
m
222+
for _, m in gm.named_modules()
223+
if isinstance(m, torch.fx.GraphModule) and m is not gm
224+
]
225+
for m in modules:
226+
replace_in_graph(m)
227+
228+
return gm
208229

209230
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
210231

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
143143
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
144144
(r"torch\._C\._nn\.softplus\(", "torch.nn.functional.softplus("),
145145
(r"torch\._C\._nn\.one_hot\(", "torch.nn.functional.one_hot("),
146-
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
146+
(r"torch\._C\._set_grad_enabled\(", "torch.set_grad_enabled("),
147+
(r"torch\._C\.set_grad_enabled\(", "torch.set_grad_enabled("),
147148
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
148149
# replace this line with modification code for task 123 (torch._C._nn.pad)
149150
# replace this line with modification code for task 125 (torch._C._nn.gelu)

graph_net/torch/test_compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from graph_net.torch.backend.blade_disc_backend import BladeDISCBackend
2424
from graph_net.torch.backend.nope_backend import NopeBackend
2525
from graph_net.torch.backend.unstable_to_stable_backend import UnstableToStableBackend
26-
from todo_works.range_decomposer_validator.range_decomposer_validator import (
26+
from graph_net.torch.backend.range_decomposer_validator_backend import (
2727
RangeDecomposerValidatorBackend,
2828
)
2929
from graph_net.test_compiler_util import generate_allclose_configs
@@ -69,6 +69,8 @@ def load_class_from_file(
6969
exec(compiled_code, module.__dict__)
7070

7171
model_class = getattr(module, class_name, None)
72+
setattr(model_class, "__graph_net_file_path__", file_path)
73+
setattr(model_class, "__graph_net_device__", device)
7274
return model_class
7375

7476

todo_works/range_decomposer_validator/__init__.py

Whitespace-only changes.

todo_works/range_decomposer_validator/range_decomposer_validator.py

Lines changed: 0 additions & 91 deletions
This file was deleted.

todo_works/range_decomposer_validator/test/simple_CNN/graph_hash.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)