@@ -27,7 +27,7 @@ namespace {
2727static constexpr auto kAppendEosToPrompt = " append_eos_to_prompt" ;
2828static constexpr auto kEnableDynamicShape = " enable_dynamic_shape" ;
2929static constexpr auto kBosId = " get_bos_id" ;
30- static constexpr auto kEosId = " get_eos_id " ;
30+ static constexpr auto kEosIds = " get_eos_ids " ;
3131static constexpr auto kMaxSeqLen = " get_max_seq_len" ;
3232static constexpr auto kNBos = " get_n_bos" ;
3333static constexpr auto kNEos = " get_n_eos" ;
@@ -85,7 +85,8 @@ Error Runner::load() {
8585 ET_LOG (Info, " Reading metadata from model" );
8686
8787 metadata_[kBosId ] = tokenizer_->bos_tok ();
88- metadata_[kEosId ] = tokenizer_->eos_tok ();
88+ auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>(
89+ std::unordered_set<uint64_t >{tokenizer_->eos_tok ()});
8990 metadata_[kVocabSize ] = tokenizer_->vocab_size ();
9091
9192 const auto method_names =
@@ -106,6 +107,15 @@ Error Runner::load() {
106107 method_name.c_str (),
107108 value);
108109 }
110+ ET_LOG (Info, " Metadata: %s = %" PRId64, method_name.c_str (), value);
111+ }
112+ if (method_names.count (kEosIds )) {
113+ eos_ids->clear ();
114+ for (const auto & eos_id : ET_UNWRAP (module_->execute (kEosIds ))) {
115+ auto value = eos_id.toScalar ().to <int64_t >();
116+ eos_ids->emplace (value);
117+ ET_LOG (Info, " eos_id = %" PRId64, value);
118+ }
109119 }
110120 text_decoder_runner_ = std::make_unique<TextDecoderRunner>(
111121 module_.get (),
@@ -122,7 +132,7 @@ Error Runner::load() {
122132 tokenizer_.get (),
123133 text_decoder_runner_.get (),
124134 metadata_.at (kUseKVCache ),
125- metadata_. at ( kEosId ),
135+ std::move (eos_ids ),
126136 &stats_);
127137
128138 return Error::Ok;
0 commit comments