Skip to content

Commit 2a49397

Browse files
committed
merge
2 parents dc0a176 + cb5d2b2 commit 2a49397

File tree

4 files changed

+111
-31
lines changed

4 files changed

+111
-31
lines changed

graph_net/tools/typical_sequence_decompose.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@ EOF
2424
)
2525

2626
python3 -m graph_net.torch.typical_sequence_split_points \
27-
--enable-resume \
2827
--model-list "$model_list" \
2928
--op-names-path-prefix "$DECOMPOSE_WORKSPACE" \
3029
--device "cuda" \
3130
--window-size 10 \
3231
--fold-policy default \
3332
--fold-times 10 \
33+
--min-seq-ops 4 \
34+
--max-seq-ops 16 \
35+
--subgraph-ranges-json "$DECOMPOSE_WORKSPACE/subgraph_ranges.json" \
3436
--output-json "$DECOMPOSE_WORKSPACE/split_results.json"
3537

3638
python3 -m graph_net.model_path_handler \
@@ -40,10 +42,11 @@ python3 -m graph_net.model_path_handler \
4042
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/graph_decomposer.py",
4143
"handler_class_name": "RangeDecomposerExtractor",
4244
"handler_config": {
43-
"resume": true,
45+
"resume": false,
4446
"model_path_prefix": "$GRAPH_NET_ROOT",
4547
"output_dir": "$DECOMPOSE_WORKSPACE",
4648
"split_results_path": "$DECOMPOSE_WORKSPACE/split_results.json",
49+
"subgraph_ranges_path": "$DECOMPOSE_WORKSPACE/subgraph_ranges.json",
4750
"group_head_and_tail": true,
4851
"chain_style": false
4952
}

