Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,22 +357,31 @@ def prune_pytree_flatten_unflatten(
[tensors and specs] ==> (in-coming) pytree.unflatten ==> "preserved module"
"""

def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]:
def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node, str]:
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `nodes`.
for node in mod.graph.nodes:
if node.op == "call_module" and node.target == fqn:
return mod, node
return mod, node, fqn
assert "." in fqn, f"can't find {fqn} in the graph of {mod}"
curr, fqn = fqn.split(".", maxsplit=1)
mod = getattr(mod, curr)
return _get_graph_node(mod, fqn)

# remove tree_unflatten from the in_fqns (in-coming nodes)
for fqn in in_fqns:
submodule, node = _get_graph_node(module, fqn)
submodule, node, submod_name = _get_graph_node(module, fqn)

# kt_regroup node will have either one arg or one kwarg
assert len(node.args) == 1 or len(node.kwargs) == 1
use_args = len(node.args) == 1
assert use_args or len(node.kwargs) == 1

# Incase the kt_regroup module is partitioned to a submodule, we need
# to check the parent module for tree_unflatten node.
if use_args and cast(Node, node.args[0]).op == "placeholder":
submodule, node, _ = _get_graph_node(
module, fqn.replace("." + submod_name, "")
)
assert len(node.args) == 1

getitem_getitem = cast(
Node, node.args[0] if use_args else list(node.kwargs.values())[0]
Expand Down Expand Up @@ -403,7 +412,7 @@ def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]:

# remove tree_flatten_spec from the out_fqns (out-going nodes)
for fqn in out_fqns:
submodule, node = _get_graph_node(module, fqn)
submodule, node, _ = _get_graph_node(module, fqn)
users = list(node.users.keys())
assert (
len(users) == 1
Expand Down
Loading