Skip to content

Commit 935d356

Browse files
authored
Merge pull request #29 from khaledsulayman/fix-rechunking
raise error if token count exceeds 1024 instead of attempting to re-chunk
2 parents 36bece4 + 85999ee commit 935d356

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

notebooks/instructlab-knowledge/utils/create_seed_dataset.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def safe_concatenate_datasets(datasets: list[Dataset]) -> Dataset:
119119
def get_token_count(text, tokenizer):
120120
return len(tokenizer.tokenize(text))
121121

122-
def add_icls(qna_yaml: Dict[str, str], chunked_document: Dataset) -> Dataset:
122+
def add_icls(qna_yaml: Dict[str, str], chunked_document: Dataset, max_token_count: int = 1024) -> Dataset:
123123
"""
124124
Add the ICLS label to the dataset.
125125
Args:
@@ -147,24 +147,22 @@ def add_icls(qna_yaml: Dict[str, str], chunked_document: Dataset) -> Dataset:
147147
)
148148
)
149149
chunked_document_all_icl = safe_concatenate_datasets(chunked_document_all_icl)
150-
chunked_document_all_icl = chunked_document_all_icl.map(
151-
lambda x: {
152-
"chunks": chunk_document(
153-
[x["document"]], server_ctx_size=4096, chunk_word_count=1024
154-
)
155-
if get_token_count(x["document"], tokenizer) > 1024
156-
else [x["document"]]
157-
}
158-
)
150+
def truncate_chunk(chunk: str):
151+
words = chunk.split()
152+
if len(words) > 7:
153+
return " ".join(words[:3]) + " ... " + " ".join(words[-3:])
154+
return chunk
155+
156+
for c in chunked_document_all_icl:
157+
if get_token_count(c["document"], tokenizer) > max_token_count:
158+
raise ValueError(f"Chunk \"{truncate_chunk(c["document"])}\" exceeds token count of {max_token_count}")
159+
160+
159161
df = chunked_document_all_icl.to_pandas()
160-
df_exploded = df.explode("chunks").reset_index(drop=True)
161-
new_ds = Dataset.from_pandas(df_exploded)
162-
new_ds = new_ds.remove_columns("document").rename_columns(
163-
{"chunks": "document"}
164-
)
162+
new_ds = Dataset.from_pandas(df)
165163

166164
# Only keep document greater than 100 tokens
167165
new_ds = new_ds.filter(
168-
lambda x: get_token_count(x["document"], tokenizer) > 100
166+
lambda c: get_token_count(c["document"], tokenizer) > 100
169167
)
170168
return new_ds

0 commit comments

Comments
 (0)