Skip to content

Commit 398fdd3

Browse files
chenmilliepytorchmergebot
authored andcommitted
[Inductor] Lower fallback nodes annotated with "should_fallback" (pytorch#166339)
Summary: This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as pytorch#164776 in providing flexibility to selectively disable Inductor lowering for specific nodes. The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. Simple example marking nodes for fallback based on custom predicates: ``` def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]): # Apply predicate and mark for fallback if needed if self.predicate(node): node.meta["should_fallback"] = True ``` Test Plan: added a CI test Differential Revision: D85347587 Pull Request resolved: pytorch#166339 Approved by: https://github.com/blaine-rister, https://github.com/eellison
1 parent 5fd1d41 commit 398fdd3

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Owner(s): ["module: inductor"]
2+
"""
3+
Test selective lowering control via node metadata annotations.
4+
"""
5+
6+
from collections.abc import Callable
7+
8+
import torch
9+
from torch._inductor.test_case import TestCase as InductorTestCase
10+
from torch.testing._internal.common_utils import instantiate_parametrized_tests
11+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
12+
13+
14+
@instantiate_parametrized_tests
15+
class SelectiveLoweringTest(InductorTestCase):
16+
"""
17+
Tests for user-controllable selective lowering using node.meta annotations.
18+
"""
19+
20+
device = GPU_TYPE
21+
22+
def _mark_nodes_for_fallback(
23+
self, gm: torch.fx.GraphModule, predicate: Callable[[torch.fx.Node], bool]
24+
) -> torch.fx.GraphModule:
25+
"""
26+
Helper method to mark nodes with should_fallback metadata based on a predicate.
27+
"""
28+
for node in gm.graph.nodes:
29+
if node.op == "call_function" and predicate(node):
30+
node.meta["should_fallback"] = True
31+
return gm
32+
33+
def test_basic_selective_lowering(self):
34+
"""
35+
Test that nodes marked for fallback use fallback handlers instead of lowerings.
36+
"""
37+
38+
def foo(x, y):
39+
a = x + y # This will be marked for fallback
40+
b = a * 2 # This will use normal lowering
41+
return b
42+
43+
x = torch.randn(10, device=self.device)
44+
y = torch.randn(10, device=self.device)
45+
46+
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
47+
# Mark all add operations for fallback
48+
def should_fallback_add(node: torch.fx.Node) -> bool:
49+
return node.target == torch.ops.aten.add.Tensor
50+
51+
self._mark_nodes_for_fallback(gm, should_fallback_add)
52+
53+
from torch._inductor.compile_fx import compile_fx
54+
55+
return compile_fx(gm, example_inputs)
56+
57+
compiled_fn = torch.compile(foo, backend=custom_backend)
58+
result = compiled_fn(x, y)
59+
expected = foo(x, y)
60+
61+
self.assertTrue(torch.allclose(result, expected))
62+
63+
def test_no_fallback_when_unmarked(self):
64+
"""
65+
Test that operations without fallback annotation use normal lowering.
66+
"""
67+
68+
def foo(x, y):
69+
return x + y
70+
71+
x = torch.randn(10, device=self.device)
72+
y = torch.randn(10, device=self.device)
73+
74+
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
75+
# Don't mark anything - all operations should use normal lowering
76+
from torch._inductor.compile_fx import compile_fx
77+
78+
return compile_fx(gm, example_inputs)
79+
80+
compiled_fn = torch.compile(foo, backend=custom_backend)
81+
result = compiled_fn(x, y)
82+
expected = foo(x, y)
83+
84+
self.assertTrue(torch.allclose(result, expected))
85+
86+
87+
if __name__ == "__main__":
88+
from torch._inductor.test_case import run_tests
89+
90+
if HAS_GPU:
91+
run_tests(needs="filelock")

torch/_inductor/graph.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,12 @@ def normalize(args: Any, kwargs: Any) -> tuple[Any, Any]:
13221322
else:
13231323
args, kwargs = layout_constraints(n, *args, **kwargs)
13241324

1325-
out = lowerings[target](*args, **kwargs) # type: ignore[index]
1325+
if "should_fallback" in n.meta:
1326+
out = fallback_handler(target, add_to_fallback_set=False)(
1327+
*args, **kwargs
1328+
)
1329+
else:
1330+
out = lowerings[target](*args, **kwargs) # type: ignore[index]
13261331

13271332
if layout_constraints:
13281333
# layout_constraints are allowed to make new copies of the inputs.

0 commit comments

Comments
 (0)