Skip to content

Commit 4006cd2

Browse files
authored
Refactor _to_edge_and_lower_llama to remove args
Differential Revision: D73785343 Pull Request resolved: #10520
1 parent 4a738bd commit 4006cd2

File tree

1 file changed

+66
-32
lines changed

1 file changed

+66
-32
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
2828
from executorch.devtools.backend_debug import print_delegation_info
2929

30-
from executorch.devtools.etrecord import generate_etrecord
30+
from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func
3131
from executorch.examples.models.llama.hf_download import (
3232
download_and_convert_hf_checkpoint,
3333
)
@@ -749,7 +749,9 @@ def _to_edge_and_lower_llama_xnnpack(
749749
pt2e_quant_params,
750750
quantizers,
751751
quant_dtype,
752-
args,
752+
xnnpack_extended_ops: bool = False,
753+
generate_etrecord: bool = False,
754+
verbose: bool = False,
753755
) -> LLMEdgeManager: # noqa: C901
754756
partitioners = []
755757

@@ -758,7 +760,7 @@ def _to_edge_and_lower_llama_xnnpack(
758760

759761
modelname = f"xnnpack_dq_{modelname}"
760762

761-
if args.xnnpack_extended_ops:
763+
if xnnpack_extended_ops:
762764
partitioners.append(
763765
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
764766
)
@@ -769,15 +771,15 @@ def _to_edge_and_lower_llama_xnnpack(
769771
logging.info(f"--> {partitioner.__class__.__name__}")
770772

771773
# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
772-
if args.generate_etrecord:
774+
if generate_etrecord:
773775
raise NotImplementedError(
774776
"export_llama does not support XNNPack and generating ETRecord at the moment."
775777
)
776778

777779
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
778780
partitioners
779781
)
780-
if args.verbose:
782+
if verbose:
781783
print_delegation_info(builder.edge_manager.exported_program().graph_module)
782784

783785
return builder.to_executorch(passes=additional_passes)
@@ -790,52 +792,66 @@ def _to_edge_and_lower_llama( # noqa: C901
790792
pt2e_quant_params,
791793
quantizers,
792794
quant_dtype,
793-
args,
795+
vulkan: bool = False,
796+
mps: bool = False,
797+
coreml: bool = False,
798+
qnn: bool = False,
799+
dtype_override: str = "fp32",
800+
enable_dynamic_shape: bool = True,
801+
use_kv_cache: bool = False,
802+
embedding_quantize: Optional[str] = None,
803+
pt2e_quantize: Optional[str] = None,
804+
coreml_ios: int = 15,
805+
coreml_quantize: Optional[str] = None,
806+
coreml_compute_units: str = "cpu_only",
807+
use_qnn_sha: bool = False,
808+
num_sharding: int = 0,
809+
soc_model: str = "SM8650",
810+
generate_etrecord: bool = False,
811+
verbose: bool = False,
794812
):
795813
builder_exported_to_edge = builder_exported.pt2e_quantize(
796814
quantizers
797815
).export_to_edge()
798816

799817
# to_backend
800818
partitioners = []
801-
if args.vulkan:
819+
if vulkan:
802820
partitioners.append(
803821
get_vulkan_partitioner(
804-
args.dtype_override,
805-
args.enable_dynamic_shape,
822+
dtype_override,
823+
enable_dynamic_shape,
806824
)
807825
)
808826
modelname = f"vulkan_{modelname}"
809827

810828
# Need to remove asserts from the graph to prevent graph breaks
811829
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
812830

813-
if args.mps:
814-
partitioners.append(get_mps_partitioner(args.use_kv_cache))
831+
if mps:
832+
partitioners.append(get_mps_partitioner(use_kv_cache))
815833
modelname = f"mps_{modelname}"
816834

817-
if args.coreml:
835+
if coreml:
818836
coreml_partitioner = get_coreml_partitioner(
819-
args.coreml_ios,
820-
args.embedding_quantize,
821-
args.pt2e_quantize,
822-
args.coreml_quantize,
823-
args.coreml_compute_units,
837+
coreml_ios,
838+
embedding_quantize,
839+
pt2e_quantize,
840+
coreml_quantize,
841+
coreml_compute_units,
824842
)
825843
partitioners.append(coreml_partitioner)
826844
modelname = f"coreml_{modelname}"
827845

828-
if args.qnn:
846+
if qnn:
829847
logging.warning(
830848
"The model definition in current repro is not performant, please refer to the instruction"
831849
" in https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama/README.md for better performance."
832850
)
833851
from executorch.extension.llm.custom_ops import model_sharding
834852

835853
partitioners.append(
836-
get_qnn_partitioner(
837-
args.use_kv_cache, args.pt2e_quantize, args.num_sharding, args.soc_model
838-
)
854+
get_qnn_partitioner(use_kv_cache, pt2e_quantize, num_sharding, soc_model)
839855
)
840856
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes`
841857
from executorch.backends.qualcomm._passes import (
@@ -864,7 +880,7 @@ def _to_edge_and_lower_llama( # noqa: C901
864880
)
865881

866882
atten = builder_exported_to_edge.model.layers[0].attention
867-
if args.use_qnn_sha:
883+
if use_qnn_sha:
868884
cache_shape = torch.Size(
869885
(atten.max_batch_size, atten.max_context_len, atten.head_dim)
870886
)
@@ -887,10 +903,10 @@ def _to_edge_and_lower_llama( # noqa: C901
887903
passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][
888904
"get_quant_io_dtype_fn"
889905
] = partial(get_custom_quant_ios_dtype, cache_shape)
890-
if args.num_sharding > 0:
906+
if num_sharding > 0:
891907
SplitGraph, setting = model_sharding.get_split_graph_pass(
892908
builder_exported_to_edge.metadata["get_n_layers"],
893-
shares=args.num_sharding,
909+
shares=num_sharding,
894910
)
895911
passes_job[SplitGraph] = setting
896912
dep_table[SplitGraph] = [FoldQDQ]
@@ -905,17 +921,17 @@ def _to_edge_and_lower_llama( # noqa: C901
905921
for partitioner in partitioners:
906922
logging.info(f"--> {partitioner.__class__.__name__}")
907923

908-
if args.generate_etrecord:
924+
if generate_etrecord:
909925
if not builder_exported_to_edge.edge_manager:
910926
raise ValueError("Unable to generate etrecord due to missing edge manager.")
911927

912928
logging.info("Generating etrecord")
913929
# Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
914930
edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager)
915931
builder = builder_exported_to_edge.to_backend(partitioners)
916-
if args.verbose:
932+
if verbose:
917933
print_delegation_info(builder.edge_manager.exported_program().graph_module)
918-
if args.num_sharding > 0 and args.qnn:
934+
if num_sharding > 0 and qnn:
919935
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`.
920936
from executorch.backends.qualcomm.utils.utils import canonicalize_program
921937

@@ -927,17 +943,17 @@ def _to_edge_and_lower_llama( # noqa: C901
927943

928944
# Generate ETRecord
929945
if edge_manager_copy:
930-
generate_etrecord(
946+
generate_etrecord_func(
931947
et_record="etrecord.bin",
932948
edge_dialect_program=edge_manager_copy,
933949
executorch_program=builder.export_program,
934950
)
935951
logging.info("Generated etrecord.bin")
936952
else:
937953
builder = builder_exported_to_edge.to_backend(partitioners)
938-
if args.verbose:
954+
if verbose:
939955
print_delegation_info(builder.edge_manager.exported_program().graph_module)
940-
if args.num_sharding > 0 and args.qnn:
956+
if num_sharding > 0 and qnn:
941957
from executorch.backends.qualcomm.utils.utils import canonicalize_program
942958

943959
canonicalize_program(builder.edge_manager.exported_program())
@@ -976,7 +992,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
976992
pt2e_quant_params,
977993
quantizers,
978994
quant_dtype,
979-
args,
995+
xnnpack_extended_ops=args.xnnpack_extended_ops,
996+
generate_etrecord=args.generate_etrecord,
997+
verbose=args.verbose,
980998
)
981999
else:
9821000
builder = _to_edge_and_lower_llama(
@@ -986,7 +1004,23 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
9861004
pt2e_quant_params,
9871005
quantizers,
9881006
quant_dtype,
989-
args,
1007+
vulkan=args.vulkan,
1008+
mps=args.mps,
1009+
coreml=args.coreml,
1010+
qnn=args.qnn,
1011+
dtype_override=args.dtype_override,
1012+
enable_dynamic_shape=args.enable_dynamic_shape,
1013+
use_kv_cache=args.use_kv_cache,
1014+
embedding_quantize=args.embedding_quantize,
1015+
pt2e_quantize=args.pt2e_quantize,
1016+
coreml_ios=args.coreml_ios,
1017+
coreml_quantize=args.coreml_quantize,
1018+
coreml_compute_units=args.coreml_compute_units,
1019+
use_qnn_sha=args.use_qnn_sha,
1020+
num_sharding=args.num_sharding,
1021+
soc_model=args.soc_model,
1022+
generate_etrecord=args.generate_etrecord,
1023+
verbose=args.verbose,
9901024
)
9911025

9921026
if args.profile_memory:

0 commit comments

Comments
 (0)