@@ -44,7 +44,7 @@ def forward(self, x, y):
4444 torch ._check (z < 4 )
4545 return x [z : z + y .shape [0 ]]
4646
47- ep = torch .export .export (M (), (torch .randn (10 ), torch .tensor ([3 ])))
47+ ep = torch .export .export (M (), (torch .randn (10 ), torch .tensor ([3 ])), strict = True )
4848
4949 compile_config_with_disable_ir_validity = EdgeCompileConfig (
5050 _check_ir_validity = False
@@ -82,7 +82,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8282
8383 example_input = (torch .zeros ([2 , 2 ]),)
8484
85- export_model = export (m , example_input )
85+ export_model = export (m , example_input , strict = True )
8686
8787 compile_config_without_edge_op = EdgeCompileConfig (
8888 _use_edge_ops = False , _skip_dim_order = False
@@ -131,7 +131,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
131131 ),
132132 )
133133
134- export_model = export (m , example_input )
134+ export_model = export (m , example_input , strict = True )
135135
136136 compile_config_with_dim_order = EdgeCompileConfig (_skip_dim_order = False )
137137 compile_config_with_stride = EdgeCompileConfig (_skip_dim_order = True )
0 commit comments