diff --git a/notebooks/instructlab-knowledge/utils/create_seed_dataset.py b/notebooks/instructlab-knowledge/utils/create_seed_dataset.py index 808d6d6..f049c86 100644 --- a/notebooks/instructlab-knowledge/utils/create_seed_dataset.py +++ b/notebooks/instructlab-knowledge/utils/create_seed_dataset.py @@ -119,7 +119,7 @@ def safe_concatenate_datasets(datasets: list[Dataset]) -> Dataset: def get_token_count(text, tokenizer): return len(tokenizer.tokenize(text)) -def add_icls(qna_yaml: Dict[str, str], chunked_document: Dataset) -> Dataset: +def add_icls(qna_yaml: Dict[str, str], chunked_document: Dataset, max_token_count: int = 1024) -> Dataset: """ Add the ICLS label to the dataset. Args: @@ -147,24 +147,22 @@ def add_icls(qna_yaml: Dict[str, str], chunked_document: Dataset) -> Dataset: ) ) chunked_document_all_icl = safe_concatenate_datasets(chunked_document_all_icl) - chunked_document_all_icl = chunked_document_all_icl.map( - lambda x: { - "chunks": chunk_document( - [x["document"]], server_ctx_size=4096, chunk_word_count=1024 - ) - if get_token_count(x["document"], tokenizer) > 1024 - else [x["document"]] - } - ) + def truncate_chunk(chunk: str): + words = chunk.split() + if len(words) > 7: + return " ".join(words[:3]) + " ... " + " ".join(words[-3:]) + return chunk + + for c in chunked_document_all_icl: + if get_token_count(c["document"], tokenizer) > max_token_count: + raise ValueError(f"Chunk \"{truncate_chunk(c["document"])}\" exceeds token count of {max_token_count}") + + df = chunked_document_all_icl.to_pandas() - df_exploded = df.explode("chunks").reset_index(drop=True) - new_ds = Dataset.from_pandas(df_exploded) - new_ds = new_ds.remove_columns("document").rename_columns( - {"chunks": "document"} - ) + new_ds = Dataset.from_pandas(df) # Only keep document greater than 100 tokens new_ds = new_ds.filter( - lambda x: get_token_count(x["document"], tokenizer) > 100 + lambda c: get_token_count(c["document"], tokenizer) > 100 ) return new_ds