@@ -739,7 +739,9 @@ def _to_edge_and_lower_llama_xnnpack(
739739 pt2e_quant_params ,
740740 quantizers ,
741741 quant_dtype ,
742- args ,
742+ xnnpack_extended_ops : bool = False ,
743+ generate_etrecord : bool = False ,
744+ verbose : bool = False ,
743745) -> LLMEdgeManager : # noqa: C901
744746 partitioners = []
745747
@@ -748,7 +750,7 @@ def _to_edge_and_lower_llama_xnnpack(
748750
749751 modelname = f"xnnpack_dq_{ modelname } "
750752
751- if args . xnnpack_extended_ops :
753+ if xnnpack_extended_ops :
752754 partitioners .append (
753755 get_xnnpack_partitioner (dynamic_quant_only_partitioner = False )
754756 )
@@ -759,15 +761,15 @@ def _to_edge_and_lower_llama_xnnpack(
759761 logging .info (f"--> { partitioner .__class__ .__name__ } " )
760762
761763 # TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
762- if args . generate_etrecord :
764+ if generate_etrecord :
763765 raise NotImplementedError (
764766 "export_llama does not support XNNPack and generating ETRecord at the moment."
765767 )
766768
767769 builder = builder_exported .pt2e_quantize (quantizers ).to_edge_transform_and_lower (
768770 partitioners
769771 )
770- if args . verbose :
772+ if verbose :
771773 print_delegation_info (builder .edge_manager .exported_program ().graph_module )
772774
773775 return builder .to_executorch (passes = additional_passes )
@@ -780,42 +782,58 @@ def _to_edge_and_lower_llama( # noqa: C901
780782 pt2e_quant_params ,
781783 quantizers ,
782784 quant_dtype ,
783- args ,
785+ vulkan : bool = False ,
786+ mps : bool = False ,
787+ coreml : bool = False ,
788+ qnn : bool = False ,
789+ dtype_override : str = "fp32" ,
790+ enable_dynamic_shape : bool = True ,
791+ use_kv_cache : bool = False ,
792+ embedding_quantize : Optional [str ] = None ,
793+ pt2e_quantize : Optional [str ] = None ,
794+ coreml_ios : int = 15 ,
795+ coreml_quantize : Optional [str ] = None ,
796+ coreml_compute_units : str = "cpu_only" ,
797+ use_qnn_sha : bool = False ,
798+ num_sharding : int = 0 ,
799+ soc_model : str = "SM8650" ,
800+ generate_etrecord : bool = False ,
801+ verbose : bool = False ,
784802):
785803 builder_exported_to_edge = builder_exported .pt2e_quantize (
786804 quantizers
787805 ).export_to_edge ()
788806
789807 # to_backend
790808 partitioners = []
791- if args . vulkan :
809+ if vulkan :
792810 partitioners .append (
793811 get_vulkan_partitioner (
794- args . dtype_override ,
795- args . enable_dynamic_shape ,
812+ dtype_override ,
813+ enable_dynamic_shape ,
796814 )
797815 )
798816 modelname = f"vulkan_{ modelname } "
799817
800818 # Need to remove asserts from the graph to prevent graph breaks
801819 remove_asserts (builder_exported_to_edge .edge_manager .exported_program ())
802820
803- if args . mps :
804- partitioners .append (get_mps_partitioner (args . use_kv_cache ))
821+ if mps :
822+ partitioners .append (get_mps_partitioner (use_kv_cache ))
805823 modelname = f"mps_{ modelname } "
806824
807- if args . coreml :
825+ if coreml :
808826 coreml_partitioner = get_coreml_partitioner (
809- args . coreml_ios ,
810- args . embedding_quantize ,
811- args . pt2e_quantize ,
812- args . coreml_quantize ,
813- args . coreml_compute_units ,
827+ coreml_ios ,
828+ embedding_quantize ,
829+ pt2e_quantize ,
830+ coreml_quantize ,
831+ coreml_compute_units ,
814832 )
815833 partitioners .append (coreml_partitioner )
816834 modelname = f"coreml_{ modelname } "
817835
818- if args . qnn :
836+ if qnn :
819837 logging .warning (
820838 "The model definition in current repro is not performant, please refer to the instruction"
821839 " in https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama/README.md for better performance."
@@ -824,7 +842,7 @@ def _to_edge_and_lower_llama( # noqa: C901
824842
825843 partitioners .append (
826844 get_qnn_partitioner (
827- args . use_kv_cache , args . pt2e_quantize , args . num_sharding , args . soc_model
845+ use_kv_cache , pt2e_quantize , num_sharding , soc_model
828846 )
829847 )
830848 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes`
@@ -854,7 +872,7 @@ def _to_edge_and_lower_llama( # noqa: C901
854872 )
855873
856874 atten = builder_exported_to_edge .model .layers [0 ].attention
857- if args . use_qnn_sha :
875+ if use_qnn_sha :
858876 cache_shape = torch .Size (
859877 (atten .max_batch_size , atten .max_context_len , atten .head_dim )
860878 )
@@ -877,10 +895,10 @@ def _to_edge_and_lower_llama( # noqa: C901
877895 passes_job [TagQuantIO ][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY ][
878896 "get_quant_io_dtype_fn"
879897 ] = partial (get_custom_quant_ios_dtype , cache_shape )
880- if args . num_sharding > 0 :
898+ if num_sharding > 0 :
881899 SplitGraph , setting = model_sharding .get_split_graph_pass (
882900 builder_exported_to_edge .metadata ["get_n_layers" ],
883- shares = args . num_sharding ,
901+ shares = num_sharding ,
884902 )
885903 passes_job [SplitGraph ] = setting
886904 dep_table [SplitGraph ] = [FoldQDQ ]
@@ -895,17 +913,17 @@ def _to_edge_and_lower_llama( # noqa: C901
895913 for partitioner in partitioners :
896914 logging .info (f"--> { partitioner .__class__ .__name__ } " )
897915
898- if args . generate_etrecord :
916+ if generate_etrecord :
899917 if not builder_exported_to_edge .edge_manager :
900918 raise ValueError ("Unable to generate etrecord due to missing edge manager." )
901919
902920 logging .info ("Generating etrecord" )
903921 # Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
904922 edge_manager_copy = copy .deepcopy (builder_exported_to_edge .edge_manager )
905923 builder = builder_exported_to_edge .to_backend (partitioners )
906- if args . verbose :
924+ if verbose :
907925 print_delegation_info (builder .edge_manager .exported_program ().graph_module )
908- if args . num_sharding > 0 and args . qnn :
926+ if num_sharding > 0 and qnn :
909927 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`.
910928 from executorch .backends .qualcomm .utils .utils import canonicalize_program
911929
@@ -925,9 +943,9 @@ def _to_edge_and_lower_llama( # noqa: C901
925943 logging .info ("Generated etrecord.bin" )
926944 else :
927945 builder = builder_exported_to_edge .to_backend (partitioners )
928- if args . verbose :
946+ if verbose :
929947 print_delegation_info (builder .edge_manager .exported_program ().graph_module )
930- if args . num_sharding > 0 and args . qnn :
948+ if num_sharding > 0 and qnn :
931949 from executorch .backends .qualcomm .utils .utils import canonicalize_program
932950
933951 canonicalize_program (builder .edge_manager .exported_program ())
@@ -966,7 +984,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
966984 pt2e_quant_params ,
967985 quantizers ,
968986 quant_dtype ,
969- args ,
987+ xnnpack_extended_ops = args .xnnpack_extended_ops ,
988+ generate_etrecord = args .generate_etrecord ,
989+ verbose = args .verbose ,
970990 )
971991 else :
972992 builder = _to_edge_and_lower_llama (
@@ -976,7 +996,23 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
976996 pt2e_quant_params ,
977997 quantizers ,
978998 quant_dtype ,
979- args ,
999+ vulkan = args .vulkan ,
1000+ mps = args .mps ,
1001+ coreml = args .coreml ,
1002+ qnn = args .qnn ,
1003+ dtype_override = args .dtype_override ,
1004+ enable_dynamic_shape = args .enable_dynamic_shape ,
1005+ use_kv_cache = args .use_kv_cache ,
1006+ embedding_quantize = args .embedding_quantize ,
1007+ pt2e_quantize = args .pt2e_quantize ,
1008+ coreml_ios = args .coreml_ios ,
1009+ coreml_quantize = args .coreml_quantize ,
1010+ coreml_compute_units = args .coreml_compute_units ,
1011+ use_qnn_sha = args .use_qnn_sha ,
1012+ num_sharding = args .num_sharding ,
1013+ soc_model = args .soc_model ,
1014+ generate_etrecord = args .generate_etrecord ,
1015+ verbose = args .verbose ,
9801016 )
9811017
9821018 if args .profile_memory :
0 commit comments