1515
1616from fms_fsdp .utils .checkpointing_utils import get_latest
1717
18- # TODO: long doc breaking
1918# TODO: titan PR adds
2019# TODO: zero-len file asserts/check
2120
@@ -359,8 +358,8 @@ def length(self, path: str):
359358 return self .open (path ).num_record_batches
360359
361360 def get (self , reader : pa .RecordBatchFileReader , index : int , drop_tokens : Set ):
361+ assert index < reader .num_record_batches , f"Illegal index { index } in set of { reader .num_record_batches } documents"
362362 frame = reader .get_batch (index )
363-
364363 doc = None
365364 for name in self .col_names :
366365 if name in frame .column_names :
@@ -407,7 +406,8 @@ def length(self, path: str):
407406 return pq .read_metadata (path ).num_rows
408407
409408 def get (self , reader , index : int , drop_tokens : Set ):
410- doc = self .tokenizer (str (reader [index ]))["input_ids" ]
409+ assert index < reader .length (), f"Illegal index { index } in set of { reader .length ()} documents"
410+ doc = self .tokenizer (str (reader [index ])[:1_000_000 ])["input_ids" ]
411411 if len (doc ) > 0 and doc [0 ] in drop_tokens :
412412 doc = doc [1 :]
413413 # Recheck len for edge case where doc=[eos]
@@ -978,10 +978,10 @@ class StreamingDocDataset(_StatefulDataset):
978978 Documents below this length are skipped
979979 max_chunksize : int
980980 Maximum sequence length to return. Break long docs into chunks of this size or shorter.
981+ max_consecutive_chunks : int
982+ Number of doc chunks to emit before manually inserting EOS and resuming later.
981983 verbose : bool
982984 Track setup progress?
983- shuffle : bool
984- Shuffle shard file and document orders? (Disable for simple testing)
985985 """
986986
987987 def __init__ (
@@ -996,6 +996,7 @@ def __init__(
996996 seed : int = 42 ,
997997 min_length : int = 1 ,
998998 max_chunksize : int = 1024 ,
999+ max_consecutive_chunks : int = 64 ,
9991000 verbose : bool = False ,
10001001 ):
10011002 super ().__init__ (datapath , rank , worldsize )
@@ -1008,6 +1009,7 @@ def __init__(
10081009 self .eos = delimiter_token
10091010 self .bos = bos_token
10101011 self .drop = strip_tokens
1012+ self .max_consec = max_consecutive_chunks
10111013 self .verbose = verbose
10121014 # Map of doc indices to (shardid, min docid, max docid)
10131015 self .docset : List [Any ] = []
@@ -1022,6 +1024,7 @@ def __init__(
10221024 self .tokens_seen = 0
10231025 self .docs_seen = 0
10241026 self .percent_seen = 0
1027+ self .consec = 0
10251028
10261029 self .state_params = [
10271030 "dataset" ,
@@ -1032,6 +1035,7 @@ def __init__(
10321035 "docs_seen" ,
10331036 "percent_seen" ,
10341037 "lcg_state" ,
1038+ "consec" ,
10351039 ]
10361040
10371041 # Setup flags
@@ -1064,73 +1068,81 @@ def setup(self):
10641068 if self .filehandler .is_legal (os .path .join (root , name ))
10651069 ]
10661070 shards .sort () # Ensure consistent sharding across machines
1067- start_frag = (self .rank * self .worldsize * len (shards )) // self .worldsize
1068- end_frag = (
1069- (self .rank + 1 ) * self .worldsize * len (shards )
1070- ) // self .worldsize
1071- shardfrags = [
1072- (shards [i // self .worldsize ], i % self .worldsize )
1073- for i in range (start_frag , end_frag )
1074- ]
1075-
1076- # Assemble length of each owned shard file
10771071
1072+ # Find metadata file
10781073 countfiles = []
10791074 if os .path .exists (os .path .join (pardir , "meta" )):
10801075 countfiles = [
10811076 x
10821077 for x in os .listdir (os .path .join (pardir , "meta" ))
10831078 if "counts" in x and "csv" in x
10841079 ]
1085- doc_counts = {}
10861080 if len (countfiles ) > 0 :
10871081 # Count file exists, use it
10881082 countpath = os .path .join (pardir , "meta" , countfiles [0 ])
1083+ else :
1084+ countpath = ""
1085+
1086+ # Use shard file sizes to perform partitioning
1087+ # Create shardlist of form shardid -> [start%, end%]
1088+ if len (countfiles ) > 0 :
1089+ sizes = {}
1090+ with open (countpath , "r" ) as csvfile :
1091+ reader = csv .DictReader (csvfile )
1092+ for row in reader :
1093+ fullpath = row ["dataset/filename" ]
1094+ prefix = fullpath .find (dataset + "/" )
1095+ if prefix >= 0 :
1096+ key = fullpath [prefix + len (dataset ) + 1 :]
1097+ sizes [key ] = int (row ["size" ])
1098+ shard_sizes = [sizes [shard ] for shard in shards ]
1099+ else :
1100+ shard_sizes = [
1101+ os .path .getsize (os .path .join (datapath , shard )) for shard in shards
1102+ ]
1103+ shard_sizes = [s / sum (shard_sizes ) for s in shard_sizes ]
1104+ start = self .rank / self .worldsize
1105+ end = (self .rank + 1 ) / self .worldsize
1106+ shardset = {}
1107+ tally = 0
1108+ for i in range (len (shards )):
1109+ if tally <= end and tally + shard_sizes [i ] >= start :
1110+ shardset [shards [i ]] = [
1111+ min (max ((start - tally ) / shard_sizes [i ], 0 ), 1 ),
1112+ min (max ((end - tally ) / shard_sizes [i ], 0 ), 1 ),
1113+ ]
1114+ tally += shard_sizes [i ]
1115+ # Count file exists, use it
10891116 with open (countpath , "r" ) as csvfile :
10901117 reader = csv .DictReader (csvfile )
10911118 for row in reader :
10921119 fullpath = row ["dataset/filename" ]
1093- prefix = fullpath .find ("/" + dataset ) + 1
1094- if prefix > 0 :
1120+ prefix = fullpath .find (dataset )
1121+ if prefix >= 0 :
10951122 key = fullpath [prefix + len (dataset ) + 1 :]
10961123 doc_counts [key ] = int (row ["documents" ])
10971124 else :
10981125 # Count file does not exist, touch every owned file for length
1099- unique_shardfiles = set (shard for shard , frag in shardfrags )
11001126 doc_counts = {
11011127 shard : self .filehandler .length (os .path .join (datapath , shard ))
1102- for shard in unique_shardfiles
1128+ for shard in shardset
11031129 }
11041130
1105- # Read shardfrags, assemble doc list for each file shard (aggregating over fragments):
1106- ndocs = - 1
1107- docset = {} # shardid -> (min docid, max docid)
1108- for i , (shard , frag ) in enumerate (shardfrags ):
1109- ndocs = doc_counts [shard ]
1110- doc_start = (ndocs * frag ) // self .worldsize
1111- doc_end = (
1112- ndocs * frag + ndocs
1113- ) // self .worldsize - 1 # Inclusive upper bound
1114- if shard not in docset :
1115- docset [shard ] = [doc_start , doc_end ]
1116- min_d , max_d = docset [shard ]
1117- if doc_start < min_d :
1118- docset [shard ][0 ] = doc_start
1119- if doc_end > max_d :
1120- docset [shard ][1 ] = doc_end
1121-
1122- # Add shard entries to self.docset
1131+ # Assemble doc list for each file shard
1132+ # Create docset of form [shardid, min docid, max docid]
11231133 doccount = 0
1124- for shardid in docset :
1125- min_d = docset [shardid ][0 ]
1126- max_d = docset [shardid ][1 ]
1127- self .docset .append ((shardid , min_d , max_d ))
1128- doccount += max_d - min_d + 1
1134+ for shard in shardset :
1135+ ndocs = doc_counts [shard ]
1136+ if ndocs > 0 :
1137+ doc_start = int (ndocs * shardset [shard ][0 ])
1138+ doc_end = max (doc_start , int (ndocs * shardset [shard ][1 ]) - 1 ) # inclusive upper bound
1139+ self .docset .append ([shard , doc_start , doc_end ])
1140+ doccount += doc_end - doc_start + 1
11291141 self ._len = doccount
11301142
11311143 if self .verbose :
11321144 logging .info (
1133- f" Worker { self .rank } ingested { len (shardfrags )} shard fragments from { dataset } "
1145+ f" Worker { self .rank } ingested { len (self . docset )} shard fragments from { dataset } "
11341146 )
11351147
11361148 # Shuffle shard files - guaranteed inconsistent across workers
@@ -1185,8 +1197,11 @@ def _construct_chunk(self, j, doc, n_chunks):
11851197 # Add bos/eos tokens if needed
11861198 if self .bos is not None and j == 0 :
11871199 chunk = [self .bos ] + chunk
1188- if j == n_chunks - 1 :
1200+ if j == n_chunks - 1 or self . consec == self . max_consec :
11891201 chunk = chunk + [self .eos ]
1202+ self .consec = 0
1203+ else :
1204+ self .consec += 1
11901205 return chunk
11911206
11921207 def _random_map_docid (self , size ):
@@ -1231,10 +1246,8 @@ def __iter__(self):
12311246 doclcg = self ._random_map_docid (docrange )
12321247 docid = doclcg + mindoc
12331248 doc = self .filehandler .get (reader , docid , self .drop )
1234- if len (doc ) == 0 :
1235- continue
12361249 doclen = len (doc ) + 1 if self .bos is None else len (doc ) + 2
1237- if doclen >= self .min_length :
1250+ if len ( doc ) > 0 and doclen >= self .min_length :
12381251 n_chunks = math .ceil (doclen / self .chunksize )
12391252 for j in range (n_chunks ):
12401253 if i == 0 and j < residual_chunks :
@@ -1345,12 +1358,12 @@ def setup(self):
13451358 if not self .is_setup :
13461359 _StatefulDataset .setup (self )
13471360 n_logical_shards = self .total_shards
1361+ assert (
1362+ n_logical_shards % self .worldsize == 0
1363+ ), f"Total workers { self .worldsize } must divide n_logical_shards { n_logical_shards } evenly"
13481364 logicals = list (range (n_logical_shards ))
13491365 self .logicals_owned = _shard_partition (logicals , self .rank , self .worldsize )
13501366 self .n_logicals = n_logical_shards // self .worldsize
1351- assert (
1352- len (self .logicals_owned ) == self .n_logicals
1353- ), "(world size * num workers) does not divide logical shards evenly"
13541367
13551368 # Build logical shards
13561369 for i in range (self .n_logicals ):
@@ -1367,21 +1380,26 @@ def setup(self):
13671380 )
13681381 [d .setup () for d in self .data ]
13691382 self .n_docs_remaining = [d ._len for d in self .data ]
1370-
1383+ assert (
1384+ sum (self .n_docs_remaining ) > 0
1385+ ), f"No documents detected in shard { self .rank } of { self .datapath } "
1386+
13711387 self .generator = torch .Generator ().manual_seed (self .rank )
13721388
13731389 def __iter__ (self ):
13741390 self .setup ()
13751391 # Grab one doc at a time in random order
13761392 data = [iter (d ) for d in self .data ]
1393+ # Reset if we're rescaling into a prematurely finished epoch
1394+ # (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] )
1395+ if sum (self .n_docs_remaining ) == 0 :
1396+ self .n_docs_remaining = [d ._len for d in self .data ]
1397+ self .generator .manual_seed (self .rank )
13771398 while True :
13781399 # Sample logical shard (or load from ckp)
13791400 if self .current_reader is not None :
13801401 ind = self .current_reader
13811402 else :
1382- assert (
1383- sum (self .n_docs_remaining ) > 0
1384- ), f"No documents detected in { self .datapath } "
13851403 ind = torch .multinomial (
13861404 torch .tensor (self .n_docs_remaining , dtype = torch .float ),
13871405 1 ,
@@ -1473,6 +1491,8 @@ def __init__(
14731491 ]
14741492 )
14751493 assert len (self .datasets ) > 0 , "You must specify at least one dataset"
1494+ for d in datasets :
1495+ assert os .path .exists (os .path .join (datapath , d )), f"Invalid subdataset path: { os .path .join (datapath , d )} "
14761496
14771497 if weights is not None :
14781498 assert len (weights ) == len (
0 commit comments