@@ -26,92 +26,143 @@ def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
2626 llm_config = LlmConfig ()
2727
2828 # BaseConfig
29- llm_config .base .model_class = ModelType (args .model )
30- llm_config .base .params = args .params
31- llm_config .base .checkpoint = args .checkpoint
32- llm_config .base .checkpoint_dir = args .checkpoint_dir
33- llm_config .base .tokenizer_path = args .tokenizer_path
34- llm_config .base .metadata = args .metadata
35- llm_config .base .use_lora = bool (args .use_lora )
36- llm_config .base .fairseq2 = args .fairseq2
29+ if hasattr (args , "model" ):
30+ llm_config .base .model_class = ModelType (args .model )
31+ if hasattr (args , "params" ):
32+ llm_config .base .params = args .params
33+ if hasattr (args , "checkpoint" ):
34+ llm_config .base .checkpoint = args .checkpoint
35+ if hasattr (args , "checkpoint_dir" ):
36+ llm_config .base .checkpoint_dir = args .checkpoint_dir
37+ if hasattr (args , "tokenizer_path" ):
38+ llm_config .base .tokenizer_path = args .tokenizer_path
39+ if hasattr (args , "metadata" ):
40+ llm_config .base .metadata = args .metadata
41+ if hasattr (args , "use_lora" ):
42+ llm_config .base .use_lora = args .use_lora
43+ if hasattr (args , "fairseq2" ):
44+ llm_config .base .fairseq2 = args .fairseq2
3745
3846 # PreqMode settings
39- if args .preq_mode :
47+ if hasattr ( args , "preq_mode" ) and args .preq_mode :
4048 llm_config .base .preq_mode = PreqMode (args .preq_mode )
41- llm_config .base .preq_group_size = args .preq_group_size
42- llm_config .base .preq_embedding_quantize = args .preq_embedding_quantize
49+ if hasattr (args , "preq_group_size" ):
50+ llm_config .base .preq_group_size = args .preq_group_size
51+ if hasattr (args , "preq_embedding_quantize" ):
52+ llm_config .base .preq_embedding_quantize = args .preq_embedding_quantize
4353
4454 # ModelConfig
45- llm_config .model .dtype_override = DtypeOverride (args .dtype_override )
46- llm_config .model .enable_dynamic_shape = args .enable_dynamic_shape
47- llm_config .model .use_shared_embedding = args .use_shared_embedding
48- llm_config .model .use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache
49- llm_config .model .expand_rope_table = args .expand_rope_table
50- llm_config .model .use_attention_sink = args .use_attention_sink
51- llm_config .model .output_prune_map = args .output_prune_map
52- llm_config .model .input_prune_map = args .input_prune_map
53- llm_config .model .use_kv_cache = args .use_kv_cache
54- llm_config .model .quantize_kv_cache = args .quantize_kv_cache
55- llm_config .model .local_global_attention = args .local_global_attention
55+ if hasattr (args , "dtype_override" ):
56+ llm_config .model .dtype_override = DtypeOverride (args .dtype_override )
57+ if hasattr (args , "enable_dynamic_shape" ):
58+ llm_config .model .enable_dynamic_shape = args .enable_dynamic_shape
59+ if hasattr (args , "use_shared_embedding" ):
60+ llm_config .model .use_shared_embedding = args .use_shared_embedding
61+ if hasattr (args , "use_sdpa_with_kv_cache" ):
62+ llm_config .model .use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache
63+ if hasattr (args , "expand_rope_table" ):
64+ llm_config .model .expand_rope_table = args .expand_rope_table
65+ if hasattr (args , "use_attention_sink" ):
66+ llm_config .model .use_attention_sink = args .use_attention_sink
67+ if hasattr (args , "output_prune_map" ):
68+ llm_config .model .output_prune_map = args .output_prune_map
69+ if hasattr (args , "input_prune_map" ):
70+ llm_config .model .input_prune_map = args .input_prune_map
71+ if hasattr (args , "use_kv_cache" ):
72+ llm_config .model .use_kv_cache = args .use_kv_cache
73+ if hasattr (args , "quantize_kv_cache" ):
74+ llm_config .model .quantize_kv_cache = args .quantize_kv_cache
75+ if hasattr (args , "local_global_attention" ):
76+ llm_config .model .local_global_attention = args .local_global_attention
5677
5778 # ExportConfig
58- llm_config .export .max_seq_length = args .max_seq_length
59- llm_config .export .max_context_length = args .max_context_length
60- llm_config .export .output_dir = args .output_dir
61- llm_config .export .output_name = args .output_name
62- llm_config .export .so_library = args .so_library
63- llm_config .export .export_only = args .export_only
79+ if hasattr (args , "max_seq_length" ):
80+ llm_config .export .max_seq_length = args .max_seq_length
81+ if hasattr (args , "max_context_length" ):
82+ llm_config .export .max_context_length = args .max_context_length
83+ if hasattr (args , "output_dir" ):
84+ llm_config .export .output_dir = args .output_dir
85+ if hasattr (args , "output_name" ):
86+ llm_config .export .output_name = args .output_name
87+ if hasattr (args , "so_library" ):
88+ llm_config .export .so_library = args .so_library
89+ if hasattr (args , "export_only" ):
90+ llm_config .export .export_only = args .export_only
6491
6592 # QuantizationConfig
66- llm_config .quantization .qmode = args .quantization_mode
67- llm_config .quantization .embedding_quantize = args .embedding_quantize
68- if args .pt2e_quantize :
93+ if hasattr (args , "quantization_mode" ):
94+ llm_config .quantization .qmode = args .quantization_mode
95+ if hasattr (args , "embedding_quantize" ):
96+ llm_config .quantization .embedding_quantize = args .embedding_quantize
97+ if hasattr (args , "pt2e_quantize" ) and args .pt2e_quantize :
6998 llm_config .quantization .pt2e_quantize = Pt2eQuantize (args .pt2e_quantize )
70- llm_config .quantization .group_size = args .group_size
71- if args .use_spin_quant :
99+ if hasattr (args , "group_size" ):
100+ llm_config .quantization .group_size = args .group_size
101+ if hasattr (args , "use_spin_quant" ) and args .use_spin_quant :
72102 llm_config .quantization .use_spin_quant = SpinQuant (args .use_spin_quant )
73- llm_config .quantization .use_qat = args .use_qat
74- llm_config .quantization .calibration_tasks = args .calibration_tasks
75- llm_config .quantization .calibration_limit = args .calibration_limit
76- llm_config .quantization .calibration_seq_length = args .calibration_seq_length
77- llm_config .quantization .calibration_data = args .calibration_data
78-
79- # BackendConfig
80- # XNNPack
81- llm_config .backend .xnnpack .enabled = args .xnnpack
82- llm_config .backend .xnnpack .extended_ops = args .xnnpack_extended_ops
103+ if hasattr (args , "use_qat" ):
104+ llm_config .quantization .use_qat = args .use_qat
105+ if hasattr (args , "calibration_tasks" ):
106+ llm_config .quantization .calibration_tasks = args .calibration_tasks
107+ if hasattr (args , "calibration_limit" ):
108+ llm_config .quantization .calibration_limit = args .calibration_limit
109+ if hasattr (args , "calibration_seq_length" ):
110+ llm_config .quantization .calibration_seq_length = args .calibration_seq_length
111+ if hasattr (args , "calibration_data" ):
112+ llm_config .quantization .calibration_data = args .calibration_data
113+
114+ # BackendConfig - XNNPack
115+ if hasattr (args , "xnnpack" ):
116+ llm_config .backend .xnnpack .enabled = args .xnnpack
117+ if hasattr (args , "xnnpack_extended_ops" ):
118+ llm_config .backend .xnnpack .extended_ops = args .xnnpack_extended_ops
83119
84120 # CoreML
85- llm_config .backend .coreml .enabled = args .coreml
121+ if hasattr (args , "coreml" ):
122+ llm_config .backend .coreml .enabled = args .coreml
86123 llm_config .backend .coreml .enable_state = getattr (args , "coreml_enable_state" , False )
87124 llm_config .backend .coreml .preserve_sdpa = getattr (
88125 args , "coreml_preserve_sdpa" , False
89126 )
90- if args .coreml_quantize :
127+ if hasattr ( args , "coreml_quantize" ) and args .coreml_quantize :
91128 llm_config .backend .coreml .quantize = CoreMLQuantize (args .coreml_quantize )
92- llm_config .backend .coreml .ios = args .coreml_ios
93- llm_config .backend .coreml .compute_units = CoreMLComputeUnit (
94- args .coreml_compute_units
95- )
129+ if hasattr (args , "coreml_ios" ):
130+ llm_config .backend .coreml .ios = args .coreml_ios
131+ if hasattr (args , "coreml_compute_units" ):
132+ llm_config .backend .coreml .compute_units = CoreMLComputeUnit (
133+ args .coreml_compute_units
134+ )
96135
97136 # Vulkan
98- llm_config .backend .vulkan .enabled = args .vulkan
137+ if hasattr (args , "vulkan" ):
138+ llm_config .backend .vulkan .enabled = args .vulkan
99139
100140 # QNN
101- llm_config .backend .qnn .enabled = args .qnn
102- llm_config .backend .qnn .use_sha = args .use_qnn_sha
103- llm_config .backend .qnn .soc_model = args .soc_model
104- llm_config .backend .qnn .optimized_rotation_path = args .optimized_rotation_path
105- llm_config .backend .qnn .num_sharding = args .num_sharding
141+ if hasattr (args , "qnn" ):
142+ llm_config .backend .qnn .enabled = args .qnn
143+ if hasattr (args , "use_qnn_sha" ):
144+ llm_config .backend .qnn .use_sha = args .use_qnn_sha
145+ if hasattr (args , "soc_model" ):
146+ llm_config .backend .qnn .soc_model = args .soc_model
147+ if hasattr (args , "optimized_rotation_path" ):
148+ llm_config .backend .qnn .optimized_rotation_path = args .optimized_rotation_path
149+ if hasattr (args , "num_sharding" ):
150+ llm_config .backend .qnn .num_sharding = args .num_sharding
106151
107152 # MPS
108- llm_config .backend .mps .enabled = args .mps
153+ if hasattr (args , "mps" ):
154+ llm_config .backend .mps .enabled = args .mps
109155
110156 # DebugConfig
111- llm_config .debug .profile_memory = args .profile_memory
112- llm_config .debug .profile_path = args .profile_path
113- llm_config .debug .generate_etrecord = args .generate_etrecord
114- llm_config .debug .generate_full_logits = args .generate_full_logits
115- llm_config .debug .verbose = args .verbose
157+ if hasattr (args , "profile_memory" ):
158+ llm_config .debug .profile_memory = args .profile_memory
159+ if hasattr (args , "profile_path" ):
160+ llm_config .debug .profile_path = args .profile_path
161+ if hasattr (args , "generate_etrecord" ):
162+ llm_config .debug .generate_etrecord = args .generate_etrecord
163+ if hasattr (args , "generate_full_logits" ):
164+ llm_config .debug .generate_full_logits = args .generate_full_logits
165+ if hasattr (args , "verbose" ):
166+ llm_config .debug .verbose = args .verbose
116167
117168 return llm_config
0 commit comments