Skip to content

Commit e5e94ec

Browse files
jamesjwupytorchmergebot
authored andcommitted
Introduce HOP for inductor compiled regions to allow torch dispatch (pytorch#167844)
This is a cleaned up version of the POC at https://github.com/pytorch/pytorch/pull/167752/files This PR adds a inductor option which you can pass into torch.compile that wraps all inductor generated code in a HOP, allowing it to be read by torch dispatches. This hop is created in output_code.post_compile, so it's cache safe. The configuration to turn it on is part of `inductor_config`, and therefore already part of the cache key. I've added a test that shows this HOP is cache safe. Because this wrapper occurs at compile time, there should be little to no cpu overhead from creating it, besides that of actually processing the torch_dispatches themselves. The context here is we want to be able to support compiled regions such as flex attention in eager mode, while working with other torch dispatch tracers like SAC. Will add more tests for SAC/flex attention specific things next. Pull Request resolved: pytorch#167844 Approved by: https://github.com/ezyang
1 parent ef7fa96 commit e5e94ec

File tree

8 files changed

+1228
-1
lines changed

8 files changed

+1228
-1
lines changed

test/dynamo/test_higher_order_ops.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3395,6 +3395,91 @@ def outer_body_fn(x):
33953395
with self.assertRaisesRegex(RuntimeError, msg):
33963396
fn_with_hints(x, y)
33973397

3398+
@requires_cuda_and_triton
3399+
def test_wrap_inductor_compiled_regions_option(self):
3400+
"""
3401+
Test that wrap_inductor_compiled_regions option wraps compiled regions
3402+
in inductor_compiled_code HOP, making them visible to DebugMode.
3403+
"""
3404+
from torch.utils._debug_mode import DebugMode
3405+
3406+
# Test with wrapping enabled
3407+
@torch.compile(
3408+
backend="inductor",
3409+
options={"wrap_inductor_compiled_regions": True},
3410+
fullgraph=True,
3411+
)
3412+
def fn_wrapped(x, y):
3413+
return torch.matmul(x, y)
3414+
3415+
# Test with wrapping disabled (default)
3416+
@torch.compile(backend="inductor", fullgraph=True)
3417+
def fn_not_wrapped(x, y):
3418+
return torch.matmul(x, y)
3419+
3420+
x = torch.randn(4, 4, device="cuda")
3421+
y = torch.randn(4, 4, device="cuda")
3422+
3423+
# Test wrapped version - HOP should be visible in DebugMode
3424+
with DebugMode() as debug_mode_wrapped:
3425+
result_wrapped = fn_wrapped(x, y)
3426+
3427+
debug_string_wrapped = debug_mode_wrapped.debug_string()
3428+
self.assertIn("inductor_compiled_code", debug_string_wrapped)
3429+
3430+
# Test non-wrapped version - HOP should NOT be visible
3431+
with DebugMode() as debug_mode_not_wrapped:
3432+
result_not_wrapped = fn_not_wrapped(x, y)
3433+
3434+
debug_string_not_wrapped = debug_mode_not_wrapped.debug_string()
3435+
self.assertNotIn("inductor_compiled_code", debug_string_not_wrapped)
3436+
3437+
# Both should produce correct results
3438+
expected = torch.matmul(x, y)
3439+
self.assertEqual(result_wrapped, expected)
3440+
self.assertEqual(result_not_wrapped, expected)
3441+
3442+
@requires_cuda_and_triton
3443+
def test_wrap_inductor_compiled_regions_with_backward(self):
3444+
"""
3445+
Test that wrap_inductor_compiled_regions works correctly with autograd.
3446+
"""
3447+
from torch.utils._debug_mode import DebugMode
3448+
3449+
@torch.compile(
3450+
backend="inductor",
3451+
options={"wrap_inductor_compiled_regions": True},
3452+
fullgraph=True,
3453+
)
3454+
def fn(x, y):
3455+
return torch.matmul(x, y)
3456+
3457+
x = torch.randn(4, 4, device="cuda", requires_grad=True)
3458+
y = torch.randn(4, 4, device="cuda", requires_grad=True)
3459+
3460+
# Clone for eager comparison
3461+
x_eager = x.detach().clone().requires_grad_(True)
3462+
y_eager = y.detach().clone().requires_grad_(True)
3463+
3464+
# Compiled forward and backward
3465+
with DebugMode() as debug_mode:
3466+
result = fn(x, y)
3467+
loss = result.sum()
3468+
loss.backward()
3469+
3470+
# HOP should be visible in forward pass
3471+
self.assertIn("inductor_compiled_code", debug_mode.debug_string())
3472+
3473+
# Eager forward and backward for comparison
3474+
expected = torch.matmul(x_eager, y_eager)
3475+
expected_loss = expected.sum()
3476+
expected_loss.backward()
3477+
3478+
# Check correctness
3479+
self.assertEqual(result, expected)
3480+
self.assertEqual(x.grad, x_eager.grad)
3481+
self.assertEqual(y.grad, y_eager.grad)
3482+
33983483

33993484
class HigherOrderOpVmapGuardTests(
34003485
torch._dynamo.test_case.TestCaseWithNestedGraphBreaks, LoggingTestCase

0 commit comments

Comments
 (0)