6262
6363from executorch .extension .export_util .utils import save_pte_program
6464from tabulate import tabulate
65+ from torch .export import ExportedProgram
66+ from torch .fx import GraphModule
6567from torch .utils .data import DataLoader
6668
6769# Quantize model if required using the standard export quantizaion flow.
@@ -146,13 +148,13 @@ def get_model_and_inputs_from_name(
146148
147149
148150def quantize (
149- model : torch . nn . Module ,
151+ model : GraphModule ,
150152 model_name : str ,
151153 compile_specs : EthosUCompileSpec | VgfCompileSpec | TosaCompileSpec ,
152154 example_inputs : Tuple [torch .Tensor ],
153155 evaluator_name : str | None ,
154156 evaluator_config : Dict [str , Any ] | None ,
155- ) -> torch . nn . Module :
157+ ) -> GraphModule :
156158 """This is the official recommended flow for quantization in pytorch 2.0
157159 export"""
158160 logging .info ("Quantizing Model..." )
@@ -741,7 +743,12 @@ def save_bpte_program(exec_prog, original_model: torch.nn.Module, output_name: s
741743 save_bundled_program (exec_prog , method_test_suites , output_name )
742744
743745
744- def quantize_model (args , model : torch .nn .Module , example_inputs , compile_spec ):
746+ def quantize_model (
747+ args ,
748+ model : GraphModule ,
749+ example_inputs : Tuple [torch .Tensor ],
750+ compile_spec ,
751+ ) -> Tuple [GraphModule , ExportedProgram ]:
745752 model_int8 = quantize (
746753 model ,
747754 args .model_name ,
@@ -759,7 +766,10 @@ def quantize_model(args, model: torch.nn.Module, example_inputs, compile_spec):
759766
760767
761768def to_edge_TOSA_delegate (
762- exported_program , args , model : torch .nn .Module , example_inputs
769+ exported_program : ExportedProgram ,
770+ args ,
771+ model : GraphModule ,
772+ example_inputs : Tuple [torch .Tensor ],
763773):
764774 # As we can target multiple output encodings, one must
765775 # be specified.
@@ -778,7 +788,6 @@ def to_edge_TOSA_delegate(
778788 model_int8 , exported_program = quantize_model (
779789 args , model , example_inputs , compile_spec
780790 )
781- model = model_int8
782791
783792 if isinstance (compile_spec , EthosUCompileSpec ):
784793 partitioner = EthosUPartitioner (compile_spec )
@@ -800,7 +809,12 @@ def to_edge_TOSA_delegate(
800809 return model_int8 , edge
801810
802811
803- def to_edge_no_delegate (exported_program , args , model : torch .nn .Module , example_inputs ):
812+ def to_edge_no_delegate (
813+ exported_program : ExportedProgram ,
814+ args ,
815+ model : GraphModule ,
816+ example_inputs : Tuple [torch .Tensor ],
817+ ):
804818 model_int8 = None
805819 if args .quantize :
806820 # As we can target multiple output encodings, one must
0 commit comments