Skip to content

Commit 5b76df2

Browse files
committed
Add everything else
1 parent 24a4667 commit 5b76df2

File tree

1 file changed

+75
-55
lines changed

1 file changed

+75
-55
lines changed

fms_fsdp/utils/dataset_utils.py

Lines changed: 75 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from fms_fsdp.utils.checkpointing_utils import get_latest
1717

18-
# TODO: long doc breaking
1918
# TODO: titan PR adds
2019
# TODO: zero-len file asserts/check
2120

@@ -359,8 +358,8 @@ def length(self, path: str):
359358
return self.open(path).num_record_batches
360359

361360
def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
361+
assert index < reader.num_record_batches, f"Illegal index {index} in set of {reader.num_record_batches} documents"
362362
frame = reader.get_batch(index)
363-
364363
doc = None
365364
for name in self.col_names:
366365
if name in frame.column_names:
@@ -407,7 +406,8 @@ def length(self, path: str):
407406
return pq.read_metadata(path).num_rows
408407

409408
def get(self, reader, index: int, drop_tokens: Set):
410-
doc = self.tokenizer(str(reader[index]))["input_ids"]
409+
assert index < reader.length(), f"Illegal index {index} in set of {reader.length()} documents"
410+
doc = self.tokenizer(str(reader[index])[:1_000_000])["input_ids"]
411411
if len(doc) > 0 and doc[0] in drop_tokens:
412412
doc = doc[1:]
413413
# Recheck len for edge case where doc=[eos]
@@ -978,10 +978,10 @@ class StreamingDocDataset(_StatefulDataset):
978978
Documents below this length are skipped
979979
max_chunksize : int
980980
Maximum sequence length to return. Break long docs into chunks of this size or shorter.
981+
max_consecutive_chunks : int
982+
Number of doc chunks to emit before manually inserting EOS and resuming later.
981983
verbose : bool
982984
Track setup progress?
983-
shuffle : bool
984-
Shuffle shard file and document orders? (Disable for simple testing)
985985
"""
986986

987987
def __init__(
@@ -996,6 +996,7 @@ def __init__(
996996
seed: int = 42,
997997
min_length: int = 1,
998998
max_chunksize: int = 1024,
999+
max_consecutive_chunks: int = 64,
9991000
verbose: bool = False,
10001001
):
10011002
super().__init__(datapath, rank, worldsize)
@@ -1008,6 +1009,7 @@ def __init__(
10081009
self.eos = delimiter_token
10091010
self.bos = bos_token
10101011
self.drop = strip_tokens
1012+
self.max_consec = max_consecutive_chunks
10111013
self.verbose = verbose
10121014
# Map of doc indices to (shardid, min docid, max docid)
10131015
self.docset: List[Any] = []
@@ -1022,6 +1024,7 @@ def __init__(
10221024
self.tokens_seen = 0
10231025
self.docs_seen = 0
10241026
self.percent_seen = 0
1027+
self.consec = 0
10251028

10261029
self.state_params = [
10271030
"dataset",
@@ -1032,6 +1035,7 @@ def __init__(
10321035
"docs_seen",
10331036
"percent_seen",
10341037
"lcg_state",
1038+
"consec",
10351039
]
10361040

10371041
# Setup flags
@@ -1064,73 +1068,81 @@ def setup(self):
10641068
if self.filehandler.is_legal(os.path.join(root, name))
10651069
]
10661070
shards.sort() # Ensure consistent sharding across machines
1067-
start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize
1068-
end_frag = (
1069-
(self.rank + 1) * self.worldsize * len(shards)
1070-
) // self.worldsize
1071-
shardfrags = [
1072-
(shards[i // self.worldsize], i % self.worldsize)
1073-
for i in range(start_frag, end_frag)
1074-
]
1075-
1076-
# Assemble length of each owned shard file
10771071

1072+
# Find metadata file
10781073
countfiles = []
10791074
if os.path.exists(os.path.join(pardir, "meta")):
10801075
countfiles = [
10811076
x
10821077
for x in os.listdir(os.path.join(pardir, "meta"))
10831078
if "counts" in x and "csv" in x
10841079
]
1085-
doc_counts = {}
10861080
if len(countfiles) > 0:
10871081
# Count file exists, use it
10881082
countpath = os.path.join(pardir, "meta", countfiles[0])
1083+
else:
1084+
countpath = ""
1085+
1086+
# Use shard file sizes to perform partitioning
1087+
# Create shardlist of form shardid -> [start%, end%]
1088+
if len(countfiles) > 0:
1089+
sizes = {}
1090+
with open(countpath, "r") as csvfile:
1091+
reader = csv.DictReader(csvfile)
1092+
for row in reader:
1093+
fullpath = row["dataset/filename"]
1094+
prefix = fullpath.find(dataset + "/")
1095+
if prefix >= 0:
1096+
key = fullpath[prefix + len(dataset) + 1 :]
1097+
sizes[key] = int(row["size"])
1098+
shard_sizes = [sizes[shard] for shard in shards]
1099+
else:
1100+
shard_sizes = [
1101+
os.path.getsize(os.path.join(datapath, shard)) for shard in shards
1102+
]
1103+
shard_sizes = [s / sum(shard_sizes) for s in shard_sizes]
1104+
start = self.rank / self.worldsize
1105+
end = (self.rank + 1) / self.worldsize
1106+
shardset = {}
1107+
tally = 0
1108+
for i in range(len(shards)):
1109+
if tally <= end and tally + shard_sizes[i] >= start:
1110+
shardset[shards[i]] = [
1111+
min(max((start - tally) / shard_sizes[i], 0), 1),
1112+
min(max((end - tally) / shard_sizes[i], 0), 1),
1113+
]
1114+
tally += shard_sizes[i]
1115+
# Count file exists, use it
10891116
with open(countpath, "r") as csvfile:
10901117
reader = csv.DictReader(csvfile)
10911118
for row in reader:
10921119
fullpath = row["dataset/filename"]
1093-
prefix = fullpath.find("/" + dataset) + 1
1094-
if prefix > 0:
1120+
prefix = fullpath.find(dataset)
1121+
if prefix >= 0:
10951122
key = fullpath[prefix + len(dataset) + 1 :]
10961123
doc_counts[key] = int(row["documents"])
10971124
else:
10981125
# Count file does not exist, touch every owned file for length
1099-
unique_shardfiles = set(shard for shard, frag in shardfrags)
11001126
doc_counts = {
11011127
shard: self.filehandler.length(os.path.join(datapath, shard))
1102-
for shard in unique_shardfiles
1128+
for shard in shardset
11031129
}
11041130

1105-
# Read shardfrags, assemble doc list for each file shard (aggregating over fragments):
1106-
ndocs = -1
1107-
docset = {} # shardid -> (min docid, max docid)
1108-
for i, (shard, frag) in enumerate(shardfrags):
1109-
ndocs = doc_counts[shard]
1110-
doc_start = (ndocs * frag) // self.worldsize
1111-
doc_end = (
1112-
ndocs * frag + ndocs
1113-
) // self.worldsize - 1 # Inclusive upper bound
1114-
if shard not in docset:
1115-
docset[shard] = [doc_start, doc_end]
1116-
min_d, max_d = docset[shard]
1117-
if doc_start < min_d:
1118-
docset[shard][0] = doc_start
1119-
if doc_end > max_d:
1120-
docset[shard][1] = doc_end
1121-
1122-
# Add shard entries to self.docset
1131+
# Assemble doc list for each file shard
1132+
# Create docset of form [shardid, min docid, max docid]
11231133
doccount = 0
1124-
for shardid in docset:
1125-
min_d = docset[shardid][0]
1126-
max_d = docset[shardid][1]
1127-
self.docset.append((shardid, min_d, max_d))
1128-
doccount += max_d - min_d + 1
1134+
for shard in shardset:
1135+
ndocs = doc_counts[shard]
1136+
if ndocs > 0:
1137+
doc_start = int(ndocs * shardset[shard][0])
1138+
doc_end = max(doc_start, int(ndocs * shardset[shard][1]) - 1) # inclusive upper bound
1139+
self.docset.append([shard, doc_start, doc_end])
1140+
doccount += doc_end - doc_start + 1
11291141
self._len = doccount
11301142

11311143
if self.verbose:
11321144
logging.info(
1133-
f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}"
1145+
f" Worker {self.rank} ingested {len(self.docset)} shard fragments from {dataset}"
11341146
)
11351147

11361148
# Shuffle shard files - guaranteed inconsistent across workers
@@ -1185,8 +1197,11 @@ def _construct_chunk(self, j, doc, n_chunks):
11851197
# Add bos/eos tokens if needed
11861198
if self.bos is not None and j == 0:
11871199
chunk = [self.bos] + chunk
1188-
if j == n_chunks - 1:
1200+
if j == n_chunks - 1 or self.consec == self.max_consec:
11891201
chunk = chunk + [self.eos]
1202+
self.consec = 0
1203+
else:
1204+
self.consec += 1
11901205
return chunk
11911206

11921207
def _random_map_docid(self, size):
@@ -1231,10 +1246,8 @@ def __iter__(self):
12311246
doclcg = self._random_map_docid(docrange)
12321247
docid = doclcg + mindoc
12331248
doc = self.filehandler.get(reader, docid, self.drop)
1234-
if len(doc) == 0:
1235-
continue
12361249
doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
1237-
if doclen >= self.min_length:
1250+
if len(doc) > 0 and doclen >= self.min_length:
12381251
n_chunks = math.ceil(doclen / self.chunksize)
12391252
for j in range(n_chunks):
12401253
if i == 0 and j < residual_chunks:
@@ -1345,12 +1358,12 @@ def setup(self):
13451358
if not self.is_setup:
13461359
_StatefulDataset.setup(self)
13471360
n_logical_shards = self.total_shards
1361+
assert (
1362+
n_logical_shards % self.worldsize == 0
1363+
), f"Total workers {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly"
13481364
logicals = list(range(n_logical_shards))
13491365
self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize)
13501366
self.n_logicals = n_logical_shards // self.worldsize
1351-
assert (
1352-
len(self.logicals_owned) == self.n_logicals
1353-
), "(world size * num workers) does not divide logical shards evenly"
13541367

13551368
# Build logical shards
13561369
for i in range(self.n_logicals):
@@ -1367,21 +1380,26 @@ def setup(self):
13671380
)
13681381
[d.setup() for d in self.data]
13691382
self.n_docs_remaining = [d._len for d in self.data]
1370-
1383+
assert (
1384+
sum(self.n_docs_remaining) > 0
1385+
), f"No documents detected in shard {self.rank} of {self.datapath}"
1386+
13711387
self.generator = torch.Generator().manual_seed(self.rank)
13721388

13731389
def __iter__(self):
13741390
self.setup()
13751391
# Grab one doc at a time in random order
13761392
data = [iter(d) for d in self.data]
1393+
# Reset if we're rescaling into a prematurely finished epoch
1394+
# (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] )
1395+
if sum(self.n_docs_remaining) == 0:
1396+
self.n_docs_remaining = [d._len for d in self.data]
1397+
self.generator.manual_seed(self.rank)
13771398
while True:
13781399
# Sample logical shard (or load from ckp)
13791400
if self.current_reader is not None:
13801401
ind = self.current_reader
13811402
else:
1382-
assert (
1383-
sum(self.n_docs_remaining) > 0
1384-
), f"No documents detected in {self.datapath}"
13851403
ind = torch.multinomial(
13861404
torch.tensor(self.n_docs_remaining, dtype=torch.float),
13871405
1,
@@ -1473,6 +1491,8 @@ def __init__(
14731491
]
14741492
)
14751493
assert len(self.datasets) > 0, "You must specify at least one dataset"
1494+
for d in datasets:
1495+
assert os.path.exists(os.path.join(datapath, d)), f"Invalid subdataset path: {os.path.join(datapath, d)}"
14761496

14771497
if weights is not None:
14781498
assert len(weights) == len(

0 commit comments

Comments
 (0)