Skip to content

Commit cd9810a

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/GraphNet into fix
2 parents 4cb45e7 + 0404da5 commit cd9810a

File tree

3 files changed

+161
-13
lines changed

3 files changed

+161
-13
lines changed

graph_net/sample_pass/fusible_subgraph_ranges_generator.py

Lines changed: 135 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
import json
55
from itertools import groupby
6+
from dataclasses import dataclass
67

78

89
class FusibleSubgraphRangesGenerator(SamplePass, ResumableSamplePassMixin):
@@ -29,7 +30,9 @@ def sample_handled(self, rel_model_path: str) -> bool:
2930

3031
def resume(self, rel_model_path: str):
3132
analyzer = self._make_analyzer(rel_model_path)
32-
output_obj = analyzer.analyze()
33+
output_obj = {
34+
"subgraph_ranges": analyzer.analyze(),
35+
}
3336
self._save_output(rel_model_path, output_obj)
3437

3538
def _save_output(self, rel_model_path, output_obj):
@@ -82,27 +85,125 @@ def __init__(
8285
self.start_offset_in_original_graph = start_offset_in_original_graph
8386

8487
def analyze(self):
85-
num_kernels_and_num_ops_list: list[
86-
(int, list[int])
87-
] = self._make_num_kernels_and_num_ops_list()
88-
num_kernels_and_num_ops_list = sorted(
89-
num_kernels_and_num_ops_list, key=lambda pair: pair[0]
90-
)
91-
num_ops_lists = [
92-
sorted(num_ops_list)
88+
analysis_ctx = self._make_analysis_ctx()
89+
num_kernels_and_num_ops_list = analysis_ctx.num_kernels_and_num_ops_list
90+
# The tail num_kernels equals the head num_kernels for each num_ops_list
91+
naive_proposal_fused_num_ops_lists = [
92+
sorted(set(num_ops_list))
9393
for _, num_ops_list in num_kernels_and_num_ops_list
9494
if len(set(num_ops_list)) > 1
9595
]
96+
proposal_fused_num_ops_lists = self._merge_all_decreasing_num_ops_lists(
97+
analysis_ctx, naive_proposal_fused_num_ops_lists
98+
)
99+
return self._create_subgraph_ranges_from_proposal(
100+
analysis_ctx,
101+
proposal_fused_num_ops_lists,
102+
)
103+
104+
def _merge_all_decreasing_num_ops_lists(self, analysis_ctx, num_ops_lists):
105+
dead_loop_detect_cnt = 0
106+
kLimit = 99999
107+
while True:
108+
last_len_num_ops_lists = len(num_ops_lists)
109+
num_ops_lists = self._merge_one_decreasing_num_ops_lists(
110+
analysis_ctx, num_ops_lists
111+
)
112+
assert last_len_num_ops_lists >= len(num_ops_lists)
113+
if last_len_num_ops_lists == len(num_ops_lists):
114+
break
115+
dead_loop_detect_cnt += 1
116+
assert dead_loop_detect_cnt < kLimit, f"{dead_loop_detect_cnt=}"
117+
return num_ops_lists
118+
119+
def _merge_one_decreasing_num_ops_lists(self, analysis_ctx, num_ops_lists):
120+
merge_pos = self._detect_mergable_decreasing_position(
121+
analysis_ctx, num_ops_lists
122+
)
123+
if merge_pos is None:
124+
return num_ops_lists
125+
assert merge_pos >= 0
126+
assert merge_pos < len(num_ops_lists) - 1
127+
return [
128+
*num_ops_lists[:merge_pos],
129+
[*num_ops_lists[merge_pos], *num_ops_lists[merge_pos + 1]],
130+
*num_ops_lists[merge_pos + 2 :],
131+
]
132+
133+
def _detect_mergable_decreasing_position(self, analysis_ctx, num_ops_lists):
134+
def get_cur_tail_num_kernels(i):
135+
return analysis_ctx.num_kernels4num_ops(num_ops_lists[i][-1])
136+
137+
def get_next_head_num_kernels(i):
138+
return analysis_ctx.num_kernels4num_ops(num_ops_lists[i + 1][0])
139+
140+
for i in range(len(num_ops_lists) - 1):
141+
assert len(num_ops_lists[i]) > 1
142+
if get_cur_tail_num_kernels(i) >= get_next_head_num_kernels(i):
143+
return i
144+
return None
145+
146+
def _create_subgraph_ranges_from_proposal(
147+
self, analysis_ctx, proposal_fused_num_ops_lists
148+
):
149+
# filter valid num_ops_list
150+
151+
def is_a_range(int_list):
152+
assert len(int_list) > 1
153+
return (int_list[-1] + 1) - int_list[0] == len(int_list)
154+
155+
def have_any_increasing(num_ops_list: list[int]):
156+
for i, cur_num_ops in enumerate(num_ops_list):
157+
if i == 0:
158+
continue
159+
cur_num_kernels = analysis_ctx.num_kernels4num_ops(cur_num_ops)
160+
last_num_kernels = analysis_ctx.num_kernels4num_ops(num_ops_list[i - 1])
161+
if cur_num_kernels > last_num_kernels:
162+
return True
163+
return False
164+
165+
def head_eq_tail(num_ops_list: list[int]):
166+
return analysis_ctx.num_kernels4num_ops(
167+
num_ops_list[0]
168+
) == analysis_ctx.num_kernels4num_ops(num_ops_list[-1])
169+
170+
def head_gt_tail(num_ops_list: list[int]):
171+
return analysis_ctx.num_kernels4num_ops(
172+
num_ops_list[0]
173+
) > analysis_ctx.num_kernels4num_ops(num_ops_list[-1])
174+
175+
def valid_fused_ops(num_ops_list: list[int]):
176+
if head_gt_tail(num_ops_list):
177+
return True
178+
if head_eq_tail(num_ops_list):
179+
return not have_any_increasing(num_ops_list)
180+
return False
181+
182+
proposal_fused_num_ops_lists = [
183+
sorted(set(num_ops_list)) for num_ops_list in proposal_fused_num_ops_lists
184+
]
185+
num_ops_lists = [
186+
num_ops_list
187+
for num_ops_list in proposal_fused_num_ops_lists
188+
if len(num_ops_list) > 1
189+
if is_a_range(num_ops_list)
190+
if valid_fused_ops(num_ops_list)
191+
]
96192
fusible_subgraph_ranges = [
97193
(start, end)
98194
for num_ops_list in num_ops_lists
99195
for start in [num_ops_list[0] - 1]
100196
for end in [num_ops_list[-1]]
101197
]
198+
102199
# sorted by `start`
103-
fusible_subgraph_ranges = sorted(
104-
fusible_subgraph_ranges, key=lambda pair: pair[0]
105-
)
200+
def range_sort_key(pair):
201+
start, end = pair
202+
# smaller `start` first
203+
# bigger `end` first
204+
return (start, -end)
205+
206+
fusible_subgraph_ranges = sorted(fusible_subgraph_ranges, key=range_sort_key)
106207
# remove shadowed
107208
fusible_subgraph_ranges = [
108209
fusible_subgraph_ranges[i]
@@ -112,6 +213,15 @@ def analyze(self):
112213
]
113214
return fusible_subgraph_ranges
114215

216+
def _make_analysis_ctx(self):
217+
return AnalysisContext(
218+
num_kernels_and_num_ops_list=self._make_num_kernels_and_num_ops_list(),
219+
num_ops2num_kernels=self._make_num_ops2num_kernels(),
220+
)
221+
222+
def _make_num_ops2num_kernels(self):
223+
return dict(zip(self.num_subgraph_ops_list, self.num_subgraph_kernels_list))
224+
115225
def _make_num_kernels_and_num_ops_list(self):
116226
num_kernels_and_num_ops = zip(
117227
self.num_subgraph_kernels_list,
@@ -121,7 +231,10 @@ def _make_num_kernels_and_num_ops_list(self):
121231
def get_num_kernels(pair):
122232
return pair[0]
123233

124-
num_kernels_and_num_ops = sorted(num_kernels_and_num_ops, key=get_num_kernels)
234+
def get_num_ops(pair):
235+
return pair[1]
236+
237+
num_kernels_and_num_ops = sorted(num_kernels_and_num_ops, key=get_num_ops)
125238
grouped_num_kernels_and_num_ops = groupby(
126239
num_kernels_and_num_ops, key=get_num_kernels
127240
)
@@ -130,3 +243,12 @@ def get_num_kernels(pair):
130243
for num_kernels, group in grouped_num_kernels_and_num_ops
131244
]
132245
return num_kernels_and_num_ops_list
246+
247+
248+
@dataclass
249+
class AnalysisContext:
250+
num_kernels_and_num_ops_list: list[(int, list[int])]
251+
num_ops2num_kernels: dict[int, int]
252+
253+
def num_kernels4num_ops(self, num_ops: int):
254+
return self.num_ops2num_kernels[num_ops]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/bin/bash
2+
set -x
3+
4+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
5+
DECOMPOSE_PATH=/tmp/decompose_workspace
6+
# DECOMPOSE_PATH=$GRAPH_NET_ROOT/decompose_test_level5_100
7+
8+
mkdir -p "$DECOMPOSE_PATH"
9+
10+
model_list="$GRAPH_NET_ROOT/graph_net/test/dev_model_list/cumsum_num_kernels_sample_list.txt"
11+
12+
python3 -m graph_net.model_path_handler \
13+
--model-path-list $model_list \
14+
--handler-config=$(base64 -w 0 <<EOF
15+
{
16+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/typical_sequence_split_points.py",
17+
"handler_class_name": "OpNamesExtractor",
18+
"handler_config": {
19+
"resume": true,
20+
"model_path_prefix": "$GRAPH_NET_ROOT",
21+
"output_dir": "$DECOMPOSE_PATH"
22+
}
23+
}
24+
EOF
25+
)
26+

samples/fcn_resnet50.zip

17.1 KB
Binary file not shown.

0 commit comments

Comments
 (0)