Skip to content

Commit 3d24089

Browse files
committed
Update the process of typical_sequence_decomposer
1 parent 4e01709 commit 3d24089

File tree

3 files changed

+145
-42
lines changed

3 files changed

+145
-42
lines changed
Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,57 @@
11
#!/bin/bash
22

33
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4-
DECOMPOSE_PATH=$GRAPH_NET_ROOT/decompose_workspace
4+
DECOMPOSE_PATH=/tmp/decompose_workspace
55

66
mkdir -p "$DECOMPOSE_PATH"
77

8-
temp_model_list=$(mktemp)
9-
cat "$GRAPH_NET_ROOT/graph_net/config/torch_samples_list.txt" > "$temp_model_list"
8+
model_list="$GRAPH_NET_ROOT/graph_net/config/small100_torch_samples_list.txt"
109

1110
python3 -m graph_net.torch.typical_sequence_split_points \
12-
--model-list "$temp_model_list" \
11+
--model-list "$model_list" \
1312
--device "cuda" \
1413
--window-size 10 \
14+
--fold-policy default \
15+
--fold-times 10 \
1516
--output-json "$DECOMPOSE_PATH/split_results.json"
1617

17-
while IFS= read -r MODEL_PATH_IN_SAMPLES; do
18-
if [[ -n "$MODEL_PATH_IN_SAMPLES" ]]; then
19-
MODEL_FULL_PATH="$GRAPH_NET_ROOT/$MODEL_PATH_IN_SAMPLES"
20-
MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES")
21-
22-
echo "== Decomposing $MODEL_PATH_IN_SAMPLES. =="
23-
24-
decomposer_config_json_str=$(cat <<EOF
18+
decompose_config_json_str=$(cat <<EOF
2519
{
26-
"split_results_path": "$DECOMPOSE_PATH/split_results.json",
27-
"workspace_path": "$DECOMPOSE_PATH",
28-
"chain_style": true,
29-
"target_model_name": "$MODEL_NAME"
20+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/naive_graph_decomposer.py",
21+
"handler_class_name": "RangeDecomposerExtractor",
22+
"handler_config": {
23+
"model_path_prefix": "$GRAPH_NET_ROOT",
24+
"output_dir": "$DECOMPOSE_PATH",
25+
"split_results_path": "$DECOMPOSE_PATH/split_results.json",
26+
"group_head_and_tail": true,
27+
"chain_style": true
28+
}
3029
}
3130
EOF
32-
)
33-
DECOMPOSER_CONFIG=$(echo $decomposer_config_json_str | base64 -w 0)
34-
35-
python3 -m graph_net.torch.test_compiler \
36-
--model-path "$MODEL_FULL_PATH" \
37-
--compiler range_decomposer \
38-
--device cuda \
39-
--config="$DECOMPOSER_CONFIG"
40-
41-
cp -r "$MODEL_FULL_PATH" "$DECOMPOSE_PATH/"
42-
43-
echo "== Validating $MODEL_PATH_IN_SAMPLES. =="
31+
)
32+
DECOMPOSE_CONFIG=$(echo $decompose_config_json_str | base64 -w 0)
4433

45-
python3 -m graph_net.torch.test_compiler \
46-
--model-path "$DECOMPOSE_PATH/$MODEL_NAME" \
47-
--compiler range_decomposer_validator \
48-
--device cuda > "$DECOMPOSE_PATH/${MODEL_NAME}_validation.log" 2>&1
34+
python3 -m graph_net.model_path_handler \
35+
--model-path-list $model_list \
36+
--handler-config=$DECOMPOSE_CONFIG \
37+
--use-subprocess
4938

50-
echo "== Finished processing $MODEL_PATH_IN_SAMPLES. =="
51-
fi
52-
done < $temp_model_list
53-
54-
rm -f "$temp_model_list"
39+
test_compiler_config_json_str=$(cat <<EOF
40+
{
41+
"decomposed_root": "$DECOMPOSE_PATH"
42+
}
43+
EOF
44+
)
45+
TEST_COMPILER_CONFIG=$(echo $test_compiler_config_json_str | base64 -w 0)
5546

