@@ -307,76 +307,6 @@ def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
307307
308308 self .assertEqual (prog (* inp ), f (* inp ))
309309
310- def test_aot_buffer_mutation (self ) -> None :
311- class Module (torch .nn .Module ):
312- def __init__ (self ):
313- super ().__init__ ()
314- self .register_buffer (
315- "_bin_num_examples" ,
316- torch .empty ([42 ]).fill_ (
317- 0.0 ,
318- ),
319- )
320-
321- def forward (self , x , y , z ):
322- self ._bin_num_examples .index_copy_ (
323- dim = 0 ,
324- index = y ,
325- source = z ,
326- )
327- self ._bin_num_examples .index_add_ (
328- dim = 0 , index = torch .arange (4 ), source = x
329- )
330- return self ._bin_num_examples - 1 , x * z
331-
332- model = Module ()
333- example_inputs = (
334- torch .randn (4 , requires_grad = True ),
335- torch .tensor (0 ),
336- torch .tensor (3.14 ),
337- )
338-
339- with self .assertRaisesRegex (
340- RuntimeError ,
341- "Found a graph input that requires gradients, and received a mutation." ,
342- ):
343- _ = exir .capture (
344- model ,
345- example_inputs ,
346- exir .CaptureConfig (
347- enable_aot = True ,
348- ),
349- )
350-
351- # Note that model._bin_num_examples is mutated during exir.capture
352- # We need to create a new_model
353- new_model = Module ()
354- example_inputs = (
355- torch .randn (4 ),
356- torch .tensor (0 ),
357- torch .tensor (3.14 ),
358- )
359-
360- ep = exir .capture (
361- new_model ,
362- example_inputs ,
363- exir .CaptureConfig (
364- enable_aot = True ,
365- ),
366- )
367-
368- test_inputs = (
369- torch .randn (4 ),
370- torch .tensor (0 ),
371- torch .tensor (2.1 ),
372- )
373- graph_outputs = ep (* test_inputs )
374- eager_outputs = Module ()(* test_inputs )
375- self .assertEqual (len (graph_outputs ), 2 )
376- self .assertEqual (len (eager_outputs ), 2 )
377- self .assertTrue (torch .allclose (graph_outputs [0 ], eager_outputs [0 ]))
378- self .assertTrue (torch .allclose (graph_outputs [1 ], eager_outputs [1 ]))
379-
380310 def test_assume_constant_by_default_prop (self ) -> None :
381311 def foo (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
382312 if x .shape [0 ] > 3 :
0 commit comments