Skip to content

Commit 85999ee

Browse files
when chunk exceeds max token count, print truncated chunk
Signed-off-by: Khaled Sulayman <[email protected]>
1 parent 961dbd2 commit 85999ee

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

notebooks/instructlab-knowledge/utils/create_seed_dataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def safe_concatenate_datasets(datasets: list[Dataset]) -> Dataset:
119119
def get_token_count(text, tokenizer):
120120
return len(tokenizer.tokenize(text))
121121

122-
def add_icls(qna_yaml: Dict[str, str], chunked_document: Dataset) -> Dataset:
122+
def add_icls(qna_yaml: Dict[str, str], chunked_document: Dataset, max_token_count: int = 1024) -> Dataset:
123123
"""
124124
Add the ICLS label to the dataset.
125125
Args:
@@ -147,9 +147,15 @@ def add_icls(qna_yaml: Dict[str, str], chunked_document: Dataset) -> Dataset:
147147
)
148148
)
149149
chunked_document_all_icl = safe_concatenate_datasets(chunked_document_all_icl)
150+
def truncate_chunk(chunk: str):
151+
words = chunk.split()
152+
if len(words) > 7:
153+
return " ".join(words[:3]) + " ... " + " ".join(words[-3:])
154+
return chunk
155+
150156
for c in chunked_document_all_icl:
151-
if get_token_count(c["document"], tokenizer) > 1024:
152-
raise ValueError("Chunk exceeds token count of 1024")
157+
if get_token_count(c["document"], tokenizer) > max_token_count:
158+
raise ValueError(f"Chunk \"{truncate_chunk(c["document"])}\" exceeds token count of {max_token_count}")
153159

154160

155161
df = chunked_document_all_icl.to_pandas()

0 commit comments

Comments
 (0)