graph_net/torch/decompose_util.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def cuda_gc(enabled: bool = True):
2020
def convert_to_submodules_graph(
2121
gm: torch.fx.GraphModule,
2222
split_positions: list[int],
23+
subgraph_ranges: list[(int, int)] = None,
2324
submodule_hook=None,
2425
submodule_name_prefix="extracted_submodule",
2526
chain_style=False,
@@ -38,21 +39,42 @@ def convert_to_submodules_graph(
3839
"output",
3940
}
4041
]
41-
split_positions = (
42-
[0, *split_positions, len(submodules_body_nodes)]
43-
if group_head_and_tail
44-
else split_positions
42+
43+
def get_range_idx2range_by_split_positions():
44+
nonlocal split_positions
45+
split_positions = (
46+
[0, *split_positions, len(submodules_body_nodes)]
47+
if group_head_and_tail
48+
else split_positions
49+
)
50+
split_positions = [
51+
max(0, min(pos, len(submodules_body_nodes))) for pos in split_positions
52+
]
53+
return [
54+
(start, end)
55+
for i in range(len(split_positions) - 1)
56+
for start in [split_positions[i]]
57+
for end in [split_positions[i + 1]]
58+
if end > start
59+
]
60+
61+
def get_range_idx2range_by_subgraph_ranges():
62+
assert submodules_body_nodes is not None
63+
num_nodes = len(submodules_body_nodes)
64+
for i in range(len(subgraph_ranges)):
65+
start, end = subgraph_ranges[i]
66+
assert start >= 0
67+
assert start < end
68+
assert end <= num_nodes
69+
# check disjoint
70+
assert i == 0 or start >= subgraph_ranges[i - 1][1], f"{i=}"
71+
return subgraph_ranges
72+
73+
range_idx2range = (
74+
get_range_idx2range_by_split_positions()
75+
if (chain_style or submodules_body_nodes is None)
76+
else get_range_idx2range_by_subgraph_ranges()
4577
)
46-
split_positions = [
47-
max(0, min(pos, len(submodules_body_nodes))) for pos in split_positions
48-
]
49-
range_idx2range = [
50-
(start, end)
51-
for i in range(len(split_positions) - 1)
52-
for start in [split_positions[i]]
53-
for end in [split_positions[i + 1]]
54-
if end > start
55-
]
5678
range_idx2submodule_body_nodes = [
5779
submodules_body_nodes[start:end] for start, end in range_idx2range
5880
]

graph_net/torch/graph_decomposer.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def _make_config(
171171
self,
172172
resume: bool = False,
173173
split_results_path=None,
174+
subgraph_ranges_path=None,
174175
group_head_and_tail=False,
175176
chain_style=False,
176177
output_dir="./tmp/naive_decomposer_dir",
@@ -181,13 +182,18 @@ def _make_config(
181182
):
182183
if os.path.isfile(split_results_path) and split_results_path.endswith(".json"):
183184
pass
185+
elif os.path.isfile(subgraph_ranges_path) and subgraph_ranges_path.endswith(
186+
".json"
187+
):
188+
pass
184189
else:
185190
raise ValueError(
186191
f"split_results_path should be a valid JSON file path, but got {split_results_path=}"
187192
)
188193
return {
189194
"resume": resume,
190195
"split_results_path": split_results_path,
196+
"subgraph_ranges_path": subgraph_ranges_path,
191197
"group_head_and_tail": group_head_and_tail,
192198
"chain_style": chain_style,
193199
"output_dir": output_dir,
@@ -196,8 +202,19 @@ def _make_config(
196202
"model_path_prefix": model_path_prefix,
197203
}
198204

199-
def _is_model_handled(self, rel_model_path, split_positions):
200-
num_subgraphs = len(split_positions) + 1
205+
def _is_model_handled(self, rel_model_path, split_positions, subgraph_ranges):
206+
if self.config["chain_style"]:
207+
return self._has_enough_subgraphs(
208+
rel_model_path,
209+
num_subgraphs=len(split_positions) + 1,
210+
)
211+
else:
212+
return self._has_enough_subgraphs(
213+
rel_model_path,
214+
num_subgraphs=len(subgraph_ranges),
215+
)
216+
217+
def _has_enough_subgraphs(self, rel_model_path, num_subgraphs):
201218
decomposed_model_path = Path(self.config["output_dir"]) / rel_model_path
202219
num_decomposed = len(list(decomposed_model_path.rglob("model.py")))
203220
if num_decomposed > 0 and num_subgraphs != num_decomposed:
@@ -210,16 +227,20 @@ def __call__(self, rel_model_path):
210227
self.config["output_dir"]
211228
)
212229
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
213-
split_results = load_json(self.config["split_results_path"])
230+
split_positions_json = load_json(self.config["split_results_path"])
231+
subgraph_ranges_json = load_json(self.config["subgraph_ranges_path"])
214232
if (
215-
split_results[rel_model_path]["split_positions"] is None
216-
or len(split_results[rel_model_path]["split_positions"]) == 0
233+
split_positions_json[rel_model_path]["split_positions"] is None
234+
or len(split_positions_json[rel_model_path]["split_positions"]) == 0
235+
or subgraph_ranges_json[rel_model_path]["subgraph_ranges"] is None
236+
or len(subgraph_ranges_json[rel_model_path]["subgraph_ranges"]) == 0
217237
):
218238
sys.stderr.write(f"Error: {rel_model_path} has no split positions.\n")
219239
return
220-
split_positions = split_results[rel_model_path]["split_positions"]
240+
split_positions = split_positions_json[rel_model_path]["split_positions"]
241+
subgraph_ranges = subgraph_ranges_json[rel_model_path]["subgraph_ranges"]
221242
if self.config["resume"] and self._is_model_handled(
222-
rel_model_path, split_positions
243+
rel_model_path, split_positions, subgraph_ranges
223244
):
224245
return
225246

@@ -235,6 +256,7 @@ def __call__(self, rel_model_path):
235256
gm,
236257
submodule_hook=self.get_naive_decomposer_extractor(rel_model_path),
237258
split_positions=split_positions,
259+
subgraph_ranges=subgraph_ranges,
238260
group_head_and_tail=self.config.get("group_head_and_tail", False),
239261
chain_style=self.config.get("chain_style", False),
240262
)

graph_net/torch/typical_sequence_split_points.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ def get_debug_sprintf():
205205
symbol_map = dict(zip(rp_expr.symbol_token_ids, rp_expr.symbol_token_tensors))
206206
token2len = self._calculate_token_lengths(rp_expr, num_primitives, symbol_map)
207207

208-
results = {}
208+
split_positions_json = {}
209+
subgraph_ranges_json = {}
209210

210211
for i, (model_name, original_path) in enumerate(valid_models):
211212
if i >= len(rp_expr.body_rp_expr):
@@ -226,12 +227,14 @@ def get_debug_sprintf():
226227

227228
total_len = sum(token2len.get(t, 1) for t in seq_tokens)
228229

230+
subgraph_ranges = list(
231+
tree.FilterSubTreeRangeBySize(self.min_seq_ops, self.max_seq_ops)
232+
)
233+
229234
sorted_splits = sorted(
230235
set(
231236
split_pos
232-
for start, end in tree.FilterSubTreeRangeBySize(
233-
self.min_seq_ops, self.max_seq_ops
234-
)
237+
for start, end in subgraph_ranges
235238
for split_pos in (start, end)
236239
if end - start > 1
237240
)
@@ -241,13 +244,31 @@ def get_debug_sprintf():
241244
model_name, str(original_path), sorted_splits, total_len, full_model_ops
242245
)
243246

244-
results[str(original_path)] = {
247+
split_positions_json[str(original_path)] = {
245248
"model_name": model_name,
246249
"split_positions": sorted_splits,
247250
"total_length": total_len,
248251
}
249252

250-
return results
253+
sorted_subgraph_ranges = sorted(
254+
set((start, end) for start, end in subgraph_ranges if end - start > 1)
255+
)
256+
257+
# make sorted_subgraph_ranges is a disjoint set
258+
sorted_subgraph_ranges = [
259+
sorted_subgraph_ranges[i]
260+
for i in range(len(sorted_subgraph_ranges))
261+
if i == 0
262+
or sorted_subgraph_ranges[i][0] >= sorted_subgraph_ranges[i - 1][1]
263+
]
264+
265+
subgraph_ranges_json[str(original_path)] = {
266+
"model_name": model_name,
267+
"subgraph_ranges": sorted_subgraph_ranges,
268+
"total_length": total_len,
269+
}
270+
271+
return split_positions_json, subgraph_ranges_json
251272

252273
def _print_analysis(self, name, path, splits, total_len, full_ops):
253274
print("=" * 60)
@@ -296,10 +317,16 @@ def main(args):
296317
min_seq_ops=args.min_seq_ops,
297318
max_seq_ops=args.max_seq_ops,
298319
)
299-
results = analyzer.analyze(args.op_names_path_prefix, args.model_list, args.device)
320+
split_positions_json, subgraph_ranges_json = analyzer.analyze(
321+
args.op_names_path_prefix, args.model_list, args.device
322+
)
300323
if args.output_json:
301324
with open(args.output_json, "w") as f:
302-
json.dump(results, f, indent=4)
325+
json.dump(split_positions_json, f, indent=4)
326+
print(f"{args.subgraph_ranges_json=}")
327+
if args.subgraph_ranges_json:
328+
with open(args.subgraph_ranges_json, "w") as f:
329+
json.dump(subgraph_ranges_json, f, indent=4)
303330

304331

305332
if __name__ == "__main__":
@@ -339,6 +366,12 @@ def main(args):
339366
default=0,
340367
help="How many times to fold tokens. If 0, then no folding is done.",
341368
)
369+
parser.add_argument(
370+
"--subgraph-ranges-json",
371+
type=str,
372+
default="subgraph_ranges.json",
373+
help="Path to save the subgraph ranges in JSON format.",
374+
)
342375
parser.add_argument(
343376
"--output-json",
344377
type=str,

0 commit comments

Comments
 (0)