Skip to content

Commit a0a1069

Browse files
authored
Save bits_per_token for each sample during preprocessing
1 parent aaca529 commit a0a1069

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

flame/utils/preprocess.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
def tokenize(
1414
examples: Dict[str, List[Any]],
1515
tokenizer: PreTrainedTokenizer,
16-
) -> Dict[str, List[List[int]]]:
16+
) -> Dict:
1717
if 'text' in examples:
18-
input_ids = tokenizer(examples['text'])['input_ids']
18+
samples = examples['text']
1919
elif 'content' in examples:
20-
input_ids = tokenizer(examples['content'])['input_ids']
20+
samples = examples['content']
2121
else:
2222
raise ValueError(f'No "text" or "content" field found in examples:\n{examples}')
23-
return {'input_ids': input_ids}
23+
input_ids = tokenizer(samples)['input_ids']
24+
bits_per_token = [len(sample.encode(encoding='utf-8')) * 8 / len(input_ids[i]) for i, sample in enumerate(samples)]
25+
return {'input_ids': input_ids, 'bits_per_token': bits_per_token}
2426

2527

2628
if __name__ == '__main__':

0 commit comments

Comments
 (0)