Skip to content

Commit 7d1b976

Browse files
zhxchen17pytorchmergebot
authored andcommitted
[export] Make dict_keys_getitem tracable. (pytorch#166776)
Summary: dict_keys_getitem can show up in the bytecode but it's using dict.keys() which is not fx tracable. fx.wrap should make it as a standalone function in the graph to be invoked later with real inputs. Test Plan: pytest test/export/test_experimental.py Pull Request resolved: pytorch#166776 Approved by: https://github.com/jamesjwu ghstack dependencies: pytorch#166775
1 parent 27cfdd9 commit 7d1b976

File tree

4 files changed

+47
-1
lines changed

4 files changed

+47
-1
lines changed

test/export/test_experimental.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,44 @@ def make_inputs(b: int):
541541
self.assertIsNotNone(gm.meta["tracing_context"].fake_mode)
542542
self.assertEqual(len(gm.meta["tracing_context"].tensor_to_context), 1)
543543

544+
def test_dynamo_graph_capture_dict_keys_getitem(self):
545+
class Module(torch.nn.Module):
546+
def forward(self, x):
547+
return x * 2
548+
549+
foo = Module()
550+
551+
class BlockMask:
552+
def __init__(self, d):
553+
self.d = d
554+
555+
block_mask = BlockMask(torch.randn(4))
556+
557+
def pre_hook_function(m, input):
558+
block_mask.d = input[0] + 1
559+
return input # Return a tuple of modified inputs
560+
561+
foo.register_forward_pre_hook(pre_hook_function)
562+
563+
def make_inputs():
564+
return (torch.randn(4),)
565+
566+
trace_inputs = make_inputs()
567+
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
568+
test_inputs = make_inputs()
569+
self.assertExpectedInline(
570+
gm.code.strip("\r\n "),
571+
"""\
572+
def forward(self, args_0):
573+
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
574+
L_args_0_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
575+
l_args_0_ = L_args_0_
576+
add = l_args_0_ + 1
577+
mul = l_args_0_ * 2; l_args_0_ = None
578+
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul, add), self._out_spec)""",
579+
)
580+
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
581+
544582
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
545583
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
546584
class DummyOp(torch.autograd.Function):

torch/_dynamo/convert_frame.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,10 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
10001000
import inspect
10011001

10021002
if isinstance(mod, torch.nn.Module):
1003-
mod = mod.forward
1003+
if len(mod._forward_pre_hooks) == 0 and len(mod._forward_hooks) == 0:
1004+
mod = mod.forward
1005+
else:
1006+
mod = mod.__call__
10041007
if hasattr(mod, "__self__"):
10051008
# pyrefly: ignore [missing-attribute]
10061009
return mod.__func__, mod.__self__

torch/_dynamo/functional_export.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,8 @@ def backend_dummy(*example_inputs):
526526
in_shuffle_graph = make_fx(
527527
InShuffle(), tracing_mode="symbolic", proxy_module_inputs=True
528528
)(*flat_real_args)
529+
in_shuffle_graph.graph.eliminate_dead_code()
530+
in_shuffle_graph.recompile()
529531

530532
output_node = next(iter(reversed(backend_input.graph_module.graph.nodes)))
531533

@@ -573,6 +575,8 @@ def backend_dummy(*example_inputs):
573575
out_shuffle_graph = make_fx(
574576
out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True
575577
)(*flat_out_shuffle_args)
578+
out_shuffle_graph.graph.eliminate_dead_code()
579+
out_shuffle_graph.recompile()
576580

577581
assert out_shuffle.out_spec is not None
578582
return PyTreeifyOutput(

torch/_dynamo/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2707,6 +2707,7 @@ def to_subclass(t: Any, cls: type) -> Any:
27072707
dict_getitem = dict.__getitem__
27082708

27092709

2710+
@torch.fx.wrap
27102711
def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any:
27112712
# Call dict(d) to prevent calling overridden __iter__/keys
27122713
dict_class = dict

0 commit comments

Comments
 (0)