@@ -712,6 +712,128 @@ def load_state_dict(self, state_dicts, sharded_input=False):
712712 return sharded_dicts
713713
714714
715+ class FIMDataset (_WrapperDataset ):
716+ """
717+ Wrapper for a StatefulDataset that implements Fill-In-the-Middle training
718+ (https://arxiv.org/pdf/2207.14255).
719+ Input should be a packed sequence (i.e. call BufferDataset before FIMDataset).
720+ Breaks sequence apart into component document spans, and for each document span
721+ of sufficient length, transforms with specified probability into:
722+ PSM mode: <PRE> (prefix) <SUF> (suffix) <MID> (middle) <EOS>
723+ SPM mode: <PRE> <SUF> (suffix) <MID> (prefix) (middle) <EOS>
724+ The new delimiter tokens can be omitted by passing in None.
725+ Any extra tokens after transformation are dropped from the end of the sequence.
726+ ...
727+ Args
728+ ----
729+ dataset : _StatefulDataset
730+ Fully instantiated dataset
731+ delimiter_token : any
732+ Token used to indicate document boundaries
733+ psm_rate : float
734+ Chance to transform into PSM. Cannot exceed 1.
735+ spm_rate : float
736+ Chance to transform into SPM. Cannot exceed 1.
737+ min_len : int
738+ Minimum document length to perform FIM transformation
739+ pre_token : any | none
740+ Token used to indicate prefix section of the document
741+ mid_token : any | none
742+ Token used to indicate middle infill section of the document
743+ suf_token : any | none
744+ Token used to indicate suffix section of the document
745+ """
746+
747+ def __init__ (
748+ self ,
749+ dataset : _StatefulDataset ,
750+ delimiter_token : Any ,
751+ psm_rate : float = 0.0 ,
752+ spm_rate : float = 0.0 ,
753+ min_len : int = 10 ,
754+ pre_token = None ,
755+ mid_token = None ,
756+ suf_token = None ,
757+ ):
758+ super ().__init__ (dataset )
759+ assert (
760+ psm_rate + spm_rate > 0
761+ ), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate."
762+ assert (
763+ psm_rate + spm_rate <= 1
764+ ), f"Combined psm_rate { psm_rate } and spm_rate { spm_rate } probabilities cannot exceed 1."
765+ self .psm = psm_rate
766+ self .spm = spm_rate
767+ self .delimiter = delimiter_token
768+ self .min_len = min_len
769+ self .pref = pre_token
770+ self .suff = suf_token
771+ self .midd = mid_token
772+
773+ self .g_state = None
774+ self .generator = torch .Generator ().manual_seed (self .rank )
775+ self .state_params = ["g_state" ]
776+
777+ def __iter__ (self ):
778+ dataset = iter (self .dataset )
779+ while True :
780+ inp = next (dataset )
781+ len_ = len (inp )
782+ i_eos = [0 ] + [i for i , x in enumerate (inp ) if x == self .delimiter ] + [len_ ]
783+ docs = [
784+ inp [i_eos [j ] + 1 : i_eos [j + 1 ]] for j in range (len (i_eos ) - 1 )
785+ ] # list[list[any]]
786+ out = []
787+ for i in range (len (docs )):
788+ doc = docs [i ]
789+ if len (docs [i ]) >= self .min_len :
790+ # decide psm, spm, or nothing
791+ thresh = torch .rand ([1 ], generator = self .generator ).item ()
792+ if thresh < self .psm + self .spm :
793+ # Split doc
794+ doc = []
795+ if self .pref :
796+ doc = [self .pref ]
797+ splits = torch .randint (
798+ 0 , len (docs [i ]), [2 ], generator = self .generator
799+ ).tolist ()
800+ pre = docs [i ][: min (splits )]
801+ mid = docs [i ][min (splits ) : max (splits )]
802+ suf = docs [i ][max (splits ) :]
803+
804+ if thresh < self .psm :
805+ # PSM transformation
806+ doc += pre
807+ if self .suff :
808+ doc .append (self .suff )
809+ doc += suf
810+ if self .midd :
811+ doc .append (self .midd )
812+ doc += mid
813+ else :
814+ # SPM transformation
815+ if self .suff :
816+ doc .append (self .suff )
817+ doc += suf
818+ if self .midd :
819+ doc .append (self .midd )
820+ doc += pre + mid
821+ out += doc + [self .delimiter ]
822+ yield out [:len_ ]
823+
824+ def state_dict (self ):
825+ # Write generator state manually
826+ self .g_state = self .generator .get_state ()
827+ return super ().state_dict ()
828+
829+ def load_state_dict (self , state_dicts , sharded_input = False ):
830+ sharded_dicts = super ().load_state_dict (state_dicts , sharded_input )
831+ # Manually set generator state if it exists
832+ if self .g_state is not None :
833+ self .generator .set_state (self .g_state )
834+ return sharded_dicts
835+
836+
715837class BufferDataset (_WrapperDataset ):
716838 """
717839 Wrapper for a _StatefulDataset that takes in sequences of varying lengths, and packs/pads them
@@ -890,9 +1012,8 @@ def __init__(
8901012 self .drop = strip_tokens
8911013 self .max_consec = max_consecutive_chunks
8921014 self .verbose = verbose
893- self .docset : List [
894- Any
895- ] = [] # map of doc indices to (shardid, min docid, max docid)
1015+ # Map of doc indices to (shardid, min docid, max docid)
1016+ self .docset : List [Any ] = []
8961017
8971018 # Position
8981019 self .docset_index = 0
0 commit comments