Skip to content

Commit 05030cc

Browse files
authored
Add FIM training (#139)
* Add fim Signed-off-by: Davis Wertheimer <[email protected]> * Blacking Signed-off-by: Davis Wertheimer <[email protected]> * Corrected fim/clm combo Signed-off-by: Davis Wertheimer <[email protected]> * reblacking Signed-off-by: Davis Wertheimer <[email protected]> * Rereblacking Signed-off-by: Davis Wertheimer <[email protected]> --------- Signed-off-by: Davis Wertheimer <[email protected]>
1 parent 570d557 commit 05030cc

File tree

6 files changed

+169
-18
lines changed

6 files changed

+169
-18
lines changed

fms_fsdp/config/training.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,11 @@ class train_config:
7272
stage2_prompt_length: int = 64
7373
stage2_batch_size: int = 96
7474
stage2_seq_length: int = 256
75+
76+
# FIM training
77+
fim_training: bool = False
78+
psm_rate: float = 0.0
79+
spm_rate: float = 0.0
80+
fim_pre: int = 1
81+
fim_mid: int = 2
82+
fim_suf: int = 3

fms_fsdp/utils/dataloader_utils.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
AutoHandler,
66
BufferDataset,
77
CheckpointDataset,
8+
FIMDataset,
89
ParquetHandler,
910
PreloadBufferDataset,
1011
PreprocessDataset,
@@ -57,9 +58,9 @@ def __iter__(self):
5758
return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size)
5859

5960

60-
def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
61+
def get_data_loader(cfg, rank, world_size):
6162
"""
62-
Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training.
63+
Pytorch dataloader for stateful, distributed, and rescalable language model training.
6364
Assumes underlying data is sequences of integer values.
6465
...
6566
Args
@@ -70,10 +71,9 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
7071
Rank of current distributed worker. Used for handling dataset sharding logic.
7172
world_size : int
7273
Number of distributed workers. Used for handling dataset sharding logic.
73-
postprocess : List[Callable]
74-
Any task-specific postprocessing to apply before handing over data. Steps will apply in
75-
the order provided by the user. For CLM training, use postprocess=[causal_lm].
7674
"""
75+
if cfg.fim_training:
76+
assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?"
7777

7878
datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name)
7979

@@ -118,20 +118,34 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
118118
verbose=(rank == 0),
119119
)
120120
# Wrap above dataset in packing logic to form constant-length lines.
121+
# Increment seq len to counteract CLM's one token removal.
121122
data = BufferDataset(
122123
data,
123-
cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1,
124+
cfg.seq_length + 1,
124125
bos_token=cfg.bol_token,
125126
eos_token=cfg.eol_token,
126127
pack_hard=True,
127128
)
128129
# Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average.
129130
data = PreloadBufferDataset(data, 10000)
130131

131-
# Apply desired postprocessing steps in sequence
132+
# Apply FIM transformation if needed
133+
if cfg.fim_training:
134+
data = FIMDataset(
135+
data,
136+
cfg.eos_token,
137+
cfg.psm_rate,
138+
cfg.spm_rate,
139+
pre_token=cfg.fim_pre,
140+
mid_token=cfg.fim_mid,
141+
suf_token=cfg.fim_suf,
142+
)
143+
144+
# Transform to tensors
132145
data = PreprocessDataset(data, torch.IntTensor)
133-
for p in postprocess:
134-
data = PreprocessDataset(data, p)
146+
147+
# Apply CLM transformation
148+
data = PreprocessDataset(data, causal_lm)
135149

