Skip to content

Commit 84e1fda

Browse files
committed
Convert args to LlmConfig
Pull Request resolved: #11081 @imported-using-ghimport Differential Revision: [D75263990](https://our.internmc.facebook.com/intern/diff/D75263990/) ghstack-source-id: 288486591
1 parent 11ed407 commit 84e1fda

File tree

1 file changed

+98
-3
lines changed

1 file changed

+98
-3
lines changed
Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3-
# Copyright 2025 Arm Limited and/or its affiliates.
43
#
54
# This source code is licensed under the BSD-style license found in the
65
# LICENSE file in the root directory of this source tree.
76

87
import argparse
98

10-
from executorch.examples.models.llama.config.llm_config import LlmConfig
9+
from executorch.examples.models.llama.config.llm_config import (
10+
CoreMLComputeUnit,
11+
CoreMLQuantize,
12+
DtypeOverride,
13+
LlmConfig,
14+
ModelType,
15+
PreqMode,
16+
Pt2eQuantize,
17+
SpinQuant,
18+
)
1119

1220

1321
def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
@@ -17,6 +25,93 @@ def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
1725
"""
1826
llm_config = LlmConfig()
1927

20-
# TODO: conversion code.
28+
# BaseConfig
29+
llm_config.base.model_class = ModelType(args.model)
30+
llm_config.base.params = args.params
31+
llm_config.base.checkpoint = args.checkpoint
32+
llm_config.base.checkpoint_dir = args.checkpoint_dir
33+
llm_config.base.tokenizer_path = args.tokenizer_path
34+
llm_config.base.metadata = args.metadata
35+
llm_config.base.use_lora = bool(args.use_lora)
36+
llm_config.base.fairseq2 = args.fairseq2
37+
38+
# PreqMode settings
39+
if args.preq_mode:
40+
llm_config.base.preq_mode = PreqMode(args.preq_mode)
41+
llm_config.base.preq_group_size = args.preq_group_size
42+
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize
43+
44+
# ModelConfig
45+
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
46+
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
47+
llm_config.model.use_shared_embedding = args.use_shared_embedding
48+
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
49+
llm_config.model.expand_rope_table = args.expand_rope_table
50+
llm_config.model.use_attention_sink = args.use_attention_sink
51+
llm_config.model.output_prune_map = args.output_prune_map
52+
llm_config.model.input_prune_map = args.input_prune_map
53+
llm_config.model.use_kv_cache = args.use_kv_cache
54+
llm_config.model.quantize_kv_cache = args.quantize_kv_cache
55+
llm_config.model.local_global_attention = args.local_global_attention
56+
57+
# ExportConfig
58+
llm_config.export.max_seq_length = args.max_seq_length
59+
llm_config.export.max_context_length = args.max_context_length
60+
llm_config.export.output_dir = args.output_dir
61+
llm_config.export.output_name = args.output_name
62+
llm_config.export.so_library = args.so_library
63+
llm_config.export.export_only = args.export_only
64+
65+
# QuantizationConfig
66+
llm_config.quantization.qmode = args.quantization_mode
67+
llm_config.quantization.embedding_quantize = args.embedding_quantize
68+
if args.pt2e_quantize:
69+
llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize)
70+
llm_config.quantization.group_size = args.group_size
71+
if args.use_spin_quant:
72+
llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant)
73+
llm_config.quantization.use_qat = args.use_qat
74+
llm_config.quantization.calibration_tasks = args.calibration_tasks
75+
llm_config.quantization.calibration_limit = args.calibration_limit
76+
llm_config.quantization.calibration_seq_length = args.calibration_seq_length
77+
llm_config.quantization.calibration_data = args.calibration_data
78+
79+
# BackendConfig
80+
# XNNPack
81+
llm_config.backend.xnnpack.enabled = args.xnnpack
82+
llm_config.backend.xnnpack.extended_ops = args.xnnpack_extended_ops
83+
84+
# CoreML
85+
llm_config.backend.coreml.enabled = args.coreml
86+
llm_config.backend.coreml.enable_state = getattr(args, "coreml_enable_state", False)
87+
llm_config.backend.coreml.preserve_sdpa = getattr(
88+
args, "coreml_preserve_sdpa", False
89+
)
90+
if args.coreml_quantize:
91+
llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize)
92+
llm_config.backend.coreml.ios = args.coreml_ios
93+
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
94+
args.coreml_compute_units
95+
)
96+
97+
# Vulkan
98+
llm_config.backend.vulkan.enabled = args.vulkan
99+
100+
# QNN
101+
llm_config.backend.qnn.enabled = args.qnn
102+
llm_config.backend.qnn.use_sha = args.use_qnn_sha
103+
llm_config.backend.qnn.soc_model = args.soc_model
104+
llm_config.backend.qnn.optimized_rotation_path = args.optimized_rotation_path
105+
llm_config.backend.qnn.num_sharding = args.num_sharding
106+
107+
# MPS
108+
llm_config.backend.mps.enabled = args.mps
109+
110+
# DebugConfig
111+
llm_config.debug.profile_memory = args.profile_memory
112+
llm_config.debug.profile_path = args.profile_path
113+
llm_config.debug.generate_etrecord = args.generate_etrecord
114+
llm_config.debug.generate_full_logits = args.generate_full_logits
115+
llm_config.debug.verbose = args.verbose
21116

22117
return llm_config

0 commit comments

Comments
 (0)