diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 05bbfe9fb..ba133f900 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -357,11 +357,11 @@ def prune_pytree_flatten_unflatten( [tensors and specs] ==> (in-coming) pytree.unflatten ==> "preserved module" """ - def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]: + def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node, str]: # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `nodes`. for node in mod.graph.nodes: if node.op == "call_module" and node.target == fqn: - return mod, node + return mod, node, fqn assert "." in fqn, f"can't find {fqn} in the graph of {mod}" curr, fqn = fqn.split(".", maxsplit=1) mod = getattr(mod, curr) @@ -369,10 +369,19 @@ def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]: # remove tree_unflatten from the in_fqns (in-coming nodes) for fqn in in_fqns: - submodule, node = _get_graph_node(module, fqn) + submodule, node, submod_name = _get_graph_node(module, fqn) + # kt_regroup node will have either one arg or one kwarg - assert len(node.args) == 1 or len(node.kwargs) == 1 use_args = len(node.args) == 1 + assert use_args or len(node.kwargs) == 1 + + # Incase the kt_regroup module is partitioned to a submodule, we need + # to check the parent module for tree_unflatten node. + if use_args and cast(Node, node.args[0]).op == "placeholder": + submodule, node, _ = _get_graph_node( + module, fqn.replace("." + submod_name, "") + ) + assert len(node.args) == 1 getitem_getitem = cast( Node, node.args[0] if use_args else list(node.kwargs.values())[0] @@ -403,7 +412,7 @@ def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]: # remove tree_flatten_spec from the out_fqns (out-going nodes) for fqn in out_fqns: - submodule, node = _get_graph_node(module, fqn) + submodule, node, _ = _get_graph_node(module, fqn) users = list(node.users.keys()) assert ( len(users) == 1