@@ -674,47 +674,62 @@ def _validate_args(args):
674674 )
675675
676676
677- def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
678- _validate_args (args )
679-
680- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
681-
682- # export_to_edge
683- builder_exported = _prepare_for_llama_export (args ).export ()
684-
685- builder_exported .run_canonical_optimizations ()
686-
687- if args .export_only :
688- exit ()
689-
690- builder_exported_to_edge = builder_exported .pt2e_quantize (
691- quantizers
692- ).export_to_edge ()
693-
694- modelname = builder_exported_to_edge .modelname
695-
696- # to_backend
677+ def _to_edge_and_lower_llama_xnnpack (
678+ builder_exported ,
679+ modelname ,
680+ additional_passes ,
681+ pt2e_quant_params ,
682+ quantizers ,
683+ quant_dtype ,
684+ args ,
685+ ) -> LLMEdgeManager : # noqa: C901
697686 partitioners = []
698687
699688 # Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
700- if (
701- pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None
702- ) or (args .xnnpack ):
703- partitioners .append (
704- get_xnnpack_partitioner (dynamic_quant_only_partitioner = True )
705- )
689+ partitioners .append (get_xnnpack_partitioner (dynamic_quant_only_partitioner = True ))
706690
707- # force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
708- args .xnnpack = True
709- modelname = f"xnnpack_dq_{ modelname } "
691+ modelname = f"xnnpack_dq_{ modelname } "
710692
711693 if args .xnnpack_extended_ops :
712- assert args .xnnpack , "xnnpack_extended_ops requires xnnpack to be enabled"
713694 partitioners .append (
714695 get_xnnpack_partitioner (dynamic_quant_only_partitioner = False )
715696 )
716697 modelname = f"xnnpack_{ modelname } "
717698
699+ logging .info ("Lowering model using following partitioner(s): " )
700+ for partitioner in partitioners :
701+ logging .info (f"--> { partitioner .__class__ .__name__ } " )
702+
703+ # TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
704+ if args .generate_etrecord :
705+ raise NotImplementedError (
706+ "export_llama does not support XNNPack and generating ETRecord at the moment."
707+ )
708+
709+ builder = builder_exported .pt2e_quantize (quantizers ).to_edge_transform_and_lower (
710+ partitioners
711+ )
712+ if args .verbose :
713+ print_delegation_info (builder .edge_manager .exported_program ().graph_module )
714+
715+ return builder .to_executorch (passes = additional_passes )
716+
717+
718+ def _to_edge_and_lower_llama ( # noqa: C901
719+ builder_exported ,
720+ modelname ,
721+ additional_passes ,
722+ pt2e_quant_params ,
723+ quantizers ,
724+ quant_dtype ,
725+ args ,
726+ ):
727+ builder_exported_to_edge = builder_exported .pt2e_quantize (
728+ quantizers
729+ ).export_to_edge ()
730+
731+ # to_backend
732+ partitioners = []
718733 if args .vulkan :
719734 partitioners .append (
720735 get_vulkan_partitioner (
@@ -729,7 +744,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
729744 modelname = f"vulkan_{ modelname } "
730745
731746 # Need to remove asserts from the graph to prevent graph breaks
732- # pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
733747 remove_asserts (builder_exported_to_edge .edge_manager .exported_program ())
734748
735749 if args .mps :
@@ -758,13 +772,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
758772 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
759773 from executorch .backends .qualcomm .utils .utils import _transform , tag_quant_io
760774
761- # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
762775 _transform (builder_exported_to_edge .edge_manager .exported_program ())
763776
764777 if args .num_sharding > 0 :
765778 model_sharding .split_graph (
766779 builder_exported_to_edge .edge_manager .exported_program (),
767- # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
768780 builder_exported_to_edge .metadata ["get_n_layers" ],
769781 shares = args .num_sharding ,
770782 )
@@ -790,19 +802,15 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
790802 atten .head_dim ,
791803 )
792804 )
793- # pyre-ignore
794805 tag_quant_io (
795806 builder_exported_to_edge .edge_manager .exported_program ().graph_module ,
796- partial (get_custom_quant_ios_dtype , cache_shape ), # pyre-ignore
807+ partial (get_custom_quant_ios_dtype , cache_shape ),
797808 )
798809
799810 logging .info ("Lowering model using following partitioner(s): " )
800811 for partitioner in partitioners :
801812 logging .info (f"--> { partitioner .__class__ .__name__ } " )
802813
803- additional_passes = []
804- if args .model in TORCHTUNE_DEFINED_MODELS :
805- additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
806814 if args .generate_etrecord :
807815 if not builder_exported_to_edge .edge_manager :
808816 raise ValueError ("Unable to generate etrecord due to missing edge manager." )
@@ -816,7 +824,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
816824 if args .num_sharding > 0 and args .qnn :
817825 from executorch .backends .qualcomm .utils .utils import canonicalize_program
818826
819- # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
820827 canonicalize_program (builder .edge_manager .exported_program ())
821828
822829 builder = builder .to_executorch (
@@ -838,11 +845,55 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
838845 if args .num_sharding > 0 and args .qnn :
839846 from executorch .backends .qualcomm .utils .utils import canonicalize_program
840847
841- # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
842848 canonicalize_program (builder .edge_manager .exported_program ())
843849
844850 builder = builder .to_executorch (passes = additional_passes )
845851
852+ return builder
853+
854+
855+ def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
856+ _validate_args (args )
857+
858+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
859+
860+ additional_passes = []
861+ if args .model in TORCHTUNE_DEFINED_MODELS :
862+ additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
863+
864+ # export_to_edge
865+ builder_exported = _prepare_for_llama_export (args ).export ()
866+ builder_exported .run_canonical_optimizations ()
867+ modelname = builder_exported .modelname
868+
869+ if args .export_only :
870+ exit ()
871+
872+ if pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None :
873+ # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
874+ args .xnnpack = True
875+
876+ if args .xnnpack :
877+ builder = _to_edge_and_lower_llama_xnnpack (
878+ builder_exported ,
879+ modelname ,
880+ additional_passes ,
881+ pt2e_quant_params ,
882+ quantizers ,
883+ quant_dtype ,
884+ args ,
885+ )
886+ else :
887+ builder = _to_edge_and_lower_llama (
888+ builder_exported ,
889+ modelname ,
890+ additional_passes ,
891+ pt2e_quant_params ,
892+ quantizers ,
893+ quant_dtype ,
894+ args ,
895+ )
896+
846897 if args .profile_memory :
847898 generate_memory_trace (builder .export_program , "memory_profile.json" )
848899
@@ -864,7 +915,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
864915 output_file = f"{ builder .output_dir } /{ modelname } .pte"
865916
866917 builder .save_to_pte (output_file )
867-
868918 return builder
869919
870920
0 commit comments