Skip to content

Commit bacac27

Browse files
craymichaelfacebook-github-bot
authored andcommitted
LLM tokenizer pretty decoding fix for emojis/unicode (#1360)
Summary: Pull Request resolved: #1360 Emojis are not well-handled in the current decoding logic (see example in test plan). What unfortunately happens is that emojis/unicode are tokenized as two symbols - I believe one is to indicate extended unicode (or maybe a type of unicode, e.g., emoji), and the second is the type (e.g., smiley face, omega, ...). This solution makes the assumption that these will always come in pairs and that the symbol "�" is returned by tokenizer if the symbol is unknown (we verify that this wasn't the intended symbol by running it back through tokenizer). This logic will break down if symbols are split up into 3 or more tokens. Example: Input String: 😂 Output Token IDs: list of length 2 Pretty Decoded Tokens: ['😂[1/2]', '😂[2/2]'] Note that we cannot just output a single token here as we will be providing attributions for each of the token IDs. In attribution, all such cases here should really be grouped together so inputs are valid and attributions make sense. Reviewed By: cyrjano Differential Revision: D63435671 fbshipit-source-id: c029ab17b7c7e6ef1a3fff429da2ecb902d42595
1 parent 15738d0 commit bacac27

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,32 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
230230
> BPE splitting mostly to avoid digesting spaces since the standard BPE algorithm
231231
> used spaces in its process
232232
"""
233-
return [tokenizer.decode(id_) for id_ in ids]
233+
pretty_tokens = []
234+
idx = 0
235+
while idx < len(ids):
236+
decoded = tokenizer.decode(ids[idx])
237+
# Handle case where single token (e.g. unicode) is split into multiple IDs
238+
# NOTE: This logic will fail if a tokenizer splits a token into 3+ IDs
239+
if decoded.strip() == "�" and tokenizer.encode(decoded) != [ids[idx]]:
240+
# ID at idx is split, ensure next token is also from a split
241+
decoded_next = tokenizer.decode(ids[idx + 1])
242+
if decoded_next.strip() == "�" and tokenizer.encode(decoded_next) != [
243+
ids[idx + 1]
244+
]:
245+
# Both tokens are from a split, combine them
246+
decoded = tokenizer.decode(ids[idx : idx + 2])
247+
pretty_tokens.append(decoded + "[1/2]")
248+
pretty_tokens.append(decoded + "[2/2]")
249+
else:
250+
# Treat tokens as separate
251+
pretty_tokens.append(decoded)
252+
pretty_tokens.append(decoded_next)
253+
idx += 2
254+
else:
255+
# Just a normal token
256+
idx += 1
257+
pretty_tokens.append(decoded)
258+
return pretty_tokens
234259

235260

236261
class LLMAttribution(Attribution):

0 commit comments

Comments
 (0)