Skip to content

Commit 92a2d40

Browse files
committed
fix
1 parent ebdc676 commit 92a2d40

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+6579
-440
lines changed

graph_net/test/chain_naive_graph_decomposer_test.sh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@ GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
44
os.path.dirname(graph_net.__file__))")
55

66
# input model path
7-
MODEL_PATH_IN_SAMPLES=/timm/resnet18
7+
MODEL_PATH_IN_SAMPLES=/timm/resnet18
8+
MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES")
9+
OUTPUT_DIR="${NAIVE_DECOMPOSE_WORKSPACE:-$(pwd)/naive_decompose_workspace}"
10+
811
extractor_config_json_str=$(cat <<EOF
912
{
1013
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
1114
"custom_extractor_config": {
12-
"output_dir": "/tmp/chain_naive_decompose_workspace",
15+
"output_dir": "$OUTPUT_DIR/${MODEL_NAME}_decomposed",
1316
"split_positions": [8, 16, 32],
1417
"group_head_and_tail": true,
1518
"chain_style": true
@@ -20,4 +23,4 @@ EOF
2023
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
2124

2225
mkdir -p /tmp/naive_decompose_workspace
23-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
26+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name $MODEL_NAME --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG

graph_net/test/decomposer_validator_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#!/bin/bash
22

33
if [ -z "$GRAPH_NET_BENCHMARK_PATH" ]; then
4-
GRAPH_NET_BENCHMARK_PATH="$(pwd)"
4+
GRAPH_NET_BENCHMARK_PATH="$(pwd)/graphnet_benchmark"
55
fi
66

77
FILE_PATH=$GRAPH_NET_BENCHMARK_PATH/decomposer
88
mkdir -p "$(dirname "$FILE_PATH/log.log")"
99

10-
MODEL_PATH="./todo_works/range_decomposer_validator/test/simple_CNN"
10+
MODEL_PATH="./todo_works/range_decomposer_validator/test/resnet18"
1111

1212
python -m graph_net.torch.test_compiler \
1313
--model-path $MODEL_PATH \

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,12 @@ GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
44
os.path.dirname(graph_net.__file__))")
55

66
# input model path
7-
MODEL_PATH_IN_SAMPLES=/timm/resnet18
8-
MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES")
9-
OUTPUT_DIR="${NAIVE_DECOMPOSE_WORKSPACE:-$(pwd)/naive_decompose_workspace}"
10-
7+
MODEL_PATH_IN_SAMPLES=/timm/resnet18
118
extractor_config_json_str=$(cat <<EOF
129
{
1310
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
1411
"custom_extractor_config": {
15-
"output_dir": "$OUTPUT_DIR/${MODEL_NAME}_decomposed",
12+
"output_dir": "/tmp/naive_decompose_workspace",
1613
"split_positions": [8, 16, 32],
1714
"group_head_and_tail": true,
1815
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
@@ -23,4 +20,5 @@ EOF
2320
)
2421
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
2522

26-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name $MODEL_NAME --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
23+
mkdir -p /tmp/naive_decompose_workspace
24+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG

todo_works/range_decomposer_validator/range_decomposer_validator.py renamed to graph_net/torch/backend/range_decomposer_validator_backend.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,8 @@ def __init__(self, graph: nn.Module, subgraph: List[nn.Module]):
1313
super().__init__()
1414
self.graph = graph
1515
self.subgraph = nn.ModuleList(subgraph)
16-
self.subgraph_param_names = [
17-
list(inspect.signature(sm.forward).parameters.keys())
18-
for sm in self.subgraph
19-
]
2016
self.extract_node = []
17+
self.graph_model = torch.compile(self.graph, backend=self.extract_compiler)
2118

2219
def _serialize_arg(self, arg: Any) -> Any:
2320
if isinstance(arg, torch.fx.Node):
@@ -39,7 +36,6 @@ def _extract_operators_from_graph(
3936
operator_info = {
4037
"op_type": node.op,
4138
"target": node.target,
42-
"name": node.name,
4339
"kwargs": self._serialize_arg(node.kwargs),
4440
}
4541

@@ -61,30 +57,26 @@ def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor])
6157
return gm.forward
6258

6359
def forward(self, **kwargs):
64-
current_args = kwargs
65-
compiled_model = torch.compile(self.graph, backend=self.extract_compiler)
66-
compiled_model(**current_args)
60+
self.graph_model(**kwargs)
6761
graph_node_list = list(itertools.chain.from_iterable(self.extract_node))
6862
self.extract_node = []
6963

70-
for i, (sm, param_names) in enumerate(
71-
zip(self.subgraph, self.subgraph_param_names)
72-
):
73-
call_kwargs = {}
74-
if i > 0:
75-
first_param_name = param_names[0]
76-
call_kwargs[first_param_name] = current_args
77-
remaining_params = param_names[1:]
78-
else:
79-
remaining_params = param_names
64+
subgraph_intput = {
65+
key.replace("L", "l_l", 1): value
66+
for key, value in kwargs.items()
67+
if key.startswith("L")
68+
}
8069

81-
for name in remaining_params:
82-
if name in kwargs:
83-
call_kwargs[name] = kwargs[name]
70+
output = None
71+
for subgraph_model in self.subgraph:
72+
compiled_model = torch.compile(
73+
subgraph_model, backend=self.extract_compiler
74+
)
8475

85-
compiled_model = torch.compile(sm, backend=self.extract_compiler)
86-
outputs = compiled_model(**call_kwargs)
87-
current_args = outputs[0]
76+
if output is None:
77+
output = compiled_model(**subgraph_intput)
78+
else:
79+
output = compiled_model(*output)
8880

8981
subgraph_node_list = list(itertools.chain.from_iterable(self.extract_node))
9082
self.extract_node = []
@@ -102,7 +94,7 @@ def forward(self, **kwargs):
10294
error_msg += f"Nodes in subgraph but not in graph: {diff_in_subgraph}"
10395
raise ValueError(error_msg)
10496

105-
return (current_args,)
97+
return output
10698

10799

108100
class RangeDecomposerValidatorBackend:
@@ -119,7 +111,7 @@ def _load_model_instance(self, path: str, device: str) -> torch.nn.Module:
119111
return instance
120112

121113
def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
122-
model_file_path = model.__class__.__file_path__
114+
model_file_path = model.__class__.__graph_net_file_path__
123115
model_dir = os.path.dirname(model_file_path)
124116
decomposed_parent_dir = model_dir + "_decomposed"
125117
subgraph_paths = []
@@ -132,7 +124,7 @@ def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
132124
f"[RangeDecomposerValidatorBackend] Found subgraphs: {[os.path.basename(p) for p in subgraph_paths]}"
133125
)
134126

135-
device = model.__class__.__device__
127+
device = model.__class__.__graph_net_device__
136128
graph_instances = self._load_model_instance(model_dir, device)
137129
subgraph_instances = []
138130

graph_net/torch/test_compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from graph_net.torch.backend.blade_disc_backend import BladeDISCBackend
2424
from graph_net.torch.backend.nope_backend import NopeBackend
2525
from graph_net.torch.backend.unstable_to_stable_backend import UnstableToStableBackend
26-
from todo_works.range_decomposer_validator.range_decomposer_validator import (
26+
from graph_net.torch.backend.range_decomposer_validator_backend import (
2727
RangeDecomposerValidatorBackend,
2828
)
2929
from graph_net.test_compiler_util import generate_allclose_configs
@@ -69,8 +69,8 @@ def load_class_from_file(
6969
exec(compiled_code, module.__dict__)
7070

7171
model_class = getattr(module, class_name, None)
72-
setattr(model_class, "__file_path__", file_path)
73-
setattr(model_class, "__device__", device)
72+
setattr(model_class, "__graph_net_file_path__", file_path)
73+
setattr(model_class, "__graph_net_device__", device)
7474
return model_class
7575

7676

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
248d46ebcf5bc02d3e72953ea430b5e18175b0419dbdbcd2479202497f58319d
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"source": "timm",
6+
"heuristic_tag": "computer_vision"
7+
}

todo_works/range_decomposer_validator/test/simple_CNN/input_meta.py renamed to todo_works/range_decomposer_validator/test/resnet18/input_meta.py

File renamed without changes.

todo_works/range_decomposer_validator/test/simple_CNN/input_tensor_constraints.py renamed to todo_works/range_decomposer_validator/test/resnet18/input_tensor_constraints.py

File renamed without changes.

0 commit comments

Comments
 (0)