Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f69655c
Use llm_config instead of args in export_llama functions
jackzhxng May 27, 2025
d9c70c2
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 27, 2025
209fd7f
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 27, 2025
45571eb
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 28, 2025
97ec69c
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 28, 2025
b928cc7
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 28, 2025
b08f22b
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 28, 2025
00aa0e8
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng May 29, 2025
20bdaa6
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 2, 2025
900bbdf
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 3, 2025
a14f548
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 3, 2025
d7d33d7
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 3, 2025
4a875d8
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 4, 2025
6f6bf53
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 5, 2025
4760311
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 5, 2025
1a85097
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 5, 2025
792022d
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 5, 2025
54477dc
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 6, 2025
6f3e0a5
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 6, 2025
c447cbd
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 6, 2025
679fe9e
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 9, 2025
52455bc
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 9, 2025
9a15088
Update on "Use llm_config instead of args in export_llama functions"
jackzhxng Jun 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
TosaPipelineMI,
)

from executorch.examples.models.llama.config.llm_config_utils import (
convert_args_to_llm_config,
)
from executorch.examples.models.llama.export_llama_lib import (
build_args_parser,
get_llama_model,
Expand Down Expand Up @@ -89,8 +92,9 @@ def prepare_model(self):
]
parser = build_args_parser()
args = parser.parse_args(args)
llm_config = convert_args_to_llm_config(args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like my comment in previous PR please do LlmConfig.from_args(args)


llama_model, llama_inputs, llama_meta = get_llama_model(args)
llama_model, llama_inputs, llama_meta = get_llama_model(llm_config)

return llama_model, llama_inputs, llama_meta

Expand Down
29 changes: 13 additions & 16 deletions examples/apple/mps/scripts/mps_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
serialize_from_bundled_program_to_flatbuffer,
)

from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
Expand Down Expand Up @@ -131,28 +132,24 @@ def parse_args():
return args


def get_model_config(args):
model_config = {}
model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0]
model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1]

if args.model_name == "llama2":
if args.checkpoint:
model_config["checkpoint"] = args.checkpoint
if args.params:
model_config["params"] = args.params
model_config["use_kv_cache"] = True
return model_config


if __name__ == "__main__":
args = parse_args()

if args.model_name not in MODEL_NAME_TO_MODEL:
raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")

model_config = get_model_config(args)
model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config)
llm_config = LlmConfig()
if args.model_name == "llama2":
if args.checkpoint:
llm_config.base.checkpoint = args.checkpoint
if args.params:
llm_config.base.params = args.params
llm_config.model.use_kv_cache = True
model, example_inputs, _, _ = EagerModelFactory.create_model(
module_name=MODEL_NAME_TO_MODEL[args.model_name][0],
model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1],
llm_config=llm_config,
)

model = model.eval()

Expand Down
2 changes: 2 additions & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ runtime.python_library(
"//caffe2:torch",
"//executorch/examples/models:model_base",
"//executorch/examples/models/llama:llama_transformer",
"//executorch/examples/models/llama/config:llm_config",
"//executorch/examples/models:checkpoint",
],
)
Expand Down Expand Up @@ -266,6 +267,7 @@ runtime.python_library(
":export_library",
"//executorch/examples/models/llama/config:llm_config",
"fbsource//third-party/pypi/hydra-core:hydra-core",
"fbsource//third-party/pypi/omegaconf:omegaconf",
],
)

Expand Down
173 changes: 112 additions & 61 deletions examples/models/llama/config/llm_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,92 +26,143 @@ def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
llm_config = LlmConfig()

# BaseConfig
llm_config.base.model_class = ModelType(args.model)
llm_config.base.params = args.params
llm_config.base.checkpoint = args.checkpoint
llm_config.base.checkpoint_dir = args.checkpoint_dir
llm_config.base.tokenizer_path = args.tokenizer_path
llm_config.base.metadata = args.metadata
llm_config.base.use_lora = bool(args.use_lora)
llm_config.base.fairseq2 = args.fairseq2
if hasattr(args, "model"):
llm_config.base.model_class = ModelType(args.model)
if hasattr(args, "params"):
llm_config.base.params = args.params
if hasattr(args, "checkpoint"):
llm_config.base.checkpoint = args.checkpoint
if hasattr(args, "checkpoint_dir"):
llm_config.base.checkpoint_dir = args.checkpoint_dir
if hasattr(args, "tokenizer_path"):
llm_config.base.tokenizer_path = args.tokenizer_path
if hasattr(args, "metadata"):
llm_config.base.metadata = args.metadata
if hasattr(args, "use_lora"):
llm_config.base.use_lora = args.use_lora
if hasattr(args, "fairseq2"):
llm_config.base.fairseq2 = args.fairseq2

