6161
6262from executorch .extension .export_util .utils import save_pte_program
6363from tabulate import tabulate
64+ from torch .export import ExportedProgram
65+ from torch .fx import GraphModule
6466from torch .utils .data import DataLoader
6567
6668# Quantize model if required using the standard export quantizaion flow.
@@ -145,13 +147,13 @@ def get_model_and_inputs_from_name(
145147
146148
147149def quantize (
148- model : torch . nn . Module ,
150+ model : GraphModule ,
149151 model_name : str ,
150152 compile_specs : EthosUCompileSpec | VgfCompileSpec | TosaCompileSpec ,
151153 example_inputs : Tuple [torch .Tensor ],
152154 evaluator_name : str | None ,
153155 evaluator_config : Dict [str , Any ] | None ,
154- ) -> torch . nn . Module :
156+ ) -> GraphModule :
155157 """This is the official recommended flow for quantization in pytorch 2.0
156158 export"""
157159 logging .info ("Quantizing Model..." )
@@ -601,7 +603,12 @@ def save_bpte_program(exec_prog, original_model: torch.nn.Module, output_name: s
601603 save_bundled_program (exec_prog , method_test_suites , output_name )
602604
603605
604- def quantize_model (args , model : torch .nn .Module , example_inputs , compile_spec ):
606+ def quantize_model (
607+ args ,
608+ model : GraphModule ,
609+ example_inputs : Tuple [torch .Tensor ],
610+ compile_spec ,
611+ ) -> Tuple [GraphModule , ExportedProgram ]:
605612 model_int8 = quantize (
606613 model ,
607614 args .model_name ,
@@ -619,7 +626,10 @@ def quantize_model(args, model: torch.nn.Module, example_inputs, compile_spec):
619626
620627
621628def to_edge_TOSA_delegate (
622- exported_program , args , model : torch .nn .Module , example_inputs
629+ exported_program : ExportedProgram ,
630+ args ,
631+ model : GraphModule ,
632+ example_inputs : Tuple [torch .Tensor ],
623633):
624634 # As we can target multiple output encodings, one must
625635 # be specified.
@@ -638,7 +648,6 @@ def to_edge_TOSA_delegate(
638648 model_int8 , exported_program = quantize_model (
639649 args , model , example_inputs , compile_spec
640650 )
641- model = model_int8
642651
643652 if isinstance (compile_spec , EthosUCompileSpec ):
644653 partitioner = EthosUPartitioner (compile_spec )
@@ -660,7 +669,12 @@ def to_edge_TOSA_delegate(
660669 return model_int8 , edge
661670
662671
663- def to_edge_no_delegate (exported_program , args , model : torch .nn .Module , example_inputs ):
672+ def to_edge_no_delegate (
673+ exported_program : ExportedProgram ,
674+ args ,
675+ model : GraphModule ,
676+ example_inputs : Tuple [torch .Tensor ],
677+ ):
664678 model_int8 = None
665679 if args .quantize :
666680 # As we can target multiple output encodings, one must
0 commit comments