@@ -66,7 +66,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6666
6767
6868example_args = (torch .randn (1 , 3 , 256 , 256 ),)
69- aten_dialect : ExportedProgram = export (SimpleConv (), example_args )
69+ aten_dialect : ExportedProgram = export (SimpleConv (), example_args , strict = True )
7070print (aten_dialect )
7171
7272######################################################################
@@ -101,7 +101,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
101101
102102
103103example_args = (torch .randn (3 , 3 ), torch .randn (3 , 3 ))
104- aten_dialect : ExportedProgram = export (Basic (), example_args )
104+ aten_dialect : ExportedProgram = export (Basic (), example_args , strict = True )
105105
106106# Works correctly
107107print (aten_dialect .module ()(torch .ones (3 , 3 ), torch .ones (3 , 3 )))
@@ -131,7 +131,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
131131dim1_x = Dim ("dim1_x" , min = 1 , max = 10 )
132132dynamic_shapes = {"x" : {1 : dim1_x }, "y" : {1 : dim1_x }}
133133aten_dialect : ExportedProgram = export (
134- Basic (), example_args , dynamic_shapes = dynamic_shapes
134+ Basic (), example_args , dynamic_shapes = dynamic_shapes , strict = True
135135)
136136print (aten_dialect )
137137
@@ -213,7 +213,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
213213print ("Quantized Graph" )
214214print (converted_graph )
215215
216- aten_dialect : ExportedProgram = export (converted_graph , example_args )
216+ aten_dialect : ExportedProgram = export (converted_graph , example_args , strict = True )
217217print ("ATen Dialect Graph" )
218218print (aten_dialect )
219219
@@ -243,7 +243,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
243243from executorch .exir import EdgeProgramManager , to_edge
244244
245245example_args = (torch .randn (1 , 3 , 256 , 256 ),)
246- aten_dialect : ExportedProgram = export (SimpleConv (), example_args )
246+ aten_dialect : ExportedProgram = export (SimpleConv (), example_args , strict = True )
247247
248248edge_program : EdgeProgramManager = to_edge (aten_dialect )
249249print ("Edge Dialect Graph" )
@@ -267,10 +267,10 @@ def forward(self, x):
267267
268268
269269encode_args = (torch .randn (1 , 10 ),)
270- aten_encode : ExportedProgram = export (Encode (), encode_args )
270+ aten_encode : ExportedProgram = export (Encode (), encode_args , strict = True )
271271
272272decode_args = (torch .randn (1 , 5 ),)
273- aten_decode : ExportedProgram = export (Decode (), decode_args )
273+ aten_decode : ExportedProgram = export (Decode (), decode_args , strict = True )
274274
275275edge_program : EdgeProgramManager = to_edge (
276276 {"encode" : aten_encode , "decode" : aten_decode }
@@ -291,7 +291,7 @@ def forward(self, x):
291291# rather than the ``torch.ops.aten`` namespace.
292292
293293example_args = (torch .randn (1 , 3 , 256 , 256 ),)
294- aten_dialect : ExportedProgram = export (SimpleConv (), example_args )
294+ aten_dialect : ExportedProgram = export (SimpleConv (), example_args , strict = True )
295295edge_program : EdgeProgramManager = to_edge (aten_dialect )
296296print ("Edge Dialect Graph" )
297297print (edge_program .exported_program ())
@@ -357,7 +357,7 @@ def forward(self, x):
357357
358358# Export and lower the module to Edge Dialect
359359example_args = (torch .ones (1 ),)
360- aten_dialect : ExportedProgram = export (LowerableModule (), example_args )
360+ aten_dialect : ExportedProgram = export (LowerableModule (), example_args , strict = True )
361361edge_program : EdgeProgramManager = to_edge (aten_dialect )
362362to_be_lowered_module = edge_program .exported_program ()
363363
@@ -423,7 +423,7 @@ def forward(self, x):
423423
424424
425425example_args = (torch .ones (1 ),)
426- aten_dialect : ExportedProgram = export (ComposedModule (), example_args )
426+ aten_dialect : ExportedProgram = export (ComposedModule (), example_args , strict = True )
427427edge_program : EdgeProgramManager = to_edge (aten_dialect )
428428exported_program = edge_program .exported_program ()
429429print ("Edge Dialect graph" )
@@ -461,7 +461,7 @@ def forward(self, a, x, b):
461461
462462
463463example_args = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
464- aten_dialect : ExportedProgram = export (Foo (), example_args )
464+ aten_dialect : ExportedProgram = export (Foo (), example_args , strict = True )
465465edge_program : EdgeProgramManager = to_edge (aten_dialect )
466466exported_program = edge_program .exported_program ()
467467print ("Edge Dialect graph" )
@@ -495,7 +495,7 @@ def forward(self, a, x, b):
495495
496496
497497example_args = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
498- aten_dialect : ExportedProgram = export (Foo (), example_args )
498+ aten_dialect : ExportedProgram = export (Foo (), example_args , strict = True )
499499edge_program : EdgeProgramManager = to_edge (aten_dialect )
500500exported_program = edge_program .exported_program ()
501501delegated_program = edge_program .to_backend (AddMulPartitionerDemo ())
@@ -577,7 +577,9 @@ def forward(self, x):
577577pre_autograd_aten_dialect = export_for_training (M (), example_args ).module ()
578578# Optionally do quantization:
579579# pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))
580- aten_dialect : ExportedProgram = export (pre_autograd_aten_dialect , example_args )
580+ aten_dialect : ExportedProgram = export (
581+ pre_autograd_aten_dialect , example_args , strict = True
582+ )
581583edge_program : exir .EdgeProgramManager = exir .to_edge (aten_dialect )
582584# Optionally do delegation:
583585# edge_program = edge_program.to_backend(CustomBackendPartitioner)
0 commit comments