Skip to content

Commit 342e7f7

Browse files
authored
Allow multiple eos ids
Differential Revision: D61420500 Pull Request resolved: #4777
1 parent d8a00e6 commit 342e7f7

File tree

6 files changed

+24
-21
lines changed

6 files changed

+24
-21
lines changed

examples/models/llama2/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ You can export and run the original Llama 3 8B instruct model.
127127
128128
2. Export model and generate `.pte` file
129129
```
130-
python -m examples.models.llama2.export_llama --checkpoint <consolidated.00.pth> -p <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_id":128001}' --embedding-quantize 4,32 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
130+
python -m examples.models.llama2.export_llama --checkpoint <consolidated.00.pth> -p <params.json> -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --embedding-quantize 4,32 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
131131
```
132132
133133
Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size.

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -562,16 +562,8 @@ def _load_llama_model_metadata(
562562
is_fairseq2 = weight_type == WeightType.FAIRSEQ2
563563
metadata = {
564564
"append_eos_to_prompt": is_fairseq2, # For language llama, tell the runtime to always append EOS token(s) to prompt.
565-
"get_bos_id": (
566-
model_args.bos_idx
567-
if model_args.bos_idx is not None
568-
else (3 if is_fairseq2 else 1)
569-
),
570-
"get_eos_id": (
571-
model_args.eos_idx
572-
if model_args.eos_idx is not None
573-
else (3 if is_fairseq2 else 2)
574-
),
565+
"get_bos_id": 3 if is_fairseq2 else 1,
566+
"get_eos_ids": [3] if is_fairseq2 else [2],
575567
"get_max_seq_len": model_args.max_seq_len,
576568
"get_n_bos": 1,
577569
"get_n_eos": 2 if is_fairseq2 else 1,

examples/models/llama2/llama_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ class ModelArgs:
104104
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
105105
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
106106
# Additional Model Metadata needed at runtime
107-
bos_idx: Optional[int] = None
108-
eos_idx: Optional[int] = None
107+
bos_idx: int = 1
108+
eos_idx: int = 3
109109
bos_count: int = -1 # i.e., a single EOS is used as BOS
110110
eos_count: int = 2
111111

examples/models/llama2/runner/runner.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace {
2727
static constexpr auto kAppendEosToPrompt = "append_eos_to_prompt";
2828
static constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
2929
static constexpr auto kBosId = "get_bos_id";
30-
static constexpr auto kEosId = "get_eos_id";
30+
static constexpr auto kEosIds = "get_eos_ids";
3131
static constexpr auto kMaxSeqLen = "get_max_seq_len";
3232
static constexpr auto kNBos = "get_n_bos";
3333
static constexpr auto kNEos = "get_n_eos";
@@ -85,7 +85,8 @@ Error Runner::load() {
8585
ET_LOG(Info, "Reading metadata from model");
8686

8787
metadata_[kBosId] = tokenizer_->bos_tok();
88-
metadata_[kEosId] = tokenizer_->eos_tok();
88+
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
89+
std::unordered_set<uint64_t>{tokenizer_->eos_tok()});
8990
metadata_[kVocabSize] = tokenizer_->vocab_size();
9091

9192
const auto method_names =
@@ -106,6 +107,15 @@ Error Runner::load() {
106107
method_name.c_str(),
107108
value);
108109
}
110+
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
111+
}
112+
if (method_names.count(kEosIds)) {
113+
eos_ids->clear();
114+
for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
115+
auto value = eos_id.toScalar().to<int64_t>();
116+
eos_ids->emplace(value);
117+
ET_LOG(Info, "eos_id = %" PRId64, value);
118+
}
109119
}
110120
text_decoder_runner_ = std::make_unique<TextDecoderRunner>(
111121
module_.get(),
@@ -122,7 +132,7 @@ Error Runner::load() {
122132
tokenizer_.get(),
123133
text_decoder_runner_.get(),
124134
metadata_.at(kUseKVCache),
125-
metadata_.at(kEosId),
135+
std::move(eos_ids),
126136
&stats_);
127137

128138
return Error::Ok;

examples/models/llava/runner/llava_runner.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ Error LlavaRunner::load() {
6363
tokenizer_.get(),
6464
text_decoder_runner_.get(),
6565
/*use_kv_cache=*/true,
66-
tokenizer_->eos_tok(),
66+
std::make_unique<std::unordered_set<uint64_t>>(
67+
std::unordered_set<uint64_t>{tokenizer_->eos_tok()}),
6768
&stats_);
6869

6970
stats_.model_load_end_ms = util::time_in_ms();

extension/llm/runner/text_token_generator.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ class TextTokenGenerator {
2222
Tokenizer* tokenizer,
2323
TextDecoderRunner* text_decoder_runner,
2424
bool use_kv_cache,
25-
uint64_t eos_id,
25+
std::unique_ptr<std::unordered_set<uint64_t>>&& eos_ids,
2626
Stats* stats)
2727
: tokenizer_(tokenizer),
2828
text_decoder_runner_(text_decoder_runner),
29-
eos_id_(eos_id),
29+
eos_ids_(std::move(eos_ids)),
3030
use_kv_cache_(use_kv_cache),
3131
stats_(stats) {}
3232

@@ -108,7 +108,7 @@ class TextTokenGenerator {
108108
}
109109

110110
// data-dependent terminating condition: we have n_eos_ number of EOS
111-
if (cur_token == eos_id_) {
111+
if (eos_ids_->find(cur_token) != eos_ids_->end()) {
112112
printf("\n");
113113
ET_LOG(Info, "\nReached to the end of generation");
114114
break;
@@ -127,7 +127,7 @@ class TextTokenGenerator {
127127
private:
128128
Tokenizer* tokenizer_;
129129
TextDecoderRunner* text_decoder_runner_;
130-
uint64_t eos_id_;
130+
std::unique_ptr<std::unordered_set<uint64_t>> eos_ids_;
131131
bool use_kv_cache_;
132132

133133
// state machine

0 commit comments

Comments
 (0)