@@ -620,12 +620,28 @@ def save_bpte_program(exec_prog, original_model: torch.nn.Module, output_name: s
620620 save_bundled_program (exec_prog , method_test_suites , output_name )
621621
622622
623+ def quantize_model (
624+ exported_program , args , model : torch .nn .Module , example_inputs , compile_spec
625+ ):
626+ model_int8 = quantize (
627+ model ,
628+ args .model_name ,
629+ compile_spec ,
630+ example_inputs ,
631+ args .evaluate ,
632+ args .evaluate_config ,
633+ )
634+ # Wrap quantized model back into an exported_program
635+ exported_program = torch .export .export_for_training (
636+ model_int8 , example_inputs , strict = True
637+ )
638+
639+ return model_int8 , exported_program
640+
641+
623642def to_edge_TOSA_delegate (
624- exported_program ,
625- args ,
626- model : torch .nn .Module ,
643+ exported_program , args , model : torch .nn .Module , example_inputs
627644):
628- model_int8 = None
629645 # As we can target multiple output encodings, one must
630646 # be specified.
631647 compile_spec = get_compile_spec (
@@ -634,23 +650,13 @@ def to_edge_TOSA_delegate(
634650 args .system_config ,
635651 args .memory_mode ,
636652 )
653+
654+ model_int8 = None
637655 if args .quantize :
638- model = quantize (
639- model ,
640- args .model_name ,
641- compile_spec ,
642- example_inputs ,
643- args .evaluate ,
644- args .evaluate_config ,
656+ model_int8 , exported_program = quantize_model (
657+ exported_program , args , model , example_inputs , compile_spec
645658 )
646- model_int8 = model
647- # Wrap quantized model back into an exported_program
648- exported_program = torch .export .export_for_training (
649- model , example_inputs , strict = True
650- )
651-
652- if args .intermediates :
653- os .makedirs (args .intermediates , exist_ok = True )
659+ model = model_int8
654660
655661 if is_ethosu (compile_spec ):
656662 partitioner = EthosUPartitioner (compile_spec )
@@ -669,6 +675,31 @@ def to_edge_TOSA_delegate(
669675 return model_int8 , edge
670676
671677
678+ def to_edge_no_delegate (exported_program , args , model : torch .nn .Module , example_inputs ):
679+ model_int8 = None
680+ if args .quantize :
681+ # As we can target multiple output encodings, one must
682+ # be specified.
683+ compile_spec = get_compile_spec (
684+ args .target ,
685+ args .intermediates ,
686+ args .system_config ,
687+ args .memory_mode ,
688+ )
689+ model , exported_program = quantize_model (
690+ exported_program , args , model , example_inputs , compile_spec
691+ )
692+ model_int8 = model
693+
694+ edge = to_edge_transform_and_lower (
695+ exported_program ,
696+ compile_config = EdgeCompileConfig (
697+ _check_ir_validity = False ,
698+ ),
699+ )
700+ return model_int8 , edge
701+
702+
672703if __name__ == "__main__" : # noqa: C901
673704 args = get_args ()
674705
@@ -686,16 +717,18 @@ def to_edge_TOSA_delegate(
686717 model = exported_program .module ()
687718 model_fp32 = model
688719
720+ if args .intermediates :
721+ os .makedirs (args .intermediates , exist_ok = True )
722+
689723 # Quantize if required
690724 model_int8 = None
691725 if args .delegate :
692- model_int8 , edge = to_edge_TOSA_delegate (exported_program , args , model )
726+ model_int8 , edge = to_edge_TOSA_delegate (
727+ exported_program , args , model , example_inputs
728+ )
693729 else :
694- edge = to_edge_transform_and_lower (
695- exported_program ,
696- compile_config = EdgeCompileConfig (
697- _check_ir_validity = False ,
698- ),
730+ model_int8 , edge = to_edge_no_delegate (
731+ exported_program , args , model , example_inputs
699732 )
700733
701734 dump_delegation_info (edge , args .intermediates )
0 commit comments