@@ -65,7 +65,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6565 return self .relu (a )
6666
6767
68- example_args = (torch .randn (1 , 3 , 256 , 256 ),)
68+ example_args : tuple [ torch . Tensor ] = (torch .randn (1 , 3 , 256 , 256 ),)
6969aten_dialect : ExportedProgram = export (SimpleConv (), example_args , strict = True )
7070print (aten_dialect )
7171
@@ -100,8 +100,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
100100 return x + y
101101
102102
103- example_args = (torch .randn (3 , 3 ), torch .randn (3 , 3 ))
104- aten_dialect : ExportedProgram = export (Basic (), example_args , strict = True )
103+ example_args_2 : tuple [torch .Tensor , torch .Tensor ] = (
104+ torch .randn (3 , 3 ),
105+ torch .randn (3 , 3 ),
106+ )
107+ aten_dialect = export (Basic (), example_args_2 , strict = True )
105108
106109# Works correctly
107110print (aten_dialect .module ()(torch .ones (3 , 3 ), torch .ones (3 , 3 )))
@@ -118,20 +121,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
118121
119122from torch .export import Dim
120123
121-
122- class Basic (torch .nn .Module ):
123- def __init__ (self ):
124- super ().__init__ ()
125-
126- def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
127- return x + y
128-
129-
130- example_args = (torch .randn (3 , 3 ), torch .randn (3 , 3 ))
124+ example_args_2 = (torch .randn (3 , 3 ), torch .randn (3 , 3 ))
131125dim1_x = Dim ("dim1_x" , min = 1 , max = 10 )
132126dynamic_shapes = {"x" : {1 : dim1_x }, "y" : {1 : dim1_x }}
133- aten_dialect : ExportedProgram = export (
134- Basic (), example_args , dynamic_shapes = dynamic_shapes , strict = True
127+ aten_dialect = export (
128+ Basic (), example_args_2 , dynamic_shapes = dynamic_shapes , strict = True
135129)
136130print (aten_dialect )
137131
@@ -207,13 +201,13 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
207201)
208202
209203quantizer = XNNPACKQuantizer ().set_global (get_symmetric_quantization_config ())
210- prepared_graph = prepare_pt2e (pre_autograd_aten_dialect , quantizer )
204+ prepared_graph = prepare_pt2e (pre_autograd_aten_dialect , quantizer ) # type: ignore[arg-type]
211205# calibrate with a sample dataset
212206converted_graph = convert_pt2e (prepared_graph )
213207print ("Quantized Graph" )
214208print (converted_graph )
215209
216- aten_dialect : ExportedProgram = export (converted_graph , example_args , strict = True )
210+ aten_dialect = export (converted_graph , example_args , strict = True )
217211print ("ATen Dialect Graph" )
218212print (aten_dialect )
219213
@@ -243,7 +237,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
243237from executorch .exir import EdgeProgramManager , to_edge
244238
245239example_args = (torch .randn (1 , 3 , 256 , 256 ),)
246- aten_dialect : ExportedProgram = export (SimpleConv (), example_args , strict = True )
240+ aten_dialect = export (SimpleConv (), example_args , strict = True )
247241
248242edge_program : EdgeProgramManager = to_edge (aten_dialect )
249243print ("Edge Dialect Graph" )
@@ -272,9 +266,7 @@ def forward(self, x):
272266decode_args = (torch .randn (1 , 5 ),)
273267aten_decode : ExportedProgram = export (Decode (), decode_args , strict = True )
274268
275- edge_program : EdgeProgramManager = to_edge (
276- {"encode" : aten_encode , "decode" : aten_decode }
277- )
269+ edge_program = to_edge ({"encode" : aten_encode , "decode" : aten_decode })
278270for method in edge_program .methods :
279271 print (f"Edge Dialect graph of { method } " )
280272 print (edge_program .exported_program (method ))
@@ -291,8 +283,8 @@ def forward(self, x):
291283# rather than the ``torch.ops.aten`` namespace.
292284
293285example_args = (torch .randn (1 , 3 , 256 , 256 ),)
294- aten_dialect : ExportedProgram = export (SimpleConv (), example_args , strict = True )
295- edge_program : EdgeProgramManager = to_edge (aten_dialect )
286+ aten_dialect = export (SimpleConv (), example_args , strict = True )
287+ edge_program = to_edge (aten_dialect )
296288print ("Edge Dialect Graph" )
297289print (edge_program .exported_program ())
298290
@@ -357,8 +349,8 @@ def forward(self, x):
357349
358350# Export and lower the module to Edge Dialect
359351example_args = (torch .ones (1 ),)
360- aten_dialect : ExportedProgram = export (LowerableModule (), example_args , strict = True )
361- edge_program : EdgeProgramManager = to_edge (aten_dialect )
352+ aten_dialect = export (LowerableModule (), example_args , strict = True )
353+ edge_program = to_edge (aten_dialect )
362354to_be_lowered_module = edge_program .exported_program ()
363355
364356from executorch .exir .backend .backend_api import LoweredBackendModule , to_backend
@@ -369,7 +361,7 @@ def forward(self, x):
369361)
370362
371363# Lower the module
372- lowered_module : LoweredBackendModule = to_backend (
364+ lowered_module : LoweredBackendModule = to_backend ( # type: ignore[call-arg]
373365 "BackendWithCompilerDemo" , to_be_lowered_module , []
374366)
375367print (lowered_module )
@@ -423,8 +415,8 @@ def forward(self, x):
423415
424416
425417example_args = (torch .ones (1 ),)
426- aten_dialect : ExportedProgram = export (ComposedModule (), example_args , strict = True )
427- edge_program : EdgeProgramManager = to_edge (aten_dialect )
418+ aten_dialect = export (ComposedModule (), example_args , strict = True )
419+ edge_program = to_edge (aten_dialect )
428420exported_program = edge_program .exported_program ()
429421print ("Edge Dialect graph" )
430422print (exported_program )
@@ -460,16 +452,16 @@ def forward(self, a, x, b):
460452 return z
461453
462454
463- example_args = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
464- aten_dialect : ExportedProgram = export (Foo (), example_args , strict = True )
465- edge_program : EdgeProgramManager = to_edge (aten_dialect )
455+ example_args_3 = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
456+ aten_dialect = export (Foo (), example_args_3 , strict = True )
457+ edge_program = to_edge (aten_dialect )
466458exported_program = edge_program .exported_program ()
467459print ("Edge Dialect graph" )
468460print (exported_program )
469461
470462from executorch .exir .backend .test .op_partitioner_demo import AddMulPartitionerDemo
471463
472- delegated_program = to_backend (exported_program , AddMulPartitionerDemo ())
464+ delegated_program = to_backend (exported_program , AddMulPartitionerDemo ()) # type: ignore[call-arg]
473465print ("Delegated program" )
474466print (delegated_program )
475467print (delegated_program .graph_module .lowered_module_0 .original_module )
@@ -484,19 +476,9 @@ def forward(self, a, x, b):
484476# call ``to_backend`` on it:
485477
486478
487- class Foo (torch .nn .Module ):
488- def forward (self , a , x , b ):
489- y = torch .mm (a , x )
490- z = y + b
491- a = z - a
492- y = torch .mm (a , x )
493- z = y + b
494- return z
495-
496-
497- example_args = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
498- aten_dialect : ExportedProgram = export (Foo (), example_args , strict = True )
499- edge_program : EdgeProgramManager = to_edge (aten_dialect )
479+ example_args_3 = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
480+ aten_dialect = export (Foo (), example_args_3 , strict = True )
481+ edge_program = to_edge (aten_dialect )
500482exported_program = edge_program .exported_program ()
501483delegated_program = edge_program .to_backend (AddMulPartitionerDemo ())
502484
@@ -530,7 +512,6 @@ def forward(self, a, x, b):
530512print ("ExecuTorch Dialect" )
531513print (executorch_program .exported_program ())
532514
533- import executorch .exir as exir
534515
535516######################################################################
536517# Notice that in the graph we now see operators like ``torch.ops.aten.sub.out``
@@ -577,13 +558,11 @@ def forward(self, x):
577558pre_autograd_aten_dialect = export_for_training (M (), example_args ).module ()
578559# Optionally do quantization:
579560# pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))
580- aten_dialect : ExportedProgram = export (
581- pre_autograd_aten_dialect , example_args , strict = True
582- )
583- edge_program : exir .EdgeProgramManager = exir .to_edge (aten_dialect )
561+ aten_dialect = export (pre_autograd_aten_dialect , example_args , strict = True )
562+ edge_program = to_edge (aten_dialect )
584563# Optionally do delegation:
585564# edge_program = edge_program.to_backend(CustomBackendPartitioner)
586- executorch_program : exir . ExecutorchProgramManager = edge_program .to_executorch (
565+ executorch_program = edge_program .to_executorch (
587566 ExecutorchBackendConfig (
588567 passes = [], # User-defined passes
589568 )
0 commit comments