Skip to content

Commit 1f9d7ee

Browse files
authored
Fix the added token decoding issue on spm based tokenizer (#908)
* fix the added token decoding issue on spm based tokenizer * skip special
1 parent bfeb3dd commit 1f9d7ee

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

operators/tokenizer/bpe_streaming.hpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
4747
return {};
4848
}
4949

50-
51-
5250
OrtxStatus Id2Token(extTokenId_t id,
5351
std::string& token,
5452
bool skip_special_tokens,
@@ -95,17 +93,21 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
9593
}
9694

9795
OrtxStatus SpmId2Token(extTokenId_t id, std::string& token, bool& f_special_last) const {
98-
99-
std::string piece = id < arr_vocab_.size() ? arr_vocab_[id] : "";
10096
bool f_special = false;
101-
if (piece.empty() || all_special_ids_.count(id)) {
102-
token = "";
103-
f_special = true;
104-
} else if (IsSpmByteWord(piece)) {
105-
char buf[3] = {piece[3], piece[4], 0}; // something like <0x20>
106-
token = {static_cast<char>(strtol(buf, NULL, 16))};
97+
if (added_tokens_.count(id)) {
98+
f_special = all_special_ids_.count(id) ? true : false;
99+
// special token was skipped
100+
token = f_special ? "" : added_tokens_.at(id);
107101
} else {
108-
token = ReplaceAll(piece, std::string(ort_extensions::spm_escaped_space), " ");
102+
std::string piece = id < arr_vocab_.size() ? arr_vocab_[id] : "";
103+
if (piece.empty()) {
104+
token = unk_token_;
105+
} else if (IsSpmByteWord(piece)) {
106+
char buf[3] = {piece[3], piece[4], 0}; // something like <0x20>
107+
token = {static_cast<char>(strtol(buf, NULL, 16))};
108+
} else {
109+
token = ReplaceAll(piece, std::string(ort_extensions::spm_escaped_space), " ");
110+
}
109111
}
110112

111113
if (!token.empty() && token[0] == ' ' && f_special_last && add_dummy_prefix_) {

0 commit comments

Comments
 (0)