136150
# Enable auto-saving
137151
data = CheckpointDataset(

fms_fsdp/utils/dataset_utils.py

Lines changed: 124 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,128 @@ def load_state_dict(self, state_dicts, sharded_input=False):
712712
return sharded_dicts
713713

714714

715+
class FIMDataset(_WrapperDataset):
716+
"""
717+
Wrapper for a StatefulDataset that implements Fill-In-the-Middle training
718+
(https://arxiv.org/pdf/2207.14255).
719+
Input should be a packed sequence (i.e. call BufferDataset before FIMDataset).
720+
Breaks sequence apart into component document spans, and for each document span
721+
of sufficient length, transforms with specified probability into:
722+
PSM mode: <PRE> (prefix) <SUF> (suffix) <MID> (middle) <EOS>
723+
SPM mode: <PRE> <SUF> (suffix) <MID> (prefix) (middle) <EOS>
724+
The new delimiter tokens can be omitted by passing in None.
725+
Any extra tokens after transformation are dropped from the end of the sequence.
726+
...
727+
Args
728+
----
729+
dataset : _StatefulDataset
730+
Fully instantiated dataset
731+
delimiter_token : any
732+
Token used to indicate document boundaries
733+
psm_rate : float
734+
Chance to transform into PSM. Cannot exceed 1.
735+
spm_rate : float
736+
Chance to transform into SPM. Cannot exceed 1.
737+
min_len : int
738+
Minimum document length to perform FIM transformation
739+
pre_token : any | none
740+
Token used to indicate prefix section of the document
741+
mid_token : any | none
742+
Token used to indicate middle infill section of the document
743+
suf_token : any | none
744+
Token used to indicate suffix section of the document
745+
"""
746+
747+
def __init__(
748+
self,
749+
dataset: _StatefulDataset,
750+
delimiter_token: Any,
751+
psm_rate: float = 0.0,
752+
spm_rate: float = 0.0,
753+
min_len: int = 10,
754+
pre_token=None,
755+
mid_token=None,
756+
suf_token=None,
757+
):
758+
super().__init__(dataset)
759+
assert (
760+
psm_rate + spm_rate > 0
761+
), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate."
762+
assert (
763+
psm_rate + spm_rate <= 1
764+
), f"Combined psm_rate {psm_rate} and spm_rate {spm_rate} probabilities cannot exceed 1."
765+
self.psm = psm_rate
766+
self.spm = spm_rate
767+
self.delimiter = delimiter_token
768+
self.min_len = min_len
769+
self.pref = pre_token
770+
self.suff = suf_token
771+
self.midd = mid_token
772+
773+
self.g_state = None
774+
self.generator = torch.Generator().manual_seed(self.rank)
775+
self.state_params = ["g_state"]
776+
777+
def __iter__(self):
778+
dataset = iter(self.dataset)
779+
while True:
780+
inp = next(dataset)
781+
len_ = len(inp)
782+
i_eos = [0] + [i for i, x in enumerate(inp) if x == self.delimiter] + [len_]
783+
docs = [
784+
inp[i_eos[j] + 1 : i_eos[j + 1]] for j in range(len(i_eos) - 1)
785+
] # list[list[any]]
786+
out = []
787+
for i in range(len(docs)):
788+
doc = docs[i]
789+
if len(docs[i]) >= self.min_len:
790+
# decide psm, spm, or nothing
791+
thresh = torch.rand([1], generator=self.generator).item()
792+
if thresh < self.psm + self.spm:
793+
# Split doc
794+
doc = []
795+
if self.pref:
796+
doc = [self.pref]
797+
splits = torch.randint(
798+
0, len(docs[i]), [2], generator=self.generator
799+
).tolist()
800+
pre = docs[i][: min(splits)]
801+
mid = docs[i][min(splits) : max(splits)]
802+
suf = docs[i][max(splits) :]
803+
804+
if thresh < self.psm:
805+
# PSM transformation
806+
doc += pre
807+
if self.suff:
808+
doc.append(self.suff)
809+
doc += suf
810+
if self.midd:
811+
doc.append(self.midd)
812+
doc += mid
813+
else:
814+
# SPM transformation
815+
if self.suff:
816+
doc.append(self.suff)
817+
doc += suf
818+
if self.midd:
819+
doc.append(self.midd)
820+
doc += pre + mid
821+
out += doc + [self.delimiter]
822+
yield out[:len_]
823+
824+
def state_dict(self):
825+
# Write generator state manually
826+
self.g_state = self.generator.get_state()
827+
return super().state_dict()
828+
829+
def load_state_dict(self, state_dicts, sharded_input=False):
830+
sharded_dicts = super().load_state_dict(state_dicts, sharded_input)
831+
# Manually set generator state if it exists
832+
if self.g_state is not None:
833+
self.generator.set_state(self.g_state)
834+
return sharded_dicts
835+
836+
715837
class BufferDataset(_WrapperDataset):
716838
"""
717839
Wrapper for a _StatefulDataset that takes in sequences of varying lengths, and packs/pads them
@@ -888,9 +1010,8 @@ def __init__(
8881010
self.bos = bos_token
8891011
self.drop = strip_tokens
8901012
self.verbose = verbose
891-
self.docset: List[
892-
Any
893-
] = [] # map of doc indices to (shardid, min docid, max docid)
1013+
# Map of doc indices to (shardid, min docid, max docid)
1014+
self.docset: List[Any] = []
8941015

8951016
# Position
8961017
self.docset_index = 0

main_training_llama.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,11 @@ def main(**kwargs):
119119
model,
120120
optimizer,
121121
None,
122-
path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
123-
if not os.path.isfile(cfg.ckpt_load_path)
124-
else cfg.ckpt_load_path,
122+
path=(
123+
os.path.join(cfg.ckpt_load_path, "checkpoints/")
124+
if not os.path.isfile(cfg.ckpt_load_path)
125+
else cfg.ckpt_load_path
126+
),
125127
strict=False,
126128
)
127129
if not is_resuming:

main_training_mamba.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,11 @@ def main(**kwargs):
142142
model,
143143
optimizer,
144144
None,
145-
path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
146-
if not os.path.isfile(cfg.ckpt_load_path)
147-
else cfg.ckpt_load_path,
145+
path=(
146+
os.path.join(cfg.ckpt_load_path, "checkpoints/")
147+
if not os.path.isfile(cfg.ckpt_load_path)
148+
else cfg.ckpt_load_path
149+
),
148150
strict=False,
149151
)
150152
if not is_resuming:

tests/test_datasets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,10 @@ def test_multi_reload_stress():
632632
# preload / sample / scale / doc pipeline
633633
multi_reload_stress_check(lambda: d6(d5(d4())))
634634

635+
# Add FIM dataset
636+
d7 = lambda x: [FIMDataset(d, -1, 0.25, 0.25, 10, -2, -3, -4) for d in x]
637+
multi_reload_stress_check(lambda: d7(d6(d5(d4()))))
638+
635639

636640
# SCALABLEDATASET TESTS
637641

0 commit comments

Comments
 (0)