Skip to content

Commit 8324609

Browse files
authored
Expose doc_breakpoint
1 parent c967bb5 commit 8324609

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

fms_fsdp/config/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class train_config:
2929
logical_shards: int = 1024
3030
num_workers: int = 1
3131
doc_cutoff: int = 1_000_000
32+
doc_breakpoint: int = 65_536
3233

3334
# fsdp policies
3435
sharding_strategy: str = "hsdp"

fms_fsdp/utils/dataloader_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from math import ceil
23

34
from fms_fsdp.utils.dataset_utils import (
45
ArrowHandler,
@@ -94,7 +95,7 @@ def get_data_loader(cfg, rank, world_size):
9495
)
9596
else:
9697
filehandler = _handler_map[cfg.file_type](cols)
97-
98+
9899
# Base reader layer
99100
data = StreamingDocDataset(
100101
cfg.data_path,
@@ -105,6 +106,7 @@ def get_data_loader(cfg, rank, world_size):
105106
bos_token=cfg.bos_token,
106107
strip_tokens=set(droplist),
107108
min_length=3,
109+
max_consecutive_chunks=ceil(cfg.doc_breakpoint/1024),
108110
seed=cfg.seed,
109111
)
110112
# Add rescaling/resharding

0 commit comments

Comments
 (0)