Skip to content

Commit 0e547c1

Browse files
lucylqfacebook-github-bot
authored andcommitted
Remove default bos/eos from metadata (pytorch#15231)
Summary: See: pytorch#15215 Currently: - default eos/bos tokens are embedded into the pte - llama3 instruct has a different set of eos/bos tokens - users must manually specify at export time the llama3 instruct eos/bos tokens, because the runner overrides tokenizer eos/bos with the values in the PTE This diff: - removes the defaults - rely on tokenizer for eos/bos UNLESS the user explicitly specifies in the metadata, in which case use the eos/bos saved in PTE. Reviewed By: jackzhxng Differential Revision: D84942718
1 parent aeee757 commit 0e547c1

File tree

1 file changed

+0
-11
lines changed

1 file changed

+0
-11
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import logging
1616
import re
1717
import shlex
18-
from enum import Enum
1918
from functools import partial
2019

2120
from importlib import resources as _resources
@@ -121,11 +120,6 @@
121120
}
122121

123122

124-
class WeightType(Enum):
125-
LLAMA = "LLAMA"
126-
FAIRSEQ2 = "FAIRSEQ2"
127-
128-
129123
def set_pkg_name(name: str) -> None:
130124
global pkg_name
131125
pkg_name = name
@@ -1247,7 +1241,6 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
12471241

12481242

12491243
def _load_llama_model_metadata(
1250-
weight_type: WeightType,
12511244
use_kv_cache: bool,
12521245
use_sdpa_with_kv_cache: bool,
12531246
enable_dynamic_shape: bool,
@@ -1257,10 +1250,7 @@ def _load_llama_model_metadata(
12571250
vocab_size: int,
12581251
metadata_str: Optional[str] = None,
12591252
):
1260-
is_fairseq2 = weight_type == WeightType.FAIRSEQ2
12611253
metadata = {
1262-
"get_bos_id": 3 if is_fairseq2 else 1,
1263-
"get_eos_ids": [3] if is_fairseq2 else [2],
12641254
"get_max_seq_len": max_seq_len,
12651255
"get_max_context_len": max_context_len,
12661256
"get_n_layers": n_layers,
@@ -1332,7 +1322,6 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
13321322
save_exported_program=llm_config.export.export_only,
13331323
verbose=llm_config.debug.verbose,
13341324
metadata=_load_llama_model_metadata(
1335-
WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA,
13361325
llm_config.model.use_kv_cache,
13371326
llm_config.model.use_sdpa_with_kv_cache,
13381327
llm_config.model.enable_dynamic_shape,

0 commit comments

Comments
 (0)