@@ -53,7 +53,42 @@ class ZeroPadding:
53
53
]
54
54
55
55
@classmethod
56
- def _pad_batch_records (cls , batch_records ):
56
+ def _pad_batch_records_to_max_length (cls , batch_records , max_length , pad_token = 0 ):
57
+ # confirm the at least one item in the pack
58
+ if len (batch_records ) == 0 :
59
+ return batch_records
60
+ # count all records total length
61
+ total_length = sum ([len (record ["input_ids" ]) for record in batch_records ])
62
+ reserved_length = max_length - total_length
63
+
64
+ # append padding to the max_length
65
+ if "attn_mask_startend_row_indices" in batch_records [0 ]:
66
+ # attn_mask_startend_row_indices is a list of row indices `0`,
67
+ # which indicates that all tokens are masked.
68
+ batch_records .append (
69
+ {
70
+ "input_ids" : [pad_token ] * reserved_length ,
71
+ "labels" : [- 100 ] * reserved_length ,
72
+ "attn_mask_startend_row_indices" : [0 ] * reserved_length ,
73
+ }
74
+ )
75
+ elif "attention_mask" in batch_records [0 ]:
76
+ # attention_mask is a fullly masked attention matrix (all False)
77
+ # which indicates that all tokens are masked.
78
+ batch_records .append (
79
+ {
80
+ "input_ids" : [pad_token ] * reserved_length ,
81
+ "labels" : [- 100 ] * reserved_length ,
82
+ "attention_mask" : np .zeros ((reserved_length , reserved_length ), dtype = bool ),
83
+ }
84
+ )
85
+
86
+ return batch_records
87
+
88
+ @classmethod
89
+ def _pad_batch_records (cls , batch_records , max_length ):
90
+ batch_records = cls ._pad_batch_records_to_max_length (batch_records , max_length )
91
+
57
92
# Only consider supported input keys
58
93
input_keys = [key for key in batch_records [0 ].keys () if key in cls .supported_input_keys ]
59
94
if "attn_mask_startend_row_indices" not in input_keys and "attention_mask" not in input_keys :
@@ -122,7 +157,7 @@ def _create_zero_padding_data(self, data):
122
157
cur_len_so_far += len (record ["input_ids" ])
123
158
else :
124
159
# exceed max length
125
- padded_list = self ._pad_batch_records (batch_records )
160
+ padded_list = self ._pad_batch_records (batch_records , self . max_length )
126
161
total_data .append (padded_list )
127
162
# reset
128
163
batch_records = []
@@ -133,7 +168,7 @@ def _create_zero_padding_data(self, data):
133
168
134
169
# remaining data
135
170
if batch_records :
136
- padded_list = self ._pad_batch_records (batch_records )
171
+ padded_list = self ._pad_batch_records (batch_records , self . max_length )
137
172
total_data .append (padded_list )
138
173
else :
139
174
examples = []
@@ -150,15 +185,15 @@ def _create_zero_padding_data(self, data):
150
185
generate_packs = generate_greedy_packs (examples , self .max_length )
151
186
for batch_records in generate_packs :
152
187
if len (batch_records ) > 0 :
153
- padded_list = self ._pad_batch_records (batch_records )
188
+ padded_list = self ._pad_batch_records (batch_records , self . max_length )
154
189
total_data .append (padded_list )
155
190
examples = [record ]
156
191
i = 1
157
192
if len (examples ) > 0 :
158
193
generate_packs = generate_greedy_packs (examples , self .max_length )
159
194
for batch_records in generate_packs :
160
195
if len (batch_records ) > 0 :
161
- padded_list = self ._pad_batch_records (batch_records )
196
+ padded_list = self ._pad_batch_records (batch_records , self . max_length )
162
197
total_data .append (padded_list )
163
198
164
199
return total_data
@@ -190,7 +225,7 @@ def __iter__(self):
190
225
cur_len_so_far += len (record ["input_ids" ])
191
226
else :
192
227
# exceed max length
193
- padded_list = self ._pad_batch_records (batch_records )
228
+ padded_list = self ._pad_batch_records (batch_records , self . max_length )
194
229
yield padded_list
195
230
# reset
196
231
batch_records = []
@@ -200,7 +235,7 @@ def __iter__(self):
200
235
self .zero_padding_global_step += 1
201
236
cur_len_so_far += len (record ["input_ids" ])
202
237
if batch_records :
203
- padded_list = self ._pad_batch_records (batch_records )
238
+ padded_list = self ._pad_batch_records (batch_records , self . max_length )
204
239
yield padded_list
205
240
else :
206
241
examples = []
@@ -218,7 +253,7 @@ def __iter__(self):
218
253
generate_packs = generate_greedy_packs (examples , self .max_length )
219
254
for batch_records in generate_packs :
220
255
if len (batch_records ) > 0 :
221
- padded_list = self ._pad_batch_records (batch_records )
256
+ padded_list = self ._pad_batch_records (batch_records , self . max_length )
222
257
yield padded_list
223
258
examples = [record ]
224
259
self .zero_padding_global_step += 1
@@ -227,5 +262,5 @@ def __iter__(self):
227
262
generate_packs = generate_greedy_packs (examples , self .max_length )
228
263
for batch_records in generate_packs :
229
264
if len (batch_records ) > 0 :
230
- padded_list = self ._pad_batch_records (batch_records )
265
+ padded_list = self ._pad_batch_records (batch_records , self . max_length )
231
266
yield padded_list
0 commit comments