@@ -17,8 +17,8 @@ def __init__(
1717 tokenizer : PreTrainedTokenizerBase ,
1818 jinja_template : str ,
1919 num_shots : int = 0 ,
20- max_source_content_length : Optional [int ] = None ,
21- max_target_content_length : Optional [int ] = None ,
20+ max_source_length_per_shot : Optional [int ] = None ,
21+ max_target_length_per_shot : Optional [int ] = None ,
2222 truncation_strategy : Literal ["longest" , "uniform" ] = "longest" ,
2323 use_words : bool = True ,
2424 source_fields : Optional [Sequence [str ]] = None ,
@@ -35,14 +35,15 @@ def __init__(
3535 the source and target; we use promptsource to parse the
3636 template and extract the source and target fields; please
3737 see the promptsource documentation for more details.
38- max_source_content_length (Optional[int], optional): the maximum
39- length of the source content (i.e., the content that is given
40- as input to the model) . If not provided, no truncation will
41- be performed. Defaults to None.
38+ max_source_length_per_shot (Optional[int], optional): the maximum
39+ length of all the fields that are part of the source in a
40+ prompting shot . If not provided, no truncation will be
41+ performed. Defaults to None
4242 max_target_content_length (Optional[int], optional): the maximum
43- length of the target content (i.e., the content that is
44- expected as output from the model). If not provided, no
45- truncation will be performed. Defaults to None.
43+ length of all the fields that are part of the target in a
44+ prompting shot (that is, the text the model is asked to
45+ generate). If not provided, no truncation will be performed.
46+ Defaults to None.
4647 truncation_strategy ("longest" or "uniform"], optional): how to
4748 perform truncation if the source or target content is longer
4849 than the maximum length. If "longest", the longest fields
@@ -124,17 +125,42 @@ def __init__(
124125 # if we don't use words, we just use the length of the prompt
125126 # in characters.
126127 length_src_prompt = len (source_text )
127- length_tgt_prompt = len (target_text )
128+ # for target, we actually take the max in case there are multiple,
129+ # and 0 if there are none.
130+ length_tgt_prompt = max ([len (t ) for t in target_text ] or [0 ])
131+
132+ # one liner to round to ceil. avoid import of math.ceil
133+ def ceil (x ):
134+ return int (x + (1 if x % 1 else 0 )) # noqa: E731
128135
129- if max_source_content_length is not None :
136+ if max_source_length_per_shot is not None :
130137 # in case a max length for the source is provided, we need to
131- # truncate; first, we decrease the max length by the length of
132- # prompt text.
133- max_source_content_length -= length_src_prompt
138+ # truncate. The total max_length for source data in each shot
139+ # needs to be reduce by (a) the length of the target prompt
140+ # text when doing few-shot, and (b) the length of text of
141+ # the prompt.
142+ #
143+ # For both (a) and (b), we need to distribute the length by
144+ # the number of shorts:
145+ # (a): recall that each prompt will contain n shots + the
146+ # prompt for the sequence we care about. So when doing
147+ # n shot, we are adding n target sequences, but are
148+ # truncating n + 1 target sequences. Therefore, we multiply
149+ # target length by n but divide by (n + 1)
150+ # (b): the text that is part of the prompt but is not variables
151+ # (e.g., instructions) must be divided over n + 1 sources.
152+ actual_source_context_length = (
153+ max_source_length_per_shot
154+ - ceil (
155+ (max_target_length_per_shot or 0 )
156+ * (num_shots / (num_shots + 1 ))
157+ )
158+ - ceil (length_src_prompt / (num_shots + 1 ))
159+ )
134160
135161 # we raise if the max length is less than one after accounting
136162 # for the length of the prompt text.
137- if max_source_content_length < 1 :
163+ if actual_source_context_length < 1 :
138164 raise ValueError (
139165 f"max_source_content_length must be at least equal to "
140166 f"the length of the source prompt ({ length_src_prompt } )!"
@@ -144,17 +170,17 @@ def __init__(
144170 self .chain (
145171 TruncateMultipleFieldsMapper (
146172 fields_to_truncate = source_fields ,
147- max_length = max_source_content_length ,
173+ max_length = actual_source_context_length ,
148174 strategy = truncation_strategy ,
149175 )
150176 )
151177
152- if len (target_text ) > 0 and max_target_content_length :
178+ if len (target_text ) > 0 and max_target_length_per_shot :
153179 # we operate here in the same way as for the source, but we
154180 # only do it if there is a target prompt.
155- max_target_content_length -= length_tgt_prompt
181+ max_target_length_per_shot -= length_tgt_prompt
156182
157- if max_target_content_length < 1 :
183+ if max_target_length_per_shot < 1 :
158184 raise ValueError (
159185 f"max_target_content_length must be at least equal to "
160186 f"the length of the target prompt ({ length_tgt_prompt } )!"
@@ -163,7 +189,7 @@ def __init__(
163189 self .chain (
164190 TruncateMultipleFieldsMapper (
165191 fields_to_truncate = target_fields ,
166- max_length = max_target_content_length ,
192+ max_length = max_target_length_per_shot ,
167193 strategy = truncation_strategy ,
168194 )
169195 )
0 commit comments