Skip to content

Commit 0adf71e

Browse files
lucylqfacebook-github-bot
authored andcommitted
Remove default bos/eos from metadata
Summary: See: pytorch#15215 Currently: - default eos/bos tokens are embedded into the pte - llama3 instruct has a different set of eos/bos tokens - users must manually specify at export time the llama3 instruct eos/bos tokens, because the runner overrides tokenizer eos/bos with the values in the PTE This diff: - removes the defaults - rely on tokenizer for eos/bos UNLESS the user explicitly specifies in the metadata, in which case use the eos/bos saved in PTE. Differential Revision: D84942718
1 parent 4cff294 commit 0adf71e

File tree

1 file changed

+0
-10
lines changed

1 file changed

+0
-10
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,6 @@
121121
}
122122

123123

124-
class WeightType(Enum):
125-
LLAMA = "LLAMA"
126-
FAIRSEQ2 = "FAIRSEQ2"
127-
128-
129124
def set_pkg_name(name: str) -> None:
130125
global pkg_name
131126
pkg_name = name
@@ -1247,7 +1242,6 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
12471242

12481243

12491244
def _load_llama_model_metadata(
1250-
weight_type: WeightType,
12511245
use_kv_cache: bool,
12521246
use_sdpa_with_kv_cache: bool,
12531247
enable_dynamic_shape: bool,
@@ -1257,10 +1251,7 @@ def _load_llama_model_metadata(
12571251
vocab_size: int,
12581252
metadata_str: Optional[str] = None,
12591253
):
1260-
is_fairseq2 = weight_type == WeightType.FAIRSEQ2
12611254
metadata = {
1262-
"get_bos_id": 3 if is_fairseq2 else 1,
1263-
"get_eos_ids": [3] if is_fairseq2 else [2],
12641255
"get_max_seq_len": max_seq_len,
12651256
"get_max_context_len": max_context_len,
12661257
"get_n_layers": n_layers,
@@ -1332,7 +1323,6 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
13321323
save_exported_program=llm_config.export.export_only,
13331324
verbose=llm_config.debug.verbose,
13341325
metadata=_load_llama_model_metadata(
1335-
WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA,
13361326
llm_config.model.use_kv_cache,
13371327
llm_config.model.use_sdpa_with_kv_cache,
13381328
llm_config.model.enable_dynamic_shape,

0 commit comments

Comments
 (0)