Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging
import re
import shlex
from enum import Enum
from functools import partial

from importlib import resources as _resources
Expand Down Expand Up @@ -121,11 +120,6 @@
}


class WeightType(Enum):
LLAMA = "LLAMA"
FAIRSEQ2 = "FAIRSEQ2"


def set_pkg_name(name: str) -> None:
global pkg_name
pkg_name = name
Expand Down Expand Up @@ -1247,7 +1241,6 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901


def _load_llama_model_metadata(
weight_type: WeightType,
use_kv_cache: bool,
use_sdpa_with_kv_cache: bool,
enable_dynamic_shape: bool,
Expand All @@ -1257,10 +1250,7 @@ def _load_llama_model_metadata(
vocab_size: int,
metadata_str: Optional[str] = None,
):
is_fairseq2 = weight_type == WeightType.FAIRSEQ2
metadata = {
"get_bos_id": 3 if is_fairseq2 else 1,
"get_eos_ids": [3] if is_fairseq2 else [2],
"get_max_seq_len": max_seq_len,
"get_max_context_len": max_context_len,
"get_n_layers": n_layers,
Expand Down Expand Up @@ -1332,7 +1322,6 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
save_exported_program=llm_config.export.export_only,
verbose=llm_config.debug.verbose,
metadata=_load_llama_model_metadata(
WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA,
llm_config.model.use_kv_cache,
llm_config.model.use_sdpa_with_kv_cache,
llm_config.model.enable_dynamic_shape,
Expand Down
Loading