Skip to content

Commit be924b0

Browse files
authored
Resolve data flow mismatch between adjacent subgraphs in chain decomposition. (#424)
* init 'symbolic_dimension_reifier' field in graph_net.json * refactor model_path_handler * Resolve data flow mismatch between adjacent subgraphs in chain decomposition.
1 parent 7e02a5d commit be924b0

File tree

3 files changed

+48
-41
lines changed

3 files changed

+48
-41
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
samples/transformers-auto-model/dbmdz_electra-large-discriminator-finetuned-conll03-english

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ config_json_str=$(cat <<EOF
1111
"handler_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
1212
"handler_class_name": "NaiveDecomposerExtractor",
1313
"handler_config": {
14+
"model_path_prefix": "$GRAPH_NET_ROOT/../",
1415
"output_dir": "/tmp/naive_decompose_workspace",
1516
"split_positions": [8, 16, 32],
1617
"chain_style": true,
@@ -21,4 +22,5 @@ EOF
2122
)
2223
CONFIG=$(echo $config_json_str | base64 -w 0)
2324

24-
python3 -m graph_net.model_path_handler --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
25+
# python3 -m graph_net.model_path_handler --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
26+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/decomposition_error_tmp_torch_samples_list.txt --handler-config=$CONFIG

graph_net/torch/decompose_util.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ def convert_to_submodules_graph(
1717
chain_style=True: decompose gm into g0 * g1 * g2 * g3
1818
"""
1919
gm = copy.deepcopy(gm)
20-
num_placeholders = len(
21-
[node for node in gm.graph.nodes if node.op == "placeholder"]
22-
)
2320
submodules_body_nodes = [
2421
node
2522
for node in gm.graph.nodes
@@ -207,6 +204,32 @@ def _get_submodule_inputs_and_outputs(
207204
start_node_idx: int,
208205
end_node_idx: int,
209206
chain_style=False,
207+
):
208+
if not chain_style:
209+
(
210+
minimal_input_nodes,
211+
minimal_output_nodes,
212+
) = _get_minimal_submodule_inputs_and_outputs(
213+
gm=gm, start_node_idx=start_node_idx, end_node_idx=end_node_idx
214+
)
215+
return minimal_input_nodes, minimal_output_nodes, []
216+
else:
217+
node_list = list(gm.graph.nodes)
218+
input_nodes, _ = _get_minimal_submodule_inputs_and_outputs(
219+
gm=gm, start_node_idx=start_node_idx, end_node_idx=len(node_list)
220+
)
221+
output_nodes, _ = _get_minimal_submodule_inputs_and_outputs(
222+
gm=gm, start_node_idx=end_node_idx, end_node_idx=len(node_list)
223+
)
224+
identity_nodes_set = set(input_nodes) & set(output_nodes)
225+
identity_nodes = [node for node in input_nodes if node in identity_nodes_set]
226+
return input_nodes, output_nodes, identity_nodes
227+
228+
229+
def _get_minimal_submodule_inputs_and_outputs(
230+
gm: torch.fx.GraphModule,
231+
start_node_idx: int,
232+
end_node_idx: int,
210233
):
211234
count_ctx = NodeProducedOrConsumedCountCtx(
212235
defaultdict(int),
@@ -215,33 +238,34 @@ def _get_submodule_inputs_and_outputs(
215238
)
216239
node_list = list(gm.graph.nodes)
217240

218-
def _hashable(obj):
219-
if isinstance(obj, slice):
220-
return ("__slice__", obj.start, obj.stop, obj.step)
221-
elif isinstance(obj, (list, tuple)):
222-
return tuple(_hashable(x) for x in obj)
241+
def get_args_node(arg):
242+
if isinstance(arg, torch.fx.Node):
243+
yield arg
244+
elif isinstance(arg, (tuple, list)):
245+
for x in arg:
246+
yield from get_args_node(x)
247+
elif isinstance(arg, slice):
248+
yield arg.start
249+
yield arg.stop
250+
yield arg.step
223251
else:
224-
return obj
252+
assert isinstance(arg, (int, bool, float, str, type(None))), f"{type(arg)=}"
225253

226-
def get_related_node(node):
254+
def get_args_node_and_self_node(node):
227255
for arg in node.args:
228-
if isinstance(arg, tuple):
229-
for x in arg:
230-
yield _hashable(x)
231-
else:
232-
yield _hashable(arg)
233-
yield _hashable(node)
256+
yield from get_args_node(arg)
257+
yield node
234258

235259
for node in node_list[0:start_node_idx]:
236-
for related_node in get_related_node(node):
260+
for related_node in get_args_node_and_self_node(node):
237261
count_ctx.node2before_input[related_node] += 1
238262

239263
for node in node_list[start_node_idx:end_node_idx]:
240-
for related_node in get_related_node(node):
264+
for related_node in get_args_node_and_self_node(node):
241265
count_ctx.node2body[related_node] += 1
242266

243267
for node in node_list[end_node_idx:]:
244-
for related_node in get_related_node(node):
268+
for related_node in get_args_node_and_self_node(node):
245269
count_ctx.node2after_output[related_node] += 1
246270

247271
input_nodes = [
@@ -257,24 +281,4 @@ def get_related_node(node):
257281
if count_ctx.node2body[node] > 0
258282
if count_ctx.node2after_output[node] > 0
259283
]
260-
if not chain_style:
261-
identity_nodes = []
262-
else:
263-
identity_nodes = [
264-
node
265-
for node in node_list
266-
if count_ctx.node2before_input[node] > 0
267-
if count_ctx.node2body[node] == 0
268-
if count_ctx.node2after_output[node] > 0
269-
]
270-
input_nodes_set = set(input_nodes)
271-
input_nodes = [
272-
*input_nodes,
273-
*[node for node in identity_nodes if node not in input_nodes_set],
274-
]
275-
output_nodes_set = set(output_nodes)
276-
output_nodes = [
277-
*output_nodes,
278-
*[node for node in identity_nodes if node not in output_nodes_set],
279-
]
280-
return input_nodes, output_nodes, identity_nodes
284+
return input_nodes, output_nodes

0 commit comments

Comments
 (0)