# PreqMode settings
if args.preq_mode:
if hasattr(args, "preq_mode") and args.preq_mode:
llm_config.base.preq_mode = PreqMode(args.preq_mode)
llm_config.base.preq_group_size = args.preq_group_size
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize
if hasattr(args, "preq_group_size"):
llm_config.base.preq_group_size = args.preq_group_size
if hasattr(args, "preq_embedding_quantize"):
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize

# ModelConfig
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
llm_config.model.use_shared_embedding = args.use_shared_embedding
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
llm_config.model.expand_rope_table = args.expand_rope_table
llm_config.model.use_attention_sink = args.use_attention_sink
llm_config.model.output_prune_map = args.output_prune_map
llm_config.model.input_prune_map = args.input_prune_map
llm_config.model.use_kv_cache = args.use_kv_cache
llm_config.model.quantize_kv_cache = args.quantize_kv_cache
llm_config.model.local_global_attention = args.local_global_attention
if hasattr(args, "dtype_override"):
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
if hasattr(args, "enable_dynamic_shape"):
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
if hasattr(args, "use_shared_embedding"):
llm_config.model.use_shared_embedding = args.use_shared_embedding
if hasattr(args, "use_sdpa_with_kv_cache"):
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
if hasattr(args, "expand_rope_table"):
llm_config.model.expand_rope_table = args.expand_rope_table
if hasattr(args, "use_attention_sink"):
llm_config.model.use_attention_sink = args.use_attention_sink
if hasattr(args, "output_prune_map"):
llm_config.model.output_prune_map = args.output_prune_map
if hasattr(args, "input_prune_map"):
llm_config.model.input_prune_map = args.input_prune_map
if hasattr(args, "use_kv_cache"):
llm_config.model.use_kv_cache = args.use_kv_cache
if hasattr(args, "quantize_kv_cache"):
llm_config.model.quantize_kv_cache = args.quantize_kv_cache
if hasattr(args, "local_global_attention"):
llm_config.model.local_global_attention = args.local_global_attention

# ExportConfig
llm_config.export.max_seq_length = args.max_seq_length
llm_config.export.max_context_length = args.max_context_length
llm_config.export.output_dir = args.output_dir
llm_config.export.output_name = args.output_name
llm_config.export.so_library = args.so_library
llm_config.export.export_only = args.export_only
if hasattr(args, "max_seq_length"):
llm_config.export.max_seq_length = args.max_seq_length
if hasattr(args, "max_context_length"):
llm_config.export.max_context_length = args.max_context_length
if hasattr(args, "output_dir"):
llm_config.export.output_dir = args.output_dir
if hasattr(args, "output_name"):
llm_config.export.output_name = args.output_name
if hasattr(args, "so_library"):
llm_config.export.so_library = args.so_library
if hasattr(args, "export_only"):
llm_config.export.export_only = args.export_only

# QuantizationConfig
llm_config.quantization.qmode = args.quantization_mode
llm_config.quantization.embedding_quantize = args.embedding_quantize
if args.pt2e_quantize:
if hasattr(args, "quantization_mode"):
llm_config.quantization.qmode = args.quantization_mode
if hasattr(args, "embedding_quantize"):
llm_config.quantization.embedding_quantize = args.embedding_quantize
if hasattr(args, "pt2e_quantize") and args.pt2e_quantize:
llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize)
llm_config.quantization.group_size = args.group_size
if args.use_spin_quant:
if hasattr(args, "group_size"):
llm_config.quantization.group_size = args.group_size
if hasattr(args, "use_spin_quant") and args.use_spin_quant:
llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant)
llm_config.quantization.use_qat = args.use_qat
llm_config.quantization.calibration_tasks = args.calibration_tasks
llm_config.quantization.calibration_limit = args.calibration_limit
llm_config.quantization.calibration_seq_length = args.calibration_seq_length
llm_config.quantization.calibration_data = args.calibration_data

