27
27
from executorch .backends .vulkan ._passes .remove_asserts import remove_asserts
28
28
from executorch .devtools .backend_debug import print_delegation_info
29
29
30
- from executorch .devtools .etrecord import generate_etrecord
30
+ from executorch .devtools .etrecord import generate_etrecord as generate_etrecord_func
31
31
from executorch .examples .models .llama .hf_download import (
32
32
download_and_convert_hf_checkpoint ,
33
33
)
@@ -749,7 +749,9 @@ def _to_edge_and_lower_llama_xnnpack(
749
749
pt2e_quant_params ,
750
750
quantizers ,
751
751
quant_dtype ,
752
- args ,
752
+ xnnpack_extended_ops : bool = False ,
753
+ generate_etrecord : bool = False ,
754
+ verbose : bool = False ,
753
755
) -> LLMEdgeManager : # noqa: C901
754
756
partitioners = []
755
757
@@ -758,7 +760,7 @@ def _to_edge_and_lower_llama_xnnpack(
758
760
759
761
modelname = f"xnnpack_dq_{ modelname } "
760
762
761
- if args . xnnpack_extended_ops :
763
+ if xnnpack_extended_ops :
762
764
partitioners .append (
763
765
get_xnnpack_partitioner (dynamic_quant_only_partitioner = False )
764
766
)
@@ -769,15 +771,15 @@ def _to_edge_and_lower_llama_xnnpack(
769
771
logging .info (f"--> { partitioner .__class__ .__name__ } " )
770
772
771
773
# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
772
- if args . generate_etrecord :
774
+ if generate_etrecord :
773
775
raise NotImplementedError (
774
776
"export_llama does not support XNNPack and generating ETRecord at the moment."
775
777
)
776
778
777
779
builder = builder_exported .pt2e_quantize (quantizers ).to_edge_transform_and_lower (
778
780
partitioners
779
781
)
780
- if args . verbose :
782
+ if verbose :
781
783
print_delegation_info (builder .edge_manager .exported_program ().graph_module )
782
784
783
785
return builder .to_executorch (passes = additional_passes )
@@ -790,52 +792,66 @@ def _to_edge_and_lower_llama( # noqa: C901
790
792
pt2e_quant_params ,
791
793
quantizers ,
792
794
quant_dtype ,
793
- args ,
795
+ vulkan : bool = False ,
796
+ mps : bool = False ,
797
+ coreml : bool = False ,
798
+ qnn : bool = False ,
799
+ dtype_override : str = "fp32" ,
800
+ enable_dynamic_shape : bool = True ,
801
+ use_kv_cache : bool = False ,
802
+ embedding_quantize : Optional [str ] = None ,
803
+ pt2e_quantize : Optional [str ] = None ,
804
+ coreml_ios : int = 15 ,
805
+ coreml_quantize : Optional [str ] = None ,
806
+ coreml_compute_units : str = "cpu_only" ,
807
+ use_qnn_sha : bool = False ,
808
+ num_sharding : int = 0 ,
809
+ soc_model : str = "SM8650" ,
810
+ generate_etrecord : bool = False ,
811
+ verbose : bool = False ,
794
812
):
795
813
builder_exported_to_edge = builder_exported .pt2e_quantize (
796
814
quantizers
797
815
).export_to_edge ()
798
816
799
817
# to_backend
800
818
partitioners = []
801
- if args . vulkan :
819
+ if vulkan :
802
820
partitioners .append (
803
821
get_vulkan_partitioner (
804
- args . dtype_override ,
805
- args . enable_dynamic_shape ,
822
+ dtype_override ,
823
+ enable_dynamic_shape ,
806
824
)
807
825
)
808
826
modelname = f"vulkan_{ modelname } "
809
827
810
828
# Need to remove asserts from the graph to prevent graph breaks
811
829
remove_asserts (builder_exported_to_edge .edge_manager .exported_program ())
812
830
813
- if args . mps :
814
- partitioners .append (get_mps_partitioner (args . use_kv_cache ))
831
+ if mps :
832
+ partitioners .append (get_mps_partitioner (use_kv_cache ))
815
833
modelname = f"mps_{ modelname } "
816
834
817
- if args . coreml :
835
+ if coreml :
818
836
coreml_partitioner = get_coreml_partitioner (
819
- args . coreml_ios ,
820
- args . embedding_quantize ,
821
- args . pt2e_quantize ,
822
- args . coreml_quantize ,
823
- args . coreml_compute_units ,
837
+ coreml_ios ,
838
+ embedding_quantize ,
839
+ pt2e_quantize ,
840
+ coreml_quantize ,
841
+ coreml_compute_units ,
824
842
)
825
843
partitioners .append (coreml_partitioner )
826
844
modelname = f"coreml_{ modelname } "
827
845
828
- if args . qnn :
846
+ if qnn :
829
847
logging .warning (
830
848
"The model definition in current repro is not performant, please refer to the instruction"
831
849
" in https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama/README.md for better performance."
832
850
)
833
851
from executorch .extension .llm .custom_ops import model_sharding
834
852
835
853
partitioners .append (
836
- get_qnn_partitioner (
837
- args .use_kv_cache , args .pt2e_quantize , args .num_sharding , args .soc_model
838
- )
854
+ get_qnn_partitioner (use_kv_cache , pt2e_quantize , num_sharding , soc_model )
839
855
)
840
856
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes`
841
857
from executorch .backends .qualcomm ._passes import (
@@ -864,7 +880,7 @@ def _to_edge_and_lower_llama( # noqa: C901
864
880
)
865
881
866
882
atten = builder_exported_to_edge .model .layers [0 ].attention
867
- if args . use_qnn_sha :
883
+ if use_qnn_sha :
868
884
cache_shape = torch .Size (
869
885
(atten .max_batch_size , atten .max_context_len , atten .head_dim )
870
886
)
@@ -887,10 +903,10 @@ def _to_edge_and_lower_llama( # noqa: C901
887
903
passes_job [TagQuantIO ][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY ][
888
904
"get_quant_io_dtype_fn"
889
905
] = partial (get_custom_quant_ios_dtype , cache_shape )
890
- if args . num_sharding > 0 :
906
+ if num_sharding > 0 :
891
907
SplitGraph , setting = model_sharding .get_split_graph_pass (
892
908
builder_exported_to_edge .metadata ["get_n_layers" ],
893
- shares = args . num_sharding ,
909
+ shares = num_sharding ,
894
910
)
895
911
passes_job [SplitGraph ] = setting
896
912
dep_table [SplitGraph ] = [FoldQDQ ]
@@ -905,17 +921,17 @@ def _to_edge_and_lower_llama( # noqa: C901
905
921
for partitioner in partitioners :
906
922
logging .info (f"--> { partitioner .__class__ .__name__ } " )
907
923
908
- if args . generate_etrecord :
924
+ if generate_etrecord :
909
925
if not builder_exported_to_edge .edge_manager :
910
926
raise ValueError ("Unable to generate etrecord due to missing edge manager." )
911
927
912
928
logging .info ("Generating etrecord" )
913
929
# Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
914
930
edge_manager_copy = copy .deepcopy (builder_exported_to_edge .edge_manager )
915
931
builder = builder_exported_to_edge .to_backend (partitioners )
916
- if args . verbose :
932
+ if verbose :
917
933
print_delegation_info (builder .edge_manager .exported_program ().graph_module )
918
- if args . num_sharding > 0 and args . qnn :
934
+ if num_sharding > 0 and qnn :
919
935
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`.
920
936
from executorch .backends .qualcomm .utils .utils import canonicalize_program
921
937
@@ -927,17 +943,17 @@ def _to_edge_and_lower_llama( # noqa: C901
927
943
928
944
# Generate ETRecord
929
945
if edge_manager_copy :
930
- generate_etrecord (
946
+ generate_etrecord_func (
931
947
et_record = "etrecord.bin" ,
932
948
edge_dialect_program = edge_manager_copy ,
933
949
executorch_program = builder .export_program ,
934
950
)
935
951
logging .info ("Generated etrecord.bin" )
936
952
else :
937
953
builder = builder_exported_to_edge .to_backend (partitioners )
938
- if args . verbose :
954
+ if verbose :
939
955
print_delegation_info (builder .edge_manager .exported_program ().graph_module )
940
- if args . num_sharding > 0 and args . qnn :
956
+ if num_sharding > 0 and qnn :
941
957
from executorch .backends .qualcomm .utils .utils import canonicalize_program
942
958
943
959
canonicalize_program (builder .edge_manager .exported_program ())
@@ -976,7 +992,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
976
992
pt2e_quant_params ,
977
993
quantizers ,
978
994
quant_dtype ,
979
- args ,
995
+ xnnpack_extended_ops = args .xnnpack_extended_ops ,
996
+ generate_etrecord = args .generate_etrecord ,
997
+ verbose = args .verbose ,
980
998
)
981
999
else :
982
1000
builder = _to_edge_and_lower_llama (
@@ -986,7 +1004,23 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
986
1004
pt2e_quant_params ,
987
1005
quantizers ,
988
1006
quant_dtype ,
989
- args ,
1007
+ vulkan = args .vulkan ,
1008
+ mps = args .mps ,
1009
+ coreml = args .coreml ,
1010
+ qnn = args .qnn ,
1011
+ dtype_override = args .dtype_override ,
1012
+ enable_dynamic_shape = args .enable_dynamic_shape ,
1013
+ use_kv_cache = args .use_kv_cache ,
1014
+ embedding_quantize = args .embedding_quantize ,
1015
+ pt2e_quantize = args .pt2e_quantize ,
1016
+ coreml_ios = args .coreml_ios ,
1017
+ coreml_quantize = args .coreml_quantize ,
1018
+ coreml_compute_units = args .coreml_compute_units ,
1019
+ use_qnn_sha = args .use_qnn_sha ,
1020
+ num_sharding = args .num_sharding ,
1021
+ soc_model = args .soc_model ,
1022
+ generate_etrecord = args .generate_etrecord ,
1023
+ verbose = args .verbose ,
990
1024
)
991
1025
992
1026
if args .profile_memory :
0 commit comments