Skip to content

Commit 922d373

Browse files
committed
Make doc slicing flaggable
1 parent abd5ef1 commit 922d373

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

fms_fsdp/config/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class train_config:
3030
doc_breakpoint: int = 65_536
3131
filter_exp: int = 2
3232
target_doclen: int = 8192
33+
slice_rate: float = 0.0
3334

3435
# FIM training
3536
psm_rate: float = 0.0

fms_fsdp/utils/dataloader_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,12 @@ def get_data_loader(cfg, rank, world_size, dp_degree):
142142
# Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average.
143143
data = PreloadBufferDataset(data, 1000)
144144
# Slice and rearrange docs to force long-context retrieval
145-
data = DocSliceDataset(
146-
data,
147-
cfg.eos_token,
148-
slice_rate=.75,
149-
)
145+
if cfg.slice_rate > 0:
146+
data = DocSliceDataset(
147+
data,
148+
cfg.eos_token,
149+
slice_rate=cfg.slice_rate,
150+
)
150151
# Apply FIM transformation if needed
151152
if fim_training:
152153
data = FIMDataset(

0 commit comments

Comments
 (0)