Skip to content

Commit 0d98f9d

Browse files
authored
Mark auto_functionalized HOPs as cacheable (pytorch#151194) (pytorch#153304)
Fixes pytorch#151188 Test Plan: - new tests
1 parent b8d9208 commit 0d98f9d

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

test/inductor/test_codecache.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,55 @@ def fn2(x):
699699
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
700700
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
701701

702+
@config.patch({"fx_graph_cache": True})
703+
@config.patch({"fx_graph_remote_cache": False})
704+
@parametrize("variant", ("v1", "v2"))
705+
def test_auto_functionalized_caching(self, variant):
706+
if variant == "v1":
707+
patch = torch._inductor.config.patch(enable_auto_functionalized_v2=False)
708+
else:
709+
assert variant == "v2"
710+
patch = torch._inductor.config.patch(enable_auto_functionalized_v2=True)
711+
712+
@torch.library.custom_op("mylib::sin_inplace", mutates_args=["x"])
713+
def sin_inplace(x: torch.Tensor) -> None:
714+
x.sin_()
715+
716+
@torch.library.custom_op("mylib::cos_inplace", mutates_args=["x"])
717+
def cos_inplace(x: torch.Tensor) -> None:
718+
x.cos_()
719+
720+
@torch.compile(fullgraph=True)
721+
def fn(x, op):
722+
y = torch.empty_like(x)
723+
op(y)
724+
return y
725+
726+
x = torch.randn(3)
727+
728+
with patch:
729+
# A first call should miss in the cache.
730+
fn(x, sin_inplace)
731+
self.reset()
732+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
733+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
734+
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
735+
736+
# A second call should hit. (First reset so in-memory guards
737+
# don't prevent compilation).
738+
self.reset()
739+
fn(x, sin_inplace)
740+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
741+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
742+
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
743+
744+
# A third call with different operator should have a cache miss
745+
self.reset()
746+
fn(x, cos_inplace)
747+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
748+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
749+
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
750+
702751
@requires_cuda
703752
@config.patch({"fx_graph_cache": True})
704753
@config.patch({"fx_graph_remote_cache": False})

torch/_higher_order_ops/auto_functionalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ class AutoFunctionalized(HigherOrderOperator):
316316
"""
317317

318318
def __init__(self) -> None:
319-
super().__init__("auto_functionalized")
319+
super().__init__("auto_functionalized", cacheable=True)
320320

321321
def __call__(
322322
self,
@@ -345,7 +345,7 @@ class AutoFunctionalizedV2(HigherOrderOperator):
345345
"""
346346

347347
def __init__(self) -> None:
348-
super().__init__("auto_functionalized_v2")
348+
super().__init__("auto_functionalized_v2", cacheable=True)
349349

350350
def __call__(
351351
self,

0 commit comments

Comments
 (0)