17
17
from scipy .linalg import block_diag
18
18
19
19
20
+ def generate_greedy_packs (examples , max_length ):
21
+ left_len = np .zeros ([len (examples )]) - 1
22
+ left_len [0 ] = max_length # At the beginning, only the first pack is valid.
23
+ generate_packs = [[] for i in range (len (examples ))]
24
+ index , left_index = 0 , 0
25
+
26
+ while index < len (examples ):
27
+ record = examples [index ]
28
+ max_left_index = left_len .argmax ()
29
+ # Put the current sequence into the largest left space valid pack.
30
+ if len (record ["input_ids" ]) <= left_len [max_left_index ]:
31
+ generate_packs [max_left_index ].append (record )
32
+ left_len [max_left_index ] -= len (record ["input_ids" ])
33
+ index += 1
34
+ else :
35
+ left_index += 1
36
+ left_len [left_index ] = max_length
37
+
38
+ return generate_packs
39
+
40
+
20
41
class ZeroPadding :
21
42
required_output_keys = ["input_ids" , "labels" , "attention_mask" ]
22
43
# Only supported the following keys for ZeroPadding. Keys outside of the set will be ignored.
@@ -80,38 +101,66 @@ def _pad_batch_records(cls, batch_records):
80
101
81
102
82
103
class ZeroPaddingMapDataset (ZeroPadding , Dataset ):
83
- def __init__ (self , data , tokenizer , max_length ):
104
+ def __init__ (self , data , tokenizer , max_length , greedy_zero_padding = False ):
84
105
self .tokenizer = tokenizer
85
106
self .max_length = max_length
107
+ self .greedy_zero_padding = greedy_zero_padding
86
108
self .new_data = self ._create_zero_padding_data (data )
87
109
88
110
def _create_zero_padding_data (self , data ):
89
- batch_records , max_len = [], 0
90
- cur_len_so_far = 0
91
-
92
111
total_data = []
93
- for i in range (len (data )):
94
- record = data [i ]
95
- max_len = max (max_len , len (record ["input_ids" ]))
96
- to_append = (cur_len_so_far + len (record ["input_ids" ])) <= self .max_length
97
- if to_append :
98
- batch_records .append (record )
99
- cur_len_so_far += len (record ["input_ids" ])
100
- else :
101
- # exceed max length
112
+ if not self .greedy_zero_padding :
113
+ batch_records = []
114
+ cur_len_so_far = 0
115
+ for i in range (len (data )):
116
+ record = data [i ]
117
+ if len (record ["input_ids" ]) > self .max_length :
118
+ continue
119
+ to_append = (cur_len_so_far + len (record ["input_ids" ])) <= self .max_length
120
+ if to_append :
121
+ batch_records .append (record )
122
+ cur_len_so_far += len (record ["input_ids" ])
123
+ else :
124
+ # exceed max length
125
+ padded_list = self ._pad_batch_records (batch_records )
126
+ total_data .append (padded_list )
127
+ # reset
128
+ batch_records = []
129
+ cur_len_so_far = 0
130
+ # append current data
131
+ batch_records .append (record )
132
+ cur_len_so_far += len (record ["input_ids" ])
133
+
134
+ # remaining data
135
+ if batch_records :
102
136
padded_list = self ._pad_batch_records (batch_records )
103
137
total_data .append (padded_list )
104
- # reset
105
- batch_records , max_len = [], 0
106
- cur_len_so_far = 0
107
- # append current data
108
- batch_records .append (record )
109
- cur_len_so_far += len (record ["input_ids" ])
110
-
111
- # remaining data
112
- if batch_records :
113
- padded_list = self ._pad_batch_records (batch_records )
114
- total_data .append (padded_list )
138
+ else :
139
+ examples = []
140
+ buffer_size = 500
141
+ i = 0
142
+ for record in data :
143
+ if len (record ["input_ids" ]) > self .max_length :
144
+ continue
145
+ if i < buffer_size :
146
+ examples .append (record )
147
+ i += 1
148
+ else :
149
+ # Running greedy strategy in examples.
150
+ generate_packs = generate_greedy_packs (examples , self .max_length )
151
+ for batch_records in generate_packs :
152
+ if len (batch_records ) > 0 :
153
+ padded_list = self ._pad_batch_records (batch_records )
154
+ total_data .append (padded_list )
155
+ examples = [record ]
156
+ i = 1
157
+ if len (examples ) > 0 :
158
+ generate_packs = generate_greedy_packs (examples , self .max_length )
159
+ for batch_records in generate_packs :
160
+ if len (batch_records ) > 0 :
161
+ padded_list = self ._pad_batch_records (batch_records )
162
+ total_data .append (padded_list )
163
+
115
164
return total_data
116
165
117
166
def __getitem__ (self , idx ):
@@ -122,34 +171,61 @@ def __len__(self):
122
171
123
172
124
173
class ZeroPaddingIterableDataset (ZeroPadding , IterableDataset ):
125
- def __init__ (self , data , tokenizer , max_length ):
126
-
174
+ def __init__ (self , data , tokenizer , max_length , greedy_zero_padding = False ):
127
175
self .data = data
128
176
self .tokenizer = tokenizer
129
177
self .max_length = max_length
130
178
self .zero_padding_global_step = 0
179
+ self .greedy_zero_padding = greedy_zero_padding
131
180
132
181
def __iter__ (self ):
133
- batch_records , max_len = [], 0
134
- cur_len_so_far = 0
135
- for record in self .data :
136
- max_len = max (max_len , len (record ["input_ids" ]))
137
- to_append = (cur_len_so_far + len (record ["input_ids" ])) <= self .max_length
138
- if to_append :
139
- batch_records .append (record )
140
- self .zero_padding_global_step += 1
141
- cur_len_so_far += len (record ["input_ids" ])
142
- else :
143
- # exceed max length
182
+ if not self .greedy_zero_padding :
183
+ batch_records = []
184
+ cur_len_so_far = 0
185
+ for record in self .data :
186
+ to_append = (cur_len_so_far + len (record ["input_ids" ])) <= self .max_length
187
+ if to_append :
188
+ batch_records .append (record )
189
+ self .zero_padding_global_step += 1
190
+ cur_len_so_far += len (record ["input_ids" ])
191
+ else :
192
+ # exceed max length
193
+ padded_list = self ._pad_batch_records (batch_records )
194
+ yield padded_list
195
+ # reset
196
+ batch_records = []
197
+ cur_len_so_far = 0
198
+ # append current data
199
+ batch_records .append (record )
200
+ self .zero_padding_global_step += 1
201
+ cur_len_so_far += len (record ["input_ids" ])
202
+ if batch_records :
144
203
padded_list = self ._pad_batch_records (batch_records )
145
204
yield padded_list
146
- # reset
147
- batch_records , max_len = [], 0
148
- cur_len_so_far = 0
149
- # append current data
150
- batch_records .append (record )
151
- self .zero_padding_global_step += 1
152
- cur_len_so_far += len (record ["input_ids" ])
153
- if batch_records :
154
- padded_list = self ._pad_batch_records (batch_records )
155
- yield padded_list
205
+ else :
206
+ examples = []
207
+ buffer_size = 500
208
+ i = 0
209
+ for record in self .data :
210
+ if len (record ["input_ids" ]) > self .max_length :
211
+ continue
212
+ if i < buffer_size :
213
+ examples .append (record )
214
+ self .zero_padding_global_step += 1
215
+ i += 1
216
+ else :
217
+ # Running greedy strategy in examples.
218
+ generate_packs = generate_greedy_packs (examples , self .max_length )
219
+ for batch_records in generate_packs :
220
+ if len (batch_records ) > 0 :
221
+ padded_list = self ._pad_batch_records (batch_records )
222
+ yield padded_list
223
+ examples = [record ]
224
+ self .zero_padding_global_step += 1
225
+ i = 1
226
+ if len (examples ) > 0 :
227
+ generate_packs = generate_greedy_packs (examples , self .max_length )
228
+ for batch_records in generate_packs :
229
+ if len (batch_records ) > 0 :
230
+ padded_list = self ._pad_batch_records (batch_records )
231
+ yield padded_list
0 commit comments