Skip to content

Commit c522065

Browse files
committed
support --subgraph-ranges-json for typical_sequence_split_points.py
1 parent e5088b3 commit c522065

File tree

4 files changed

+111
-31
lines changed

4 files changed

+111
-31
lines changed

graph_net/tools/typical_sequence_decompose.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@ EOF
2424
)
2525

2626
python3 -m graph_net.torch.typical_sequence_split_points \
27-
--enable-resume \
2827
--model-list "$model_list" \
2928
--op-names-path-prefix "$DECOMPOSE_WORKSPACE" \
3029
--device "cuda" \
3130
--window-size 10 \
3231
--fold-policy default \
3332
--fold-times 10 \
33+
--min-seq-ops 4 \
34+
--max-seq-ops 16 \
35+
--subgraph-ranges-json "$DECOMPOSE_WORKSPACE/subgraph_ranges.json" \
3436
--output-json "$DECOMPOSE_WORKSPACE/split_results.json"
3537

3638
python3 -m graph_net.model_path_handler \
@@ -40,10 +42,11 @@ python3 -m graph_net.model_path_handler \
4042
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/graph_decomposer.py",
4143
"handler_class_name": "RangeDecomposerExtractor",
4244
"handler_config": {
43-
"resume": true,
45+
"resume": false,
4446
"model_path_prefix": "$GRAPH_NET_ROOT",
4547
"output_dir": "$DECOMPOSE_WORKSPACE",
4648
"split_results_path": "$DECOMPOSE_WORKSPACE/split_results.json",
49+
"subgraph_ranges_path": "$DECOMPOSE_WORKSPACE/subgraph_ranges.json",
4750
"group_head_and_tail": true,
4851
"chain_style": false
4952
}

