@@ -553,27 +553,29 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
553553
554554def _load_llama_model_metadata (
555555 weight_type : WeightType ,
556- dtype : DType ,
557556 use_kv_cache : bool ,
558557 use_sdpa_with_kv_cache : bool ,
559558 enable_dynamic_shape : bool ,
560- modelArgs : ModelArgs ,
559+ model_args : ModelArgs ,
561560 metadata_str : Optional [str ] = None ,
562561):
563562 is_fairseq2 = weight_type == WeightType .FAIRSEQ2
564563 metadata = {
565564 "append_eos_to_prompt" : is_fairseq2 , # For language llama, tell the runtime to always append EOS token(s) to prompt.
566- "get_bos_id" : 3 if is_fairseq2 else 1 ,
567- "get_dtype" : 5 if dtype == DType .fp16 else 6 ,
568- "get_eos_id" : 3 if is_fairseq2 else 2 ,
569- "get_head_dim" : modelArgs .dim // modelArgs .n_heads ,
570- "get_max_batch_size" : modelArgs .max_batch_size ,
571- "get_max_seq_len" : modelArgs .max_seq_len ,
565+ "get_bos_id" : (
566+ model_args .bos_idx
567+ if model_args .bos_idx is not None
568+ else (3 if is_fairseq2 else 1 )
569+ ),
570+ "get_eos_id" : (
571+ model_args .eos_idx
572+ if model_args .eos_idx is not None
573+ else (3 if is_fairseq2 else 2 )
574+ ),
575+ "get_max_seq_len" : model_args .max_seq_len ,
572576 "get_n_bos" : 1 ,
573577 "get_n_eos" : 2 if is_fairseq2 else 1 ,
574- "get_n_kv_heads" : modelArgs .n_kv_heads ,
575- "get_n_layers" : modelArgs .n_layers ,
576- "get_vocab_size" : modelArgs .vocab_size ,
578+ "get_vocab_size" : model_args .vocab_size ,
577579 "use_kv_cache" : use_kv_cache ,
578580 "use_sdpa_with_kv_cache" : use_sdpa_with_kv_cache ,
579581 "enable_dynamic_shape" : enable_dynamic_shape ,
@@ -655,7 +657,6 @@ def _load_llama_model(
655657 verbose = verbose ,
656658 metadata = _load_llama_model_metadata (
657659 weight_type ,
658- dtype ,
659660 use_kv_cache ,
660661 use_sdpa_with_kv_cache ,
661662 enable_dynamic_shape ,
0 commit comments