# BackendConfig
# XNNPack
llm_config.backend.xnnpack.enabled = args.xnnpack
llm_config.backend.xnnpack.extended_ops = args.xnnpack_extended_ops
if hasattr(args, "use_qat"):
llm_config.quantization.use_qat = args.use_qat
if hasattr(args, "calibration_tasks"):
llm_config.quantization.calibration_tasks = args.calibration_tasks
if hasattr(args, "calibration_limit"):
llm_config.quantization.calibration_limit = args.calibration_limit
if hasattr(args, "calibration_seq_length"):
llm_config.quantization.calibration_seq_length = args.calibration_seq_length
if hasattr(args, "calibration_data"):
llm_config.quantization.calibration_data = args.calibration_data

# BackendConfig - XNNPack
if hasattr(args, "xnnpack"):
llm_config.backend.xnnpack.enabled = args.xnnpack
if hasattr(args, "xnnpack_extended_ops"):
llm_config.backend.xnnpack.extended_ops = args.xnnpack_extended_ops

# CoreML
llm_config.backend.coreml.enabled = args.coreml
if hasattr(args, "coreml"):
llm_config.backend.coreml.enabled = args.coreml
llm_config.backend.coreml.enable_state = getattr(args, "coreml_enable_state", False)
llm_config.backend.coreml.preserve_sdpa = getattr(
args, "coreml_preserve_sdpa", False
)
if args.coreml_quantize:
if hasattr(args, "coreml_quantize") and args.coreml_quantize:
llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize)
llm_config.backend.coreml.ios = args.coreml_ios
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
args.coreml_compute_units
)
if hasattr(args, "coreml_ios"):
llm_config.backend.coreml.ios = args.coreml_ios
if hasattr(args, "coreml_compute_units"):
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
args.coreml_compute_units
)

# Vulkan
llm_config.backend.vulkan.enabled = args.vulkan
if hasattr(args, "vulkan"):
llm_config.backend.vulkan.enabled = args.vulkan

# QNN
llm_config.backend.qnn.enabled = args.qnn
llm_config.backend.qnn.use_sha = args.use_qnn_sha
llm_config.backend.qnn.soc_model = args.soc_model
llm_config.backend.qnn.optimized_rotation_path = args.optimized_rotation_path
llm_config.backend.qnn.num_sharding = args.num_sharding
if hasattr(args, "qnn"):
llm_config.backend.qnn.enabled = args.qnn
if hasattr(args, "use_qnn_sha"):
llm_config.backend.qnn.use_sha = args.use_qnn_sha
if hasattr(args, "soc_model"):
llm_config.backend.qnn.soc_model = args.soc_model
if hasattr(args, "optimized_rotation_path"):
llm_config.backend.qnn.optimized_rotation_path = args.optimized_rotation_path
if hasattr(args, "num_sharding"):
llm_config.backend.qnn.num_sharding = args.num_sharding

# MPS
llm_config.backend.mps.enabled = args.mps
if hasattr(args, "mps"):
llm_config.backend.mps.enabled = args.mps

# DebugConfig
llm_config.debug.profile_memory = args.profile_memory
llm_config.debug.profile_path = args.profile_path
llm_config.debug.generate_etrecord = args.generate_etrecord
llm_config.debug.generate_full_logits = args.generate_full_logits
llm_config.debug.verbose = args.verbose
if hasattr(args, "profile_memory"):
llm_config.debug.profile_memory = args.profile_memory
if hasattr(args, "profile_path"):
llm_config.debug.profile_path = args.profile_path
if hasattr(args, "generate_etrecord"):
llm_config.debug.generate_etrecord = args.generate_etrecord
if hasattr(args, "generate_full_logits"):
llm_config.debug.generate_full_logits = args.generate_full_logits
if hasattr(args, "verbose"):
llm_config.debug.verbose = args.verbose

return llm_config
Loading
Loading