77
88import argparse
99
10- from executorch .examples .models .llama .config .llm_config import LlmConfig
10+ from executorch .examples .models .llama .config .llm_config import (
11+ CoreMLComputeUnit ,
12+ CoreMLQuantize ,
13+ DtypeOverride ,
14+ LlmConfig ,
15+ ModelType ,
16+ PreqMode ,
17+ Pt2eQuantize ,
18+ SpinQuant ,
19+ )
1120
1221
1322def convert_args_to_llm_config (args : argparse .Namespace ) -> LlmConfig :
@@ -17,6 +26,93 @@ def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
1726 """
1827 llm_config = LlmConfig ()
1928
20- # TODO: conversion code.
29+ # BaseConfig
30+ llm_config .base .model_class = ModelType (args .model )
31+ llm_config .base .params = args .params
32+ llm_config .base .checkpoint = args .checkpoint
33+ llm_config .base .checkpoint_dir = args .checkpoint_dir
34+ llm_config .base .tokenizer_path = args .tokenizer_path
35+ llm_config .base .metadata = args .metadata
36+ llm_config .base .use_lora = bool (args .use_lora )
37+ llm_config .base .fairseq2 = args .fairseq2
38+
39+ # PreqMode settings
40+ if args .preq_mode :
41+ llm_config .base .preq_mode = PreqMode (args .preq_mode )
42+ llm_config .base .preq_group_size = args .preq_group_size
43+ llm_config .base .preq_embedding_quantize = args .preq_embedding_quantize
44+
45+ # ModelConfig
46+ llm_config .model .dtype_override = DtypeOverride (args .dtype_override )
47+ llm_config .model .enable_dynamic_shape = args .enable_dynamic_shape
48+ llm_config .model .use_shared_embedding = args .use_shared_embedding
49+ llm_config .model .use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache
50+ llm_config .model .expand_rope_table = args .expand_rope_table
51+ llm_config .model .use_attention_sink = args .use_attention_sink
52+ llm_config .model .output_prune_map = args .output_prune_map
53+ llm_config .model .input_prune_map = args .input_prune_map
54+ llm_config .model .use_kv_cache = args .use_kv_cache
55+ llm_config .model .quantize_kv_cache = args .quantize_kv_cache
56+ llm_config .model .local_global_attention = args .local_global_attention
57+
58+ # ExportConfig
59+ llm_config .export .max_seq_length = args .max_seq_length
60+ llm_config .export .max_context_length = args .max_context_length
61+ llm_config .export .output_dir = args .output_dir
62+ llm_config .export .output_name = args .output_name
63+ llm_config .export .so_library = args .so_library
64+ llm_config .export .export_only = args .export_only
65+
66+ # QuantizationConfig
67+ llm_config .quantization .qmode = args .quantization_mode
68+ llm_config .quantization .embedding_quantize = args .embedding_quantize
69+ if args .pt2e_quantize :
70+ llm_config .quantization .pt2e_quantize = Pt2eQuantize (args .pt2e_quantize )
71+ llm_config .quantization .group_size = args .group_size
72+ if args .use_spin_quant :
73+ llm_config .quantization .use_spin_quant = SpinQuant (args .use_spin_quant )
74+ llm_config .quantization .use_qat = args .use_qat
75+ llm_config .quantization .calibration_tasks = args .calibration_tasks
76+ llm_config .quantization .calibration_limit = args .calibration_limit
77+ llm_config .quantization .calibration_seq_length = args .calibration_seq_length
78+ llm_config .quantization .calibration_data = args .calibration_data
79+
80+ # BackendConfig
81+ # XNNPack
82+ llm_config .backend .xnnpack .enabled = args .xnnpack
83+ llm_config .backend .xnnpack .extended_ops = args .xnnpack_extended_ops
84+
85+ # CoreML
86+ llm_config .backend .coreml .enabled = args .coreml
87+ llm_config .backend .coreml .enable_state = getattr (args , "coreml_enable_state" , False )
88+ llm_config .backend .coreml .preserve_sdpa = getattr (
89+ args , "coreml_preserve_sdpa" , False
90+ )
91+ if args .coreml_quantize :
92+ llm_config .backend .coreml .quantize = CoreMLQuantize (args .coreml_quantize )
93+ llm_config .backend .coreml .ios = args .coreml_ios
94+ llm_config .backend .coreml .compute_units = CoreMLComputeUnit (
95+ args .coreml_compute_units
96+ )
97+
98+ # Vulkan
99+ llm_config .backend .vulkan .enabled = args .vulkan
100+
101+ # QNN
102+ llm_config .backend .qnn .enabled = args .qnn
103+ llm_config .backend .qnn .use_sha = args .use_qnn_sha
104+ llm_config .backend .qnn .soc_model = args .soc_model
105+ llm_config .backend .qnn .optimized_rotation_path = args .optimized_rotation_path
106+ llm_config .backend .qnn .num_sharding = args .num_sharding
107+
108+ # MPS
109+ llm_config .backend .mps .enabled = args .mps
110+
111+ # DebugConfig
112+ llm_config .debug .profile_memory = args .profile_memory
113+ llm_config .debug .profile_path = args .profile_path
114+ llm_config .debug .generate_etrecord = args .generate_etrecord
115+ llm_config .debug .generate_full_logits = args .generate_full_logits
116+ llm_config .debug .verbose = args .verbose
21117
22118 return llm_config
0 commit comments