graph_net/torch/decompose_util.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
def convert_to_submodules_graph(
99
gm: torch.fx.GraphModule,
1010
split_positions: list[int],
11+
subgraph_ranges: list[(int, int)] = None,
1112
submodule_hook=None,
1213
submodule_name_prefix="extracted_submodule",
1314
chain_style=False,
@@ -26,21 +27,42 @@ def convert_to_submodules_graph(
2627
"output",
2728
}
2829
]
29-
split_positions = (
30-
[0, *split_positions, len(submodules_body_nodes)]
31-
if group_head_and_tail
32-
else split_positions
30+
31+
def get_range_idx2range_by_split_positions():
32+
nonlocal split_positions
33+
split_positions = (
34+
[0, *split_positions, len(submodules_body_nodes)]
35+
if group_head_and_tail
36+
else split_positions
37+
)
38+
split_positions = [
39+
max(0, min(pos, len(submodules_body_nodes))) for pos in split_positions
40+
]
41+
return [
42+
(start, end)
43+
for i in range(len(split_positions) - 1)
44+
for start in [split_positions[i]]
45+
for end in [split_positions[i + 1]]
46+
if end > start
47+
]
48+
49+
def get_range_idx2range_by_subgraph_ranges():
50+
assert submodules_body_nodes is not None
51+
num_nodes = len(submodules_body_nodes)
52+
for i in range(len(subgraph_ranges)):
53+
start, end = subgraph_ranges[i]
54+
assert start >= 0
55+
assert start < end
56+
assert end <= num_nodes
57+
# check disjoint
58+
assert i == 0 or start >= subgraph_ranges[i - 1][1], f"{i=}"
59+
return subgraph_ranges
60+
61+
range_idx2range = (
62+
get_range_idx2range_by_split_positions()
63+
if (chain_style or submodules_body_nodes is None)
64+
else get_range_idx2range_by_subgraph_ranges()
3365
)
34-
split_positions = [
35-
max(0, min(pos, len(submodules_body_nodes))) for pos in split_positions
36-
]
37-
range_idx2range = [
38-
(start, end)
39-
for i in range(len(split_positions) - 1)
40-
for start in [split_positions[i]]
41-
for end in [split_positions[i + 1]]
42-
if end > start
43-
]
4466
range_idx2submodule_body_nodes = [
4567
submodules_body_nodes[start:end] for start, end in range_idx2range
4668
]

graph_net/torch/graph_decomposer.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def _make_config(
170170
self,
171171
resume: bool = False,
172172
split_results_path=None,
173+
subgraph_ranges_path=None,
173174
group_head_and_tail=False,
174175
chain_style=False,
175176
output_dir="./tmp/naive_decomposer_dir",
@@ -180,13 +181,18 @@ def _make_config(
180181
):
181182
if os.path.isfile(split_results_path) and split_results_path.endswith(".json"):
182183
pass
184+
elif os.path.isfile(subgraph_ranges_path) and subgraph_ranges_path.endswith(
185+
".json"
186+
):
187+
pass
183188
else:
184189
raise ValueError(
185190
f"split_results_path should be a valid JSON file path, but got {split_results_path=}"
186191
)
187192
return {
188193
"resume": resume,
189194
"split_results_path": split_results_path,
195+
"subgraph_ranges_path": subgraph_ranges_path,
190196
"group_head_and_tail": group_head_and_tail,
191197
"chain_style": chain_style,
192198
"output_dir": output_dir,
@@ -195,8 +201,19 @@ def _make_config(
195201
"model_path_prefix": model_path_prefix,
196202
}
197203

198-
def _is_model_handled(self, rel_model_path, split_positions):
199-
num_subgraphs = len(split_positions) + 1
204+
def _is_model_handled(self, rel_model_path, split_positions, subgraph_ranges):
205+
if self.config["chain_style"]:
206+
return self._has_enough_subgraphs(
207+
rel_model_path,
208+
num_subgraphs=len(split_positions) + 1,
209+
)
210+
else:
211+
return self._has_enough_subgraphs(
212+
rel_model_path,
213+
num_subgraphs=len(subgraph_ranges),
214+
)
215+
216+
def _has_enough_subgraphs(self, rel_model_path, num_subgraphs):
200217
decomposed_model_path = Path(self.config["output_dir"]) / rel_model_path
201218
num_decomposed = len(list(decomposed_model_path.rglob("model.py")))
202219
if num_decomposed > 0 and num_subgraphs != num_decomposed:
@@ -209,21 +226,26 @@ def __call__(self, rel_model_path):
209226
self.config["output_dir"]
210227
)
211228
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
212-
split_results = load_json(self.config["split_results_path"])
229+
split_positions_json = load_json(self.config["split_results_path"])
230+
subgraph_ranges_json = load_json(self.config["subgraph_ranges_path"])
213231
if (
214-
split_results[rel_model_path]["split_positions"] is None
215-
or len(split_results[rel_model_path]["split_positions"]) == 0
232+
split_positions_json[rel_model_path]["split_positions"] is None
233+
or len(split_positions_json[rel_model_path]["split_positions"]) == 0
234+
or subgraph_ranges_json[rel_model_path]["subgraph_ranges"] is None
235+
or len(subgraph_ranges_json[rel_model_path]["subgraph_ranges"]) == 0
216236
):
217237
sys.stderr.write(f"Error: {rel_model_path} has no split positions.\n")
218238
return
219-
split_positions = split_results[rel_model_path]["split_positions"]
239+
split_positions = split_positions_json[rel_model_path]["split_positions"]
240+
subgraph_ranges = subgraph_ranges_json[rel_model_path]["subgraph_ranges"]
220241
if self.config["resume"] and self._is_model_handled(
221-
rel_model_path, split_positions
242+
rel_model_path, split_positions, subgraph_ranges
222243
):
223244
return
224245
torch.cuda.empty_cache()
225246
config = {
226247
"split_positions": split_positions,
248+
"subgraph_ranges": subgraph_ranges,
227249
"group_head_and_tail": self.config.get("group_head_and_tail", False),
228250
"chain_style": self.config.get("chain_style", False),
229251
}

graph_net/torch/typical_sequence_split_points.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ def get_debug_sprintf():
201201
symbol_map = dict(zip(rp_expr.symbol_token_ids, rp_expr.symbol_token_tensors))
202202
token2len = self._calculate_token_lengths(rp_expr, num_primitives, symbol_map)
203203

204-
results = {}
204+
split_positions_json = {}
205+
subgraph_ranges_json = {}
205206

206207
for i, (model_name, original_path) in enumerate(valid_models):
207208
if i >= len(rp_expr.body_rp_expr):
@@ -222,12 +223,14 @@ def get_debug_sprintf():
222223

223224
total_len = sum(token2len.get(t, 1) for t in seq_tokens)
224225

226+
subgraph_ranges = list(
227+
tree.FilterSubTreeRangeBySize(self.min_seq_ops, self.max_seq_ops)
228+
)
229+
225230
sorted_splits = sorted(
226231
set(
227232
split_pos
228-
for start, end in tree.FilterSubTreeRangeBySize(
229-
self.min_seq_ops, self.max_seq_ops
230-
)
233+
for start, end in subgraph_ranges
231234
for split_pos in (start, end)
232235
if end - start > 1
233236
)
@@ -237,13 +240,31 @@ def get_debug_sprintf():
237240
model_name, str(original_path), sorted_splits, total_len, full_model_ops
238241
)
239242

240-
results[str(original_path)] = {
243+
split_positions_json[str(original_path)] = {
241244
"model_name": model_name,
242245
"split_positions": sorted_splits,
243246
"total_length": total_len,
244247
}
245248

246-
return results
249+
sorted_subgraph_ranges = sorted(
250+
set((start, end) for start, end in subgraph_ranges if end - start > 1)
251+
)
252+
253+
# make sorted_subgraph_ranges is a disjoint set
254+
sorted_subgraph_ranges = [
255+
sorted_subgraph_ranges[i]
256+
for i in range(len(sorted_subgraph_ranges))
257+
if i == 0
258+
or sorted_subgraph_ranges[i][0] >= sorted_subgraph_ranges[i - 1][1]
259+
]
260+
261+
subgraph_ranges_json[str(original_path)] = {
262+
"model_name": model_name,
263+
"subgraph_ranges": sorted_subgraph_ranges,
264+
"total_length": total_len,
265+
}
266+
267+
return split_positions_json, subgraph_ranges_json
247268

248269
def _print_analysis(self, name, path, splits, total_len, full_ops):
249270
print("=" * 60)
@@ -292,10 +313,16 @@ def main(args):
292313
min_seq_ops=args.min_seq_ops,
293314
max_seq_ops=args.max_seq_ops,
294315
)
295-
results = analyzer.analyze(args.op_names_path_prefix, args.model_list, args.device)
316+
split_positions_json, subgraph_ranges_json = analyzer.analyze(
317+
args.op_names_path_prefix, args.model_list, args.device
318+
)
296319
if args.output_json:
297320
with open(args.output_json, "w") as f:
298-
json.dump(results, f, indent=4)
321+
json.dump(split_positions_json, f, indent=4)
322+
print(f"{args.subgraph_ranges_json=}")
323+
if args.subgraph_ranges_json:
324+
with open(args.subgraph_ranges_json, "w") as f:
325+
json.dump(subgraph_ranges_json, f, indent=4)
299326

300327

301328
if __name__ == "__main__":
@@ -335,6 +362,12 @@ def main(args):
335362
default=0,
336363
help="How many times to fold tokens. If 0, then no folding is done.",
337364
)
365+
parser.add_argument(
366+
"--subgraph-ranges-json",
367+
type=str,
368+
default="subgraph_ranges.json",
369+
help="Path to save the subgraph ranges in JSON format.",
370+
)
338371
parser.add_argument(
339372
"--output-json",
340373
type=str,

0 commit comments

Comments
 (0)