@@ -8647,64 +8647,15 @@ def global_context_capture_fn(frame_summary):
8647
8647
self .assertEqual (seen_frames [1 ].name , "uwu_inline_me" )
8648
8648
self .assertEqual (seen_frames [2 ].line , "r2 = uwu_inline_me_deep(y, z)" )
8649
8649
8650
- def test_recompile_on_disable_1 (self ):
8651
- # fix https://github.com/pytorch/pytorch/issues/157399
8650
+ def test_error_on_recompile (self ):
8652
8651
@torch .compile (backend = "eager" )
8653
- def fn (x ):
8654
- @torch ._dynamo .disable
8655
- def inner (x ):
8656
- return x + 10
8657
-
8658
- return inner (x ) + 1
8659
-
8660
- with unittest .mock .patch ("torch._dynamo.config.error_on_recompile" , True ):
8661
- try :
8662
- for i in range (5 ):
8663
- fn (torch .rand (2 , 3 ))
8664
- except torch ._dynamo .exc .RecompileError as e :
8665
- self .fail ("RecompileError raised unexpectedly: " + str (e ))
8666
-
8667
- def test_recompile_on_disable_2 (self ):
8668
- def outer (x , cond ):
8669
- @torch ._dynamo .disable ()
8670
- def fn0 (y ):
8671
- return y + 1
8672
-
8673
- @torch ._dynamo .disable ()
8674
- def fn1 (y ):
8675
- return y + 2
8676
-
8677
- if cond :
8678
- f = fn0
8679
- else :
8680
- f = fn1
8681
-
8682
- torch ._dynamo .graph_break ()
8683
- # there will be a resume function here
8684
- return f (x )
8685
-
8686
- with unittest .mock .patch ("torch._dynamo.config.error_on_recompile" , True ):
8687
- with self .assertRaises (torch ._dynamo .exc .RecompileError ):
8688
- x = torch .rand (2 , 3 )
8689
- self .assertEqual (outer (x , True ), torch .compile (outer )(x , True ))
8690
- self .assertEqual (outer (x , False ), torch .compile (outer )(x , False ))
8691
-
8692
- def test_create_nested_fn_cache_clear (self ):
8693
- def outer (x ):
8694
- @torch ._dynamo .disable ()
8695
- def f (y ):
8696
- return y + 2
8697
-
8698
- return f (x ) + 1
8652
+ def fn (a , b ):
8653
+ return a + b
8699
8654
8700
- outer = torch .compile (outer )
8701
8655
with unittest .mock .patch ("torch._dynamo.config.error_on_recompile" , True ):
8702
8656
with self .assertRaises (torch ._dynamo .exc .RecompileError ):
8703
- outer (torch .randn (3 , 3 ))
8704
- from torch ._dynamo .utils import create_nested_fn_cache
8705
-
8706
- create_nested_fn_cache .clear ()
8707
- outer (torch .randn (3 , 3 ))
8657
+ fn (torch .rand (2 , 3 ), torch .rand (2 , 3 ))
8658
+ fn (torch .rand (2 , 3 ), (1 , 2 , 3 ))
8708
8659
8709
8660
def test_guards_strip_function_call (self ):
8710
8661
from torch ._dynamo .guards import strip_function_call
0 commit comments