Skip to content

Commit fa71972

Browse files
committed
Blacking
1 parent 62ce139 commit fa71972

File tree

3 files changed

+32
-12
lines changed

3 files changed

+32
-12
lines changed

fms_fsdp/config/training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ class train_config:
1515
file_type: str = "arrow"
1616
col_name: str = "tokens"
1717
tokenizer_path: str = "/fsx/tokenizer"
18-
datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
18+
datasets: str = (
19+
"lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
20+
)
1921
weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
2022
seq_length: int = 4096
2123
vocab_size: int = 32000

fms_fsdp/utils/dataloader_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,4 +177,4 @@ def splitstrip(x):
177177
datas = splitstrip(datas)
178178
weights = [float(x) for x in splitstrip(weights)]
179179
cols = splitstrip(cols)
180-
return datas, weights, cols
180+
return datas, weights, cols

fms_fsdp/utils/dataset_utils.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -355,14 +355,18 @@ def length(self, path: str):
355355
return self.open(path).num_record_batches
356356

357357
def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
358-
assert index < reader.num_record_batches, f"Illegal index {index} in set of {reader.num_record_batches} documents"
358+
assert (
359+
index < reader.num_record_batches
360+
), f"Illegal index {index} in set of {reader.num_record_batches} documents"
359361
frame = reader.get_batch(index)
360362
doc = None
361363
for name in self.col_names:
362364
if name in frame.column_names:
363365
doc = frame[name]
364366
break
365-
assert doc is not None, f"None of column names {self.col_names} found in file headers {frame.column_names}"
367+
assert (
368+
doc is not None
369+
), f"None of column names {self.col_names} found in file headers {frame.column_names}"
366370
if len(doc) > 0 and doc[0].as_py() in drop_tokens:
367371
doc = doc.slice(1, len(doc) - 1)
368372
# Recheck len for edge case where doc=[eos]
@@ -382,7 +386,9 @@ class ParquetHandler(_ShardFileHandler):
382386
before getting/slicing. However, this is a standard and widely-used data format.
383387
"""
384388

385-
def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]):
389+
def __init__(
390+
self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]
391+
):
386392
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
387393
self.col_names = col_names
388394

@@ -396,14 +402,18 @@ def open(self, path: str):
396402
if name in names:
397403
match = name
398404
break
399-
assert match is not None, f"None of column names {self.col_names} found in file headers {names}"
405+
assert (
406+
match is not None
407+
), f"None of column names {self.col_names} found in file headers {names}"
400408
return pq.read_pandas(path, columns=[match], partitioning=None)[match]
401409

402410
def length(self, path: str):
403411
return pq.read_metadata(path).num_rows
404412

405413
def get(self, reader, index: int, drop_tokens: Set):
406-
assert index < reader.length(), f"Illegal index {index} in set of {reader.length()} documents"
414+
assert (
415+
index < reader.length()
416+
), f"Illegal index {index} in set of {reader.length()} documents"
407417
doc = self.tokenizer(str(reader[index])[:1_000_000])["input_ids"]
408418
if len(doc) > 0 and doc[0] in drop_tokens:
409419
doc = doc[1:]
@@ -417,7 +427,9 @@ def slice(self, doc: List, index: int, n_pull: int) -> List:
417427

418428

419429
class AutoHandler(_ShardFileHandler):
420-
def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]):
430+
def __init__(
431+
self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]
432+
):
421433
self.PHandler = ParquetHandler(tokenizer_path, col_names)
422434
self.AHandler = ArrowHandler(col_names)
423435
self.current = _ShardFileHandler()
@@ -1132,7 +1144,9 @@ def setup(self):
11321144
ndocs = doc_counts[shard]
11331145
if ndocs > 0:
11341146
doc_start = int(ndocs * shardset[shard][0])
1135-
doc_end = max(doc_start, int(ndocs * shardset[shard][1]) - 1) # inclusive upper bound
1147+
doc_end = max(
1148+
doc_start, int(ndocs * shardset[shard][1]) - 1
1149+
) # inclusive upper bound
11361150
self.docset.append([shard, doc_start, doc_end])
11371151
doccount += doc_end - doc_start + 1
11381152
self._len = doccount
@@ -1280,7 +1294,9 @@ def __iter__(self):
12801294
yield self._construct_chunk(j, doc, n_chunks)
12811295

12821296
# Check that epoch was non-empty
1283-
assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}"
1297+
assert (
1298+
self.has_yielded
1299+
), f"Empty logical shard detected: {self.dataset, self.docset}"
12841300

12851301
def load_state_dict(self, state_dicts, sharded_input=False):
12861302
self.setup()
@@ -1378,7 +1394,7 @@ def setup(self):
13781394
assert (
13791395
sum(self.n_docs_remaining) > 0
13801396
), f"No documents detected in shard {self.rank} of {self.datapath}"
1381-
1397+
13821398
self.generator = torch.Generator().manual_seed(self.rank)
13831399

13841400
def __iter__(self):
@@ -1487,7 +1503,9 @@ def __init__(
14871503
)
14881504
assert len(self.datasets) > 0, "You must specify at least one dataset"
14891505
for d in datasets:
1490-
assert os.path.exists(os.path.join(datapath, d)), f"Invalid subdataset path: {os.path.join(datapath, d)}"
1506+
assert os.path.exists(
1507+
os.path.join(datapath, d)
1508+
), f"Invalid subdataset path: {os.path.join(datapath, d)}"
14911509

14921510
if weights is not None:
14931511
assert len(weights) == len(

0 commit comments

Comments
 (0)