@@ -119,7 +119,7 @@ def safe_concatenate_datasets(datasets: list[Dataset]) -> Dataset:
119119def get_token_count (text , tokenizer ):
120120 return len (tokenizer .tokenize (text ))
121121
122- def add_icls (qna_yaml : Dict [str , str ], chunked_document : Dataset ) -> Dataset :
122+ def add_icls (qna_yaml : Dict [str , str ], chunked_document : Dataset , max_token_count : int = 1024 ) -> Dataset :
123123 """
124124 Add the ICLS label to the dataset.
125125 Args:
@@ -147,24 +147,22 @@ def add_icls(qna_yaml: Dict[str, str], chunked_document: Dataset) -> Dataset:
147147 )
148148 )
149149 chunked_document_all_icl = safe_concatenate_datasets (chunked_document_all_icl )
150- chunked_document_all_icl = chunked_document_all_icl .map (
151- lambda x : {
152- "chunks" : chunk_document (
153- [x ["document" ]], server_ctx_size = 4096 , chunk_word_count = 1024
154- )
155- if get_token_count (x ["document" ], tokenizer ) > 1024
156- else [x ["document" ]]
157- }
158- )
150+ def truncate_chunk (chunk : str ):
151+ words = chunk .split ()
152+ if len (words ) > 7 :
153+ return " " .join (words [:3 ]) + " ... " + " " .join (words [- 3 :])
154+ return chunk
155+
156+ for c in chunked_document_all_icl :
157+ if get_token_count (c ["document" ], tokenizer ) > max_token_count :
158+ raise ValueError (f"Chunk \" { truncate_chunk (c ["document" ])} \" exceeds token count of { max_token_count } " )
159+
160+
159161 df = chunked_document_all_icl .to_pandas ()
160- df_exploded = df .explode ("chunks" ).reset_index (drop = True )
161- new_ds = Dataset .from_pandas (df_exploded )
162- new_ds = new_ds .remove_columns ("document" ).rename_columns (
163- {"chunks" : "document" }
164- )
162+ new_ds = Dataset .from_pandas (df )
165163
166164 # Only keep document greater than 100 tokens
167165 new_ds = new_ds .filter (
168- lambda x : get_token_count (x ["document" ], tokenizer ) > 100
166+ lambda c : get_token_count (c ["document" ], tokenizer ) > 100
169167 )
170168 return new_ds
0 commit comments