@@ -355,14 +355,18 @@ def length(self, path: str):
355355 return self .open (path ).num_record_batches
356356
357357 def get (self , reader : pa .RecordBatchFileReader , index : int , drop_tokens : Set ):
358- assert index < reader .num_record_batches , f"Illegal index { index } in set of { reader .num_record_batches } documents"
358+ assert (
359+ index < reader .num_record_batches
360+ ), f"Illegal index { index } in set of { reader .num_record_batches } documents"
359361 frame = reader .get_batch (index )
360362 doc = None
361363 for name in self .col_names :
362364 if name in frame .column_names :
363365 doc = frame [name ]
364366 break
365- assert doc is not None , f"None of column names { self .col_names } found in file headers { frame .column_names } "
367+ assert (
368+ doc is not None
369+ ), f"None of column names { self .col_names } found in file headers { frame .column_names } "
366370 if len (doc ) > 0 and doc [0 ].as_py () in drop_tokens :
367371 doc = doc .slice (1 , len (doc ) - 1 )
368372 # Recheck len for edge case where doc=[eos]
@@ -382,7 +386,9 @@ class ParquetHandler(_ShardFileHandler):
382386 before getting/slicing. However, this is a standard and widely-used data format.
383387 """
384388
385- def __init__ (self , tokenizer_path : str , col_names : List [str ] = ["text" , "contents" , "tokens" ]):
389+ def __init__ (
390+ self , tokenizer_path : str , col_names : List [str ] = ["text" , "contents" , "tokens" ]
391+ ):
386392 self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_path )
387393 self .col_names = col_names
388394
@@ -396,14 +402,18 @@ def open(self, path: str):
396402 if name in names :
397403 match = name
398404 break
399- assert match is not None , f"None of column names { self .col_names } found in file headers { names } "
405+ assert (
406+ match is not None
407+ ), f"None of column names { self .col_names } found in file headers { names } "
400408 return pq .read_pandas (path , columns = [match ], partitioning = None )[match ]
401409
402410 def length (self , path : str ):
403411 return pq .read_metadata (path ).num_rows
404412
405413 def get (self , reader , index : int , drop_tokens : Set ):
406- assert index < reader .length (), f"Illegal index { index } in set of { reader .length ()} documents"
414+ assert (
415+ index < reader .length ()
416+ ), f"Illegal index { index } in set of { reader .length ()} documents"
407417 doc = self .tokenizer (str (reader [index ])[:1_000_000 ])["input_ids" ]
408418 if len (doc ) > 0 and doc [0 ] in drop_tokens :
409419 doc = doc [1 :]
@@ -417,7 +427,9 @@ def slice(self, doc: List, index: int, n_pull: int) -> List:
417427
418428
419429class AutoHandler (_ShardFileHandler ):
420- def __init__ (self , tokenizer_path : str , col_names : List [str ] = ["text" , "contents" , "tokens" ]):
430+ def __init__ (
431+ self , tokenizer_path : str , col_names : List [str ] = ["text" , "contents" , "tokens" ]
432+ ):
421433 self .PHandler = ParquetHandler (tokenizer_path , col_names )
422434 self .AHandler = ArrowHandler (col_names )
423435 self .current = _ShardFileHandler ()
@@ -1132,7 +1144,9 @@ def setup(self):
11321144 ndocs = doc_counts [shard ]
11331145 if ndocs > 0 :
11341146 doc_start = int (ndocs * shardset [shard ][0 ])
1135- doc_end = max (doc_start , int (ndocs * shardset [shard ][1 ]) - 1 ) # inclusive upper bound
1147+ doc_end = max (
1148+ doc_start , int (ndocs * shardset [shard ][1 ]) - 1
1149+ ) # inclusive upper bound
11361150 self .docset .append ([shard , doc_start , doc_end ])
11371151 doccount += doc_end - doc_start + 1
11381152 self ._len = doccount
@@ -1280,7 +1294,9 @@ def __iter__(self):
12801294 yield self ._construct_chunk (j , doc , n_chunks )
12811295
12821296 # Check that epoch was non-empty
1283- assert self .has_yielded , f"Empty logical shard detected: { self .dataset , self .docset } "
1297+ assert (
1298+ self .has_yielded
1299+ ), f"Empty logical shard detected: { self .dataset , self .docset } "
12841300
12851301 def load_state_dict (self , state_dicts , sharded_input = False ):
12861302 self .setup ()
@@ -1378,7 +1394,7 @@ def setup(self):
13781394 assert (
13791395 sum (self .n_docs_remaining ) > 0
13801396 ), f"No documents detected in shard { self .rank } of { self .datapath } "
1381-
1397+
13821398 self .generator = torch .Generator ().manual_seed (self .rank )
13831399
13841400 def __iter__ (self ):
@@ -1487,7 +1503,9 @@ def __init__(
14871503 )
14881504 assert len (self .datasets ) > 0 , "You must specify at least one dataset"
14891505 for d in datasets :
1490- assert os .path .exists (os .path .join (datapath , d )), f"Invalid subdataset path: { os .path .join (datapath , d )} "
1506+ assert os .path .exists (
1507+ os .path .join (datapath , d )
1508+ ), f"Invalid subdataset path: { os .path .join (datapath , d )} "
14911509
14921510 if weights is not None :
14931511 assert len (weights ) == len (
0 commit comments