2727from executorch .backends .vulkan ._passes .remove_asserts import remove_asserts
2828from executorch .devtools .backend_debug import print_delegation_info
2929
30- from executorch .devtools .etrecord import generate_etrecord
30+ from executorch .devtools .etrecord import generate_etrecord as generate_etrecord_func
3131from executorch .examples .models .llama .hf_download import (
3232 download_and_convert_hf_checkpoint ,
3333)
@@ -749,7 +749,9 @@ def _to_edge_and_lower_llama_xnnpack(
749749 pt2e_quant_params ,
750750 quantizers ,
751751 quant_dtype ,
752- args ,
752+ xnnpack_extended_ops : bool = False ,
753+ generate_etrecord : bool = False ,
754+ verbose : bool = False ,
753755) -> LLMEdgeManager : # noqa: C901
754756 partitioners = []
755757
@@ -758,7 +760,7 @@ def _to_edge_and_lower_llama_xnnpack(
758760
759761 modelname = f"xnnpack_dq_{ modelname } "
760762
761- if args . xnnpack_extended_ops :
763+ if xnnpack_extended_ops :
762764 partitioners .append (
763765 get_xnnpack_partitioner (dynamic_quant_only_partitioner = False )
764766 )
@@ -769,15 +771,15 @@ def _to_edge_and_lower_llama_xnnpack(
769771 logging .info (f"--> { partitioner .__class__ .__name__ } " )
770772
771773 # TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
772- if args . generate_etrecord :
774+ if generate_etrecord :
773775 raise NotImplementedError (
774776 "export_llama does not support XNNPack and generating ETRecord at the moment."
775777 )
776778
777779 builder = builder_exported .pt2e_quantize (quantizers ).to_edge_transform_and_lower (
778780 partitioners
779781 )
780- if args . verbose :
782+ if verbose :
781783 print_delegation_info (builder .edge_manager .exported_program ().graph_module )
782784
783785 return builder .to_executorch (passes = additional_passes )
@@ -790,52 +792,66 @@ def _to_edge_and_lower_llama( # noqa: C901
790792 pt2e_quant_params ,
791793 quantizers ,
792794 quant_dtype ,
793- args ,
795+ vulkan : bool = False ,
796+ mps : bool = False ,
797+ coreml : bool = False ,
798+ qnn : bool = False ,
799+ dtype_override : str = "fp32" ,
800+ enable_dynamic_shape : bool = True ,
801+ use_kv_cache : bool = False ,
802+ embedding_quantize : Optional [str ] = None ,
803+ pt2e_quantize : Optional [str ] = None ,
804+ coreml_ios : int = 15 ,
805+ coreml_quantize : Optional [str ] = None ,
806+ coreml_compute_units : str = "cpu_only" ,
807+ use_qnn_sha : bool = False ,
808+ num_sharding : int = 0 ,
809+ soc_model : str = "SM8650" ,
810+ generate_etrecord : bool = False ,
811+ verbose : bool = False ,
794812):
795813 builder_exported_to_edge = builder_exported .pt2e_quantize (
796814 quantizers
797815 ).export_to_edge ()
798816
799817 # to_backend
800818 partitioners = []
801- if args . vulkan :
819+ if vulkan :
802820 partitioners .append (
803821 get_vulkan_partitioner (
804- args . dtype_override ,
805- args . enable_dynamic_shape ,
822+ dtype_override ,
823+ enable_dynamic_shape ,
806824 )
807825 )
808826 modelname = f"vulkan_{ modelname } "
809827
810828 # Need to remove asserts from the graph to prevent graph breaks
811829 remove_asserts (builder_exported_to_edge .edge_manager .exported_program ())
812830
813- if args . mps :
814- partitioners .append (get_mps_partitioner (args . use_kv_cache ))
831+ if mps :
832+ partitioners .append (get_mps_partitioner (use_kv_cache ))
815833 modelname = f"mps_{ modelname } "
816834
817- if args . coreml :
835+ if coreml :
818836 coreml_partitioner = get_coreml_partitioner (
819- args . coreml_ios ,
820- args . embedding_quantize ,
821- args . pt2e_quantize ,
822- args . coreml_quantize ,
823- args . coreml_compute_units ,
837+ coreml_ios ,
838+ embedding_quantize ,
839+ pt2e_quantize ,
840+ coreml_quantize ,
841+ coreml_compute_units ,
824842 )
825843 partitioners .append (coreml_partitioner )
826844 modelname = f"coreml_{ modelname } "
827845
828- if args . qnn :
846+ if qnn :
829847 logging .warning (
830848 "The model definition in current repro is not performant, please refer to the instruction"
831849 " in https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama/README.md for better performance."
832850 )
833851 from executorch .extension .llm .custom_ops import model_sharding
834852
835853 partitioners .append (
836- get_qnn_partitioner (
837- args .use_kv_cache , args .pt2e_quantize , args .num_sharding , args .soc_model
838- )
854+ get_qnn_partitioner (use_kv_cache , pt2e_quantize , num_sharding , soc_model )
839855 )
840856 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes`
841857 from executorch .backends .qualcomm ._passes import (
@@ -864,7 +880,7 @@ def _to_edge_and_lower_llama( # noqa: C901
864880 )
865881
866882 atten = builder_exported_to_edge .model .layers [0 ].attention
867- if args . use_qnn_sha :
883+ if use_qnn_sha :
868884 cache_shape = torch .Size (
869885 (atten .max_batch_size , atten .max_context_len , atten .head_dim )
870886 )
@@ -887,10 +903,10 @@ def _to_edge_and_lower_llama( # noqa: C901
887903 passes_job [TagQuantIO ][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY ][
888904 "get_quant_io_dtype_fn"
889905 ] = partial (get_custom_quant_ios_dtype , cache_shape )
890- if args . num_sharding > 0 :
906+ if num_sharding > 0 :
891907 SplitGraph , setting = model_sharding .get_split_graph_pass (
892908 builder_exported_to_edge .metadata ["get_n_layers" ],
893- shares = args . num_sharding ,
909+ shares = num_sharding ,
894910 )
895911 passes_job [SplitGraph ] = setting
896912 dep_table [SplitGraph ] = [FoldQDQ ]
@@ -905,17 +921,17 @@ def _to_edge_and_lower_llama( # noqa: C901
905921 for partitioner in partitioners :
906922 logging .info (f"--> { partitioner .__class__ .__name__ } " )
907923
908- if args . generate_etrecord :
924+ if generate_etrecord :
909925 if not builder_exported_to_edge .edge_manager :
910926 raise ValueError ("Unable to generate etrecord due to missing edge manager." )
911927
912928 logging .info ("Generating etrecord" )
913929 # Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
914930 edge_manager_copy = copy .deepcopy (builder_exported_to_edge .edge_manager )
915931 builder = builder_exported_to_edge .to_backend (partitioners )
916- if args . verbose :
932+ if verbose :
917933 print_delegation_info (builder .edge_manager .exported_program ().graph_module )
918- if args . num_sharding > 0 and args . qnn :
934+ if num_sharding > 0 and qnn :
919935 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`.
920936 from executorch .backends .qualcomm .utils .utils import canonicalize_program
921937
@@ -927,17 +943,17 @@ def _to_edge_and_lower_llama( # noqa: C901
927943
928944 # Generate ETRecord
929945 if edge_manager_copy :
930- generate_etrecord (
946+ generate_etrecord_func (
931947 et_record = "etrecord.bin" ,
932948 edge_dialect_program = edge_manager_copy ,
933949 executorch_program = builder .export_program ,
934950 )
935951 logging .info ("Generated etrecord.bin" )
936952 else :
937953 builder = builder_exported_to_edge .to_backend (partitioners )
938- if args . verbose :
954+ if verbose :
939955 print_delegation_info (builder .edge_manager .exported_program ().graph_module )
940- if args . num_sharding > 0 and args . qnn :
956+ if num_sharding > 0 and qnn :
941957 from executorch .backends .qualcomm .utils .utils import canonicalize_program
942958
943959 canonicalize_program (builder .edge_manager .exported_program ())
@@ -976,7 +992,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
976992 pt2e_quant_params ,
977993 quantizers ,
978994 quant_dtype ,
979- args ,
995+ xnnpack_extended_ops = args .xnnpack_extended_ops ,
996+ generate_etrecord = args .generate_etrecord ,
997+ verbose = args .verbose ,
980998 )
981999 else :
9821000 builder = _to_edge_and_lower_llama (
@@ -986,7 +1004,23 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
9861004 pt2e_quant_params ,
9871005 quantizers ,
9881006 quant_dtype ,
989- args ,
1007+ vulkan = args .vulkan ,
1008+ mps = args .mps ,
1009+ coreml = args .coreml ,
1010+ qnn = args .qnn ,
1011+ dtype_override = args .dtype_override ,
1012+ enable_dynamic_shape = args .enable_dynamic_shape ,
1013+ use_kv_cache = args .use_kv_cache ,
1014+ embedding_quantize = args .embedding_quantize ,
1015+ pt2e_quantize = args .pt2e_quantize ,
1016+ coreml_ios = args .coreml_ios ,
1017+ coreml_quantize = args .coreml_quantize ,
1018+ coreml_compute_units = args .coreml_compute_units ,
1019+ use_qnn_sha = args .use_qnn_sha ,
1020+ num_sharding = args .num_sharding ,
1021+ soc_model = args .soc_model ,
1022+ generate_etrecord = args .generate_etrecord ,
1023+ verbose = args .verbose ,
9901024 )
9911025
9921026 if args .profile_memory :
0 commit comments