@@ -239,12 +239,7 @@ def wrapper(self: "TestMemoryPlanning") -> None:
239239 # torch._tensor.Tensor]` is not a function.
240240 inputs = eager_module .get_random_inputs ()
241241 graph_module = (
242- to_edge (
243- export (
244- eager_module ,
245- inputs ,
246- )
247- )
242+ to_edge (export (eager_module , inputs , strict = True ))
248243 .exported_program ()
249244 .graph_module
250245 )
@@ -491,10 +486,7 @@ def test_multiple_pools(
491486 expected_bufsizes : List [int ],
492487 ) -> None :
493488 edge_program = to_edge (
494- export (
495- MultiplePoolsToyModel (),
496- (torch .ones (1 ),),
497- )
489+ export (MultiplePoolsToyModel (), (torch .ones (1 ),), strict = True )
498490 )
499491
500492 edge_program .to_executorch (
@@ -538,7 +530,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
538530 return torch .nn .functional .sigmoid (self .linear (x ) + self .constant + 1 )
539531
540532 def count_planned_inputs (
541- nodes : List [Node ], graph_signature : Any # pyre-ignore
533+ nodes : List [Node ],
534+ graph_signature : Any , # pyre-ignore
542535 ) -> Tuple [int , int ]:
543536 num_mem_planned_placeholders = 0
544537 num_placeholders = 0
@@ -555,7 +548,9 @@ def count_planned_inputs(
555548 model = Simple ()
556549 inputs = (torch .randn (5 , 5 ),)
557550
558- ep_no_input_planning = to_edge (export (model , inputs )).to_executorch (
551+ ep_no_input_planning = to_edge (
552+ export (model , inputs , strict = True )
553+ ).to_executorch (
559554 config = ExecutorchBackendConfig (
560555 memory_planning_pass = MemoryPlanningPass (alloc_graph_input = False ),
561556 sym_shape_eval_pass = ConstraintBasedSymShapeEvalPass (),
@@ -575,7 +570,7 @@ def count_planned_inputs(
575570 5 , # x, self.constant, linear weight, linear bias, '1' scalar promoted to tensor
576571 )
577572
578- ep_input_planning = to_edge (export (model , inputs )).to_executorch (
573+ ep_input_planning = to_edge (export (model , inputs , strict = True )).to_executorch (
579574 config = ExecutorchBackendConfig (
580575 memory_planning_pass = MemoryPlanningPass (alloc_graph_input = True ),
581576 sym_shape_eval_pass = ConstraintBasedSymShapeEvalPass (),
@@ -609,7 +604,7 @@ def forward(self, a, b, x):
609604
610605 model = TestModel ()
611606 example_inputs = (torch .rand (1 , 6 , 2 ), torch .rand (1 , 6 , 2 ), torch .randn (5 , 5 ))
612- exported_model = torch .export .export (model , example_inputs )
607+ exported_model = torch .export .export (model , example_inputs , strict = True )
613608 edge = to_edge (exported_model )
614609
615610 class TestPass (ExportPass ):
0 commit comments