@@ -660,83 +660,72 @@ def fn():
660660
661661 self .check_output_and_recompiles (fn )
662662
663- def test_torch_compile_api_inductor (self ):
664- def fn ():
665- torch .manual_seed (123 )
666- model = torch .nn .Sequential (
667- torch .nn .Linear (4 , 4 ),
668- torch .nn .Sigmoid (),
669- )
670-
663+ @parametrize ("api" , ("compile" , "optimize" ))
664+ @parametrize ("backend" , ("eager" , "aot_eager" , "inductor" ))
665+ def test_compile_api (self , api , backend ):
666+ def wrap (fn , backend ):
667+ if api == "compile" :
668+ return torch .compile (fn , backend = backend )
669+ elif api == "optimize" :
670+ return torch ._dynamo .optimize (backend )(fn )
671+
672+ def fn (model , inputs ):
671673 res = []
672- for _ in range (3 ):
673- x = torch .randn ([1 , 4 ])
674-
675- result = model (x ).sum ()
674+ for inp in inputs :
675+ result = model (inp ).sum ()
676676 result .backward ()
677677 res .append (model [0 ].weight .grad )
678678 res .append (model [0 ].bias .grad )
679679 model .zero_grad ()
680680 return res
681681
682- expected = fn ()
683- with config .patch (compiled_autograd = True ):
684- compiled_fn = torch .compile (fn )
685- actual = compiled_fn ()
686- self .assertEqual (expected , actual )
687- self .assertEqual (counters ["compiled_autograd" ]["captures" ], 1 )
688-
689- def test_torch_compile_api_aot_eager (self ):
690- def fn ():
691- torch .manual_seed (123 )
692- model = torch .nn .Sequential (
693- torch .nn .Linear (4 , 4 ),
694- torch .nn .Sigmoid (),
695- )
696-
697- res = []
698- for _ in range (3 ):
699- x = torch .randn ([1 , 4 ])
700-
701- result = model (x ).sum ()
702- result .backward ()
703- res .append (model [0 ].weight .grad )
704- res .append (model [0 ].bias .grad )
705- model .zero_grad ()
706- return res
682+ torch .manual_seed (123 )
683+ model = torch .nn .Sequential (
684+ torch .nn .Linear (4 , 4 ),
685+ torch .nn .Sigmoid (),
686+ )
687+ inputs = [
688+ torch .randn ([1 , 4 ]),
689+ torch .randn ([2 , 4 ]),
690+ torch .randn ([3 , 4 ]),
691+ ]
707692
708- expected = fn ()
693+ expected = fn (model , inputs )
709694 with config .patch (compiled_autograd = True ):
710- compiled_fn = torch . compile (fn , backend = "aot_eager" )
711- actual = compiled_fn ()
695+ compiled_fn = wrap (fn , backend )
696+ actual = compiled_fn (model , inputs )
712697 self .assertEqual (expected , actual )
713- self .assertEqual (counters ["compiled_autograd" ]["captures" ], 1 )
698+ self .assertEqual (counters ["compiled_autograd" ]["captures" ], 2 )
714699
715- def test_torch_compile_api_eager (self ):
716- def fn ():
717- torch .manual_seed (123 )
718- model = torch .nn .Sequential (
719- torch .nn .Linear (4 , 4 ),
720- torch .nn .Sigmoid (),
721- )
700+ @parametrize ("backend" , ("eager" , "aot_eager" , "inductor" ))
701+ def test_optimize_assert (self , backend ):
702+ # can be merged into the test above once we support
703+ # no graph break on .backward
722704
723- res = []
724- for _ in range ( 3 ):
725- x = torch . randn ([ 1 , 4 ] )
705+ def fn ( model , inp ):
706+ # NOTE: not calling .backward in the compiled fn
707+ return model ( inp ). sum ( )
726708
727- result = model ( x ). sum ( )
728- result . backward ()
729- res . append ( model [ 0 ]. weight . grad )
730- res . append ( model [ 0 ]. bias . grad )
731- model . zero_grad ( )
732- return res
709+ torch . manual_seed ( 123 )
710+ model = torch . nn . Sequential (
711+ torch . nn . Linear ( 4 , 4 ),
712+ torch . nn . Sigmoid (),
713+ )
714+ inp = torch . randn ([ 1 , 4 ])
733715
734- expected = fn ()
716+ out = fn (model , inp )
717+ out .backward ()
718+ expected = [p .grad for p in model .parameters ()]
719+ model .zero_grad ()
735720 with config .patch (compiled_autograd = True ):
736- compiled_fn = torch .compile (fn , backend = "eager" )
737- actual = compiled_fn ()
721+ compiled_fn = torch ._dynamo .optimize_assert (backend )(fn )
722+
723+ # should not error due to undefined `rebuild_ctx`
724+ out = compiled_fn (model , inp )
725+ out .backward ()
726+ actual = [p .grad for p in model .parameters ()]
738727 self .assertEqual (expected , actual )
739- self .assertEqual (counters ["compiled_autograd" ]["captures" ], 1 )
728+ self .assertEqual (counters ["compiled_autograd" ]["captures" ], 0 )
740729
741730 def test_multiple_torch_compile (self ):
742731 model = torch .nn .Sequential (
0 commit comments