56-
cat $DECOMPOSE_PATH/*_validation.log >> $DECOMPOSE_PATH/combined.log
47+
python3 -m graph_net.torch.test_compiler \
48+
--allow-list $model_list \
49+
--compiler range_decomposer_validator \
50+
--device cuda \
51+
--config $TEST_COMPILER_CONFIG \
52+
--model-path-prefix $GRAPH_NET_ROOT \
53+
> "$DECOMPOSE_PATH/validation.log" 2>&1
5754

5855
python3 -m graph_net.plot_ESt \
59-
--benchmark-path "$DECOMPOSE_PATH/combined.log" \
56+
--benchmark-path "$DECOMPOSE_PATH/validation.log" \
6057
--output-dir "$DECOMPOSE_PATH"

graph_net/torch/naive_graph_decomposer.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
import os
22
import torch
3+
import json
34
from graph_net.torch.decompose_util import convert_to_submodules_graph
45
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
56
import graph_net.imp_util as imp_util
67
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
78
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
89

910

11+
def load_json(file_path):
12+
with open(file_path, "r", encoding="utf-8") as file:
13+
data_dict = json.load(file)
14+
return data_dict
15+
16+
1017
class GraphExtractor:
1118
"""
1219
Used by graph_net.torch.run_model
@@ -151,6 +158,83 @@ def fn(submodule, seq_no):
151158
return fn
152159

153160

161+
class RangeDecomposerExtractor:
162+
"""
163+
Used by graph_net.model_path_handler
164+
"""
165+
166+
def __init__(self, config: dict = None):
167+
if config is None:
168+
config = {}
169+
self.config = self._make_config(**config)
170+
171+
def _make_config(
172+
self,
173+
split_results_path=None,
174+
group_head_and_tail=False,
175+
chain_style=False,
176+
output_dir="./tmp/naive_decomposer_dir",
177+
filter_path=None,
178+
filter_config=None,
179+
post_extract_process_path=None,
180+
post_extract_process_class_name=None,
181+
post_extract_process_config=None,
182+
model_path_prefix="",
183+
**kwargs,
184+
):
185+
if os.path.isfile(split_results_path) and split_results_path.endswith(".json"):
186+
pass
187+
else:
188+
raise ValueError(
189+
f"split_results_path should be a valid JSON file path, but got {split_results_path=}"
190+
)
191+
if post_extract_process_config is None:
192+
post_extract_process_config = {}
193+
return {
194+
"split_results_path": split_results_path,
195+
"group_head_and_tail": group_head_and_tail,
196+
"chain_style": chain_style,
197+
"output_dir": output_dir,
198+
"filter_path": filter_path,
199+
"filter_config": filter_config if filter_config is not None else {},
200+
"post_extract_process_path": post_extract_process_path,
201+
"post_extract_process_class_name": post_extract_process_class_name,
202+
"post_extract_process_config": post_extract_process_config,
203+
"model_path_prefix": model_path_prefix,
204+
}
205+
206+
def __call__(self, rel_model_path):
207+
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
208+
split_results = load_json(self.config["split_results_path"])
209+
split_positions = split_results[os.path.basename(rel_model_path)][
210+
"split_points"
211+
]
212+
config = {
213+
"split_positions": split_positions,
214+
"group_head_and_tail": self.config.get("group_head_and_tail", False),
215+
"chain_style": self.config.get("chain_style", False),
216+
}
217+
module, inputs = get_torch_module_and_inputs(model_path)
218+
gm = parse_sole_graph_module(module, inputs)
219+
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
220+
gm,
221+
submodule_hook=self.get_naive_decomposer_extractor(model_path),
222+
**config,
223+
)
224+
rewrited_gm(*inputs)
225+
226+
def get_naive_decomposer_extractor(self, model_path):
227+
def fn(submodule, seq_no):
228+
return NaiveDecomposerExtractorModule(
229+
config=self.config,
230+
parent_graph_name=os.path.basename(model_path),
231+
submodule=submodule,
232+
seq_no=seq_no,
233+
)
234+
235+
return fn
236+
237+
154238
class NaiveDecomposerExtractorModule(torch.nn.Module):
155239
def __init__(
156240
self,

graph_net/torch/typical_sequence_split_points.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,12 @@ def get_input_dict(self, model_path: str, device: str) -> Dict[str, torch.Tensor
8282

8383

8484
class SplitAnalyzer:
85-
def __init__(self, window_size: int = 10):
85+
def __init__(
86+
self, window_size: int = 10, fold_policy: str = "default", fold_times: int = 0
87+
):
8688
self.window_size = window_size
89+
self.fold_policy = fold_policy
90+
self.fold_times = fold_times
8791

8892
def _resolve_token_to_ops(
8993
self, tid, num_primitives, token_id2primitive_id, symbol_map
@@ -169,7 +173,9 @@ def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]:
169173
return {}
170174

171175
rp_parser = RpExprParser(
172-
window_size=self.window_size, fold_policy="default", fold_times=0
176+
window_size=self.window_size,
177+
fold_policy=self.fold_policy,
178+
fold_times=self.fold_times,
173179
)
174180
rp_expr, token_id2primitive_id = rp_parser(inputs_seqs)
175181
rp_expr.try_unwrap_body_of_sole_symbol_token()
@@ -253,7 +259,11 @@ def _print_analysis(self, name, path, splits, total_len, full_ops):
253259

254260

255261
def main(args):
256-
analyzer = SplitAnalyzer(window_size=args.window_size)
262+
analyzer = SplitAnalyzer(
263+
window_size=args.window_size,
264+
fold_policy=args.fold_policy,
265+
fold_times=args.fold_times,
266+
)
257267
results = analyzer.analyze(args.model_list, args.device)
258268
if args.output_json:
259269
with open(args.output_json, "w") as f:
@@ -279,6 +289,18 @@ def main(args):
279289
parser.add_argument(
280290
"--window-size", type=int, default=10, help="Window size for RP Parser."
281291
)
292+
parser.add_argument(
293+
"--fold-policy",
294+
type=str,
295+
default="default",
296+
help="Policy for split analysis, one of 'default' or 'longest'",
297+
)
298+
parser.add_argument(
299+
"--fold-times",
300+
type=int,
301+
default=0,
302+
help="How many times to fold tokens. If 0, then no folding is done.",
303+
)
282304
parser.add_argument(
283305
"--output-json",
284306
type=str,

0 commit comments

Comments
 (0)