@@ -699,6 +699,55 @@ def fn2(x):
699699 self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 2 )
700700 self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 0 )
701701
702+ @config .patch ({"fx_graph_cache" : True })
703+ @config .patch ({"fx_graph_remote_cache" : False })
704+ @parametrize ("variant" , ("v1" , "v2" ))
705+ def test_auto_functionalized_caching (self , variant ):
706+ if variant == "v1" :
707+ patch = torch ._inductor .config .patch (enable_auto_functionalized_v2 = False )
708+ else :
709+ assert variant == "v2"
710+ patch = torch ._inductor .config .patch (enable_auto_functionalized_v2 = True )
711+
712+ @torch .library .custom_op ("mylib::sin_inplace" , mutates_args = ["x" ])
713+ def sin_inplace (x : torch .Tensor ) -> None :
714+ x .sin_ ()
715+
716+ @torch .library .custom_op ("mylib::cos_inplace" , mutates_args = ["x" ])
717+ def cos_inplace (x : torch .Tensor ) -> None :
718+ x .cos_ ()
719+
720+ @torch .compile (fullgraph = True )
721+ def fn (x , op ):
722+ y = torch .empty_like (x )
723+ op (y )
724+ return y
725+
726+ x = torch .randn (3 )
727+
728+ with patch :
729+ # A first call should miss in the cache.
730+ fn (x , sin_inplace )
731+ self .reset ()
732+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 1 )
733+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 0 )
734+ self .assertEqual (counters ["inductor" ]["fxgraph_lookup_write_file" ], 0 )
735+
736+ # A second call should hit. (First reset so in-memory guards
737+ # don't prevent compilation).
738+ self .reset ()
739+ fn (x , sin_inplace )
740+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 1 )
741+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 1 )
742+ self .assertEqual (counters ["inductor" ]["fxgraph_lookup_write_file" ], 1 )
743+
744+ # A third call with different operator should have a cache miss
745+ self .reset ()
746+ fn (x , cos_inplace )
747+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 2 )
748+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 1 )
749+ self .assertEqual (counters ["inductor" ]["fxgraph_lookup_write_file" ], 1 )
750+
702751 @requires_cuda
703752 @config .patch ({"fx_graph_cache" : True })
704753 @config .patch ({"fx_graph_remote_cache" : False })
0 commit comments