Skip to content

Commit 3c6d6e0

Browse files
committed
refactor: Replace args parameter with individual function parameters in _to_edge_and_lower_llama_xnnpack and _to_edge_and_lower_llama
1 parent 9e59c19 commit 3c6d6e0

File tree

1 file changed

+64
-28
lines changed

1 file changed

+64
-28
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,9 @@ def _to_edge_and_lower_llama_xnnpack(
739739
pt2e_quant_params,
740740
quantizers,
741741
quant_dtype,
742-
args,
742+
xnnpack_extended_ops: bool = False,
743+
generate_etrecord: bool = False,
744+
verbose: bool = False,
743745
) -> LLMEdgeManager: # noqa: C901
744746
partitioners = []
745747

@@ -748,7 +750,7 @@ def _to_edge_and_lower_llama_xnnpack(
748750

749751
modelname = f"xnnpack_dq_{modelname}"
750752

751-
if args.xnnpack_extended_ops:
753+
if xnnpack_extended_ops:
752754
partitioners.append(
753755
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
754756
)
@@ -759,15 +761,15 @@ def _to_edge_and_lower_llama_xnnpack(
759761
logging.info(f"--> {partitioner.__class__.__name__}")
760762

761763
# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
762-
if args.generate_etrecord:
764+
if generate_etrecord:
763765
raise NotImplementedError(
764766
"export_llama does not support XNNPack and generating ETRecord at the moment."
765767
)
766768

767769
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
768770
partitioners
769771
)
770-
if args.verbose:
772+
if verbose:
771773
print_delegation_info(builder.edge_manager.exported_program().graph_module)
772774

773775
return builder.to_executorch(passes=additional_passes)
@@ -780,42 +782,58 @@ def _to_edge_and_lower_llama( # noqa: C901
780782
pt2e_quant_params,
781783
quantizers,
782784
quant_dtype,
783-
args,
785+
vulkan: bool = False,
786+
mps: bool = False,
787+
coreml: bool = False,
788+
qnn: bool = False,
789+
dtype_override: str = "fp32",
790+
enable_dynamic_shape: bool = True,
791+
use_kv_cache: bool = False,
792+
embedding_quantize: Optional[str] = None,
793+
pt2e_quantize: Optional[str] = None,
794+
coreml_ios: int = 15,
795+
coreml_quantize: Optional[str] = None,
796+
coreml_compute_units: str = "cpu_only",
797+
use_qnn_sha: bool = False,
798+
num_sharding: int = 0,
799+
soc_model: str = "SM8650",
800+
generate_etrecord: bool = False,
801+
verbose: bool = False,
784802
):
785803
builder_exported_to_edge = builder_exported.pt2e_quantize(
786804
quantizers
787805
).export_to_edge()
788806

789807
# to_backend
790808
partitioners = []
791-
if args.vulkan:
809+
if vulkan:
792810
partitioners.append(
793811
get_vulkan_partitioner(
794-
args.dtype_override,
795-
args.enable_dynamic_shape,
812+
dtype_override,
813+
enable_dynamic_shape,
796814
)
797815
)
798816
modelname = f"vulkan_{modelname}"
799817

800818
# Need to remove asserts from the graph to prevent graph breaks
801819
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
802820

803-
if args.mps:
804-
partitioners.append(get_mps_partitioner(args.use_kv_cache))
821+
if mps:
822+
partitioners.append(get_mps_partitioner(use_kv_cache))
805823
modelname = f"mps_{modelname}"
806824

807-
if args.coreml:
825+
if coreml:
808826
coreml_partitioner = get_coreml_partitioner(
809-
args.coreml_ios,
810-
args.embedding_quantize,
811-
args.pt2e_quantize,
812-
args.coreml_quantize,
813-
args.coreml_compute_units,
827+
coreml_ios,
828+
embedding_quantize,
829+
pt2e_quantize,
830+
coreml_quantize,
831+
coreml_compute_units,
814832
)
815833
partitioners.append(coreml_partitioner)
816834
modelname = f"coreml_{modelname}"
817835

818-
if args.qnn:
836+
if qnn:
819837
logging.warning(
820838
"The model definition in current repro is not performant, please refer to the instruction"
821839
" in https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama/README.md for better performance."
@@ -824,7 +842,7 @@ def _to_edge_and_lower_llama( # noqa: C901
824842

825843
partitioners.append(
826844
get_qnn_partitioner(
827-
args.use_kv_cache, args.pt2e_quantize, args.num_sharding, args.soc_model
845+
use_kv_cache, pt2e_quantize, num_sharding, soc_model
828846
)
829847
)
830848
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes`
@@ -854,7 +872,7 @@ def _to_edge_and_lower_llama( # noqa: C901
854872
)
855873

856874
atten = builder_exported_to_edge.model.layers[0].attention
857-
if args.use_qnn_sha:
875+
if use_qnn_sha:
858876
cache_shape = torch.Size(
859877
(atten.max_batch_size, atten.max_context_len, atten.head_dim)
860878
)
@@ -877,10 +895,10 @@ def _to_edge_and_lower_llama( # noqa: C901
877895
passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][
878896
"get_quant_io_dtype_fn"
879897
] = partial(get_custom_quant_ios_dtype, cache_shape)
880-
if args.num_sharding > 0:
898+
if num_sharding > 0:
881899
SplitGraph, setting = model_sharding.get_split_graph_pass(
882900
builder_exported_to_edge.metadata["get_n_layers"],
883-
shares=args.num_sharding,
901+
shares=num_sharding,
884902
)
885903
passes_job[SplitGraph] = setting
886904
dep_table[SplitGraph] = [FoldQDQ]
@@ -895,17 +913,17 @@ def _to_edge_and_lower_llama( # noqa: C901
895913
for partitioner in partitioners:
896914
logging.info(f"--> {partitioner.__class__.__name__}")
897915

898-
if args.generate_etrecord:
916+
if generate_etrecord:
899917
if not builder_exported_to_edge.edge_manager:
900918
raise ValueError("Unable to generate etrecord due to missing edge manager.")
901919

902920
logging.info("Generating etrecord")
903921
# Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
904922
edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager)
905923
builder = builder_exported_to_edge.to_backend(partitioners)
906-
if args.verbose:
924+
if verbose:
907925
print_delegation_info(builder.edge_manager.exported_program().graph_module)
908-
if args.num_sharding > 0 and args.qnn:
926+
if num_sharding > 0 and qnn:
909927
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`.
910928
from executorch.backends.qualcomm.utils.utils import canonicalize_program
911929

@@ -925,9 +943,9 @@ def _to_edge_and_lower_llama( # noqa: C901
925943
logging.info("Generated etrecord.bin")
926944
else:
927945
builder = builder_exported_to_edge.to_backend(partitioners)
928-
if args.verbose:
946+
if verbose:
929947
print_delegation_info(builder.edge_manager.exported_program().graph_module)
930-
if args.num_sharding > 0 and args.qnn:
948+
if num_sharding > 0 and qnn:
931949
from executorch.backends.qualcomm.utils.utils import canonicalize_program
932950

933951
canonicalize_program(builder.edge_manager.exported_program())
@@ -966,7 +984,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
966984
pt2e_quant_params,
967985
quantizers,
968986
quant_dtype,
969-
args,
987+
xnnpack_extended_ops=args.xnnpack_extended_ops,
988+
generate_etrecord=args.generate_etrecord,
989+
verbose=args.verbose,
970990
)
971991
else:
972992
builder = _to_edge_and_lower_llama(
@@ -976,7 +996,23 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
976996
pt2e_quant_params,
977997
quantizers,
978998
quant_dtype,
979-
args,
999+
vulkan=args.vulkan,
1000+
mps=args.mps,
1001+
coreml=args.coreml,
1002+
qnn=args.qnn,
1003+
dtype_override=args.dtype_override,
1004+
enable_dynamic_shape=args.enable_dynamic_shape,
1005+
use_kv_cache=args.use_kv_cache,
1006+
embedding_quantize=args.embedding_quantize,
1007+
pt2e_quantize=args.pt2e_quantize,
1008+
coreml_ios=args.coreml_ios,
1009+
coreml_quantize=args.coreml_quantize,
1010+
coreml_compute_units=args.coreml_compute_units,
1011+
use_qnn_sha=args.use_qnn_sha,
1012+
num_sharding=args.num_sharding,
1013+
soc_model=args.soc_model,
1014+
generate_etrecord=args.generate_etrecord,
1015+
verbose=args.verbose,
9801016
)
9811017

9821018
if args.profile_memory:

0 commit comments

Comments
 (0)