Skip to content

Commit 779e6f3

Browse files
pytorchbotzou3519
andauthored
Add flag to fx.passes.split_module to normalize input names (pytorch#157793)
Add flag to fx.passes.split_module to normalize input names (pytorch#157733) This is useful for vLLM, which runs AOTAutograd directly on graphs after they have been split. I created a new flag for this instead of reusing `keep_original_node_name` (please let me know if you think I should reuse this). The reasoning is: - The names of the placeholder nodes is different from the targets of the placehoder nodes. The targets are the actual input names. - Backwards compatibility: this API has been out for ~4 years, it looks public, and it has extensive public use. For example, this change would actually be BC-breaking to vLLM (they rely on the subgraph input names being different at the moment). Test Plan: - new tests Pull Request resolved: pytorch#157733 Approved by: https://github.com/ezyang (cherry picked from commit b9afdd9) Co-authored-by: rzou <[email protected]>
1 parent e52eeb7 commit 779e6f3

File tree

3 files changed

+65
-12
lines changed

3 files changed

+65
-12
lines changed

test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.no
6464
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
6565
torch.fx.passes.reinplace.reinplace(gm, *sample_args)
6666
torch.fx.passes.runtime_assert.insert_deferred_runtime_asserts(gm: torch.fx.graph_module.GraphModule, shape_env: Any, name: str, export: bool = False) -> None
67-
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False)
67+
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, keep_original_input_name: bool = True)
6868
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
6969
torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None)
7070
torch.fx.proxy.Proxy.keys(self)

test/test_fx_experimental.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,46 @@ def mod_partition(node: Node):
791791

792792
self.assertEqual(orig_out, submodules_out)
793793

794+
def test_split_module_input_names(self):
795+
class Mod(torch.nn.Module):
796+
def forward(self, x, a0, a1, b0, b1, c0, c1):
797+
x = x + (a0 ** 2) + (a1 / 2)
798+
x = x + (b0 ** 2) + (b1 / 2)
799+
x = x + (c0 ** 2) + (c1 / 2)
800+
return x
801+
802+
mod = Mod()
803+
traced = torch.fx.symbolic_trace(mod)
804+
805+
seen = 0
806+
807+
def split(n):
808+
nonlocal seen
809+
result = seen // 4
810+
seen += 1
811+
return result
812+
813+
split = split_module(traced, mod, split, keep_original_input_name=False)
814+
815+
# All the submodules should take in the inputs in the same order.
816+
args = [torch.tensor(2.), torch.tensor(3.), torch.tensor(4.)]
817+
output0 = split.submod_0(*args)
818+
output1 = split.submod_1(*args)
819+
output2 = split.submod_2(*args)
820+
self.assertEqual(output0, output1)
821+
self.assertEqual(output1, output2)
822+
823+
# Each submodule should have normalized input names
824+
def check_ph(gm):
825+
nodes = list(gm.graph.nodes)
826+
self.assertEqual(nodes[0].target, "arg_0")
827+
self.assertEqual(nodes[1].target, "arg_1")
828+
self.assertEqual(nodes[2].target, "arg_2")
829+
830+
check_ph(split.submod_0)
831+
check_ph(split.submod_1)
832+
check_ph(split.submod_2)
833+
794834
def test_split_module_dead_code(self):
795835
class ModWithDeadCode(torch.nn.Module):
796836
def forward(self, x):

torch/fx/passes/split_module.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def split_module(
5858
qualname_map: Optional[dict[str, str]] = None,
5959
keep_original_order: Optional[bool] = False,
6060
keep_original_node_name: Optional[bool] = False,
61+
keep_original_input_name: bool = True,
6162
):
6263
"""
6364
Creates subgraphs out of main graph
@@ -76,7 +77,10 @@ def split_module(
7677
names in the original module.
7778
keep_original_order: Optional[bool]: keep the original order of the GraphModule
7879
or use the Topological order of the new constructed GraphModule
79-
80+
keep_original_node_name: Optional[bool]: If the partitioned graphs should
81+
have the same node names as the original graph.
82+
keep_original_input_name: bool: If the partitioned graphs should
83+
have the same input names as the original graph.
8084
8185
Returns:
8286
GraphModule: the module after split.
@@ -419,11 +423,28 @@ def instantiate_node_partition_mapping(node):
419423
for partition_name in sorted_partitions:
420424
partition = partitions[partition_name]
421425
new_inputs: dict[str, None] = {}
426+
427+
counter = 0
428+
422429
for inp in partition.inputs:
423430
orig_node = orig_nodes[inp]
424431
# We don't pass in get_attr nodes as inputs to the partition, but
425432
# instead set them as targets and use getattr within the module
426433

434+
def add_placeholder():
435+
if keep_original_input_name:
436+
name = inp
437+
else:
438+
nonlocal counter
439+
name = f"arg_{counter}"
440+
counter += 1
441+
placeholder = partition.graph.placeholder(
442+
name,
443+
type_expr=orig_nodes[inp].type,
444+
)
445+
new_inputs[inp] = None
446+
return placeholder
447+
427448
if orig_node.op == "get_attr":
428449
assert isinstance(orig_node.target, str)
429450

@@ -432,17 +453,9 @@ def instantiate_node_partition_mapping(node):
432453
placeholder = partition.graph.get_attr(orig_node.target)
433454
partition.targets[orig_node.target] = orig_attr
434455
else:
435-
placeholder = partition.graph.placeholder(
436-
inp,
437-
type_expr=orig_nodes[inp].type,
438-
)
439-
new_inputs[inp] = None
456+
placeholder = add_placeholder()
440457
else:
441-
placeholder = partition.graph.placeholder(
442-
inp,
443-
type_expr=orig_nodes[inp].type,
444-
)
445-
new_inputs[inp] = None
458+
placeholder = add_placeholder()
446459
placeholder.meta = orig_nodes[inp].meta.copy()
447460
partition.environment[orig_nodes[inp]] = placeholder
448461
partition.inputs = new_inputs

0 commit comments

Comments
 (0)