Skip to content

Commit 8f46971

Browse files
authored
allow models to use customized token ids during export
Differential Revision: D61044259 Pull Request resolved: #4649
1 parent 440048c commit 8f46971

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -553,27 +553,29 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
553553

554554
def _load_llama_model_metadata(
555555
weight_type: WeightType,
556-
dtype: DType,
557556
use_kv_cache: bool,
558557
use_sdpa_with_kv_cache: bool,
559558
enable_dynamic_shape: bool,
560-
modelArgs: ModelArgs,
559+
model_args: ModelArgs,
561560
metadata_str: Optional[str] = None,
562561
):
563562
is_fairseq2 = weight_type == WeightType.FAIRSEQ2
564563
metadata = {
565564
"append_eos_to_prompt": is_fairseq2, # For language llama, tell the runtime to always append EOS token(s) to prompt.
566-
"get_bos_id": 3 if is_fairseq2 else 1,
567-
"get_dtype": 5 if dtype == DType.fp16 else 6,
568-
"get_eos_id": 3 if is_fairseq2 else 2,
569-
"get_head_dim": modelArgs.dim // modelArgs.n_heads,
570-
"get_max_batch_size": modelArgs.max_batch_size,
571-
"get_max_seq_len": modelArgs.max_seq_len,
565+
"get_bos_id": (
566+
model_args.bos_idx
567+
if model_args.bos_idx is not None
568+
else (3 if is_fairseq2 else 1)
569+
),
570+
"get_eos_id": (
571+
model_args.eos_idx
572+
if model_args.eos_idx is not None
573+
else (3 if is_fairseq2 else 2)
574+
),
575+
"get_max_seq_len": model_args.max_seq_len,
572576
"get_n_bos": 1,
573577
"get_n_eos": 2 if is_fairseq2 else 1,
574-
"get_n_kv_heads": modelArgs.n_kv_heads,
575-
"get_n_layers": modelArgs.n_layers,
576-
"get_vocab_size": modelArgs.vocab_size,
578+
"get_vocab_size": model_args.vocab_size,
577579
"use_kv_cache": use_kv_cache,
578580
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
579581
"enable_dynamic_shape": enable_dynamic_shape,
@@ -655,7 +657,6 @@ def _load_llama_model(
655657
verbose=verbose,
656658
metadata=_load_llama_model_metadata(
657659
weight_type,
658-
dtype,
659660
use_kv_cache,
660661
use_sdpa_with_kv_cache,
661662
enable_dynamic_shape,

examples/models/llama2/llama_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ class ModelArgs:
104104
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
105105
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
106106
# Additional Model Metadata needed at runtime
107-
bos_idx: int = 1
108-
eos_idx: int = 3
107+
bos_idx: Optional[int] = None
108+
eos_idx: Optional[int] = None
109109
bos_count: int = -1 # i.e., a single EOS is used as BOS
110110
eos_count: int = 2
111111

0 commit comments

Comments
 (0)