Skip to content

Commit 30a2ac6

Browse files
authored
[ZeroPadding] padding to max_length for sequence parallel (#8973)
* fix zero_padding for sequence parallel
1 parent d505a97 commit 30a2ac6

File tree

1 file changed

+44
-9
lines changed

1 file changed

+44
-9
lines changed

paddlenlp/datasets/zero_padding_dataset.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,42 @@ class ZeroPadding:
5353
]
5454

5555
@classmethod
56-
def _pad_batch_records(cls, batch_records):
56+
def _pad_batch_records_to_max_length(cls, batch_records, max_length, pad_token=0):
57+
# confirm the at least one item in the pack
58+
if len(batch_records) == 0:
59+
return batch_records
60+
# count all records total length
61+
total_length = sum([len(record["input_ids"]) for record in batch_records])
62+
reserved_length = max_length - total_length
63+
64+
# append padding to the max_length
65+
if "attn_mask_startend_row_indices" in batch_records[0]:
66+
# attn_mask_startend_row_indices is a list of row indices `0`,
67+
# which indicates that all tokens are masked.
68+
batch_records.append(
69+
{
70+
"input_ids": [pad_token] * reserved_length,
71+
"labels": [-100] * reserved_length,
72+
"attn_mask_startend_row_indices": [0] * reserved_length,
73+
}
74+
)
75+
elif "attention_mask" in batch_records[0]:
76+
# attention_mask is a fullly masked attention matrix (all False)
77+
# which indicates that all tokens are masked.
78+
batch_records.append(
79+
{
80+
"input_ids": [pad_token] * reserved_length,
81+
"labels": [-100] * reserved_length,
82+
"attention_mask": np.zeros((reserved_length, reserved_length), dtype=bool),
83+
}
84+
)
85+
86+
return batch_records
87+
88+
@classmethod
89+
def _pad_batch_records(cls, batch_records, max_length):
90+
batch_records = cls._pad_batch_records_to_max_length(batch_records, max_length)
91+
5792
# Only consider supported input keys
5893
input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys]
5994
if "attn_mask_startend_row_indices" not in input_keys and "attention_mask" not in input_keys:
@@ -122,7 +157,7 @@ def _create_zero_padding_data(self, data):
122157
cur_len_so_far += len(record["input_ids"])
123158
else:
124159
# exceed max length
125-
padded_list = self._pad_batch_records(batch_records)
160+
padded_list = self._pad_batch_records(batch_records, self.max_length)
126161
total_data.append(padded_list)
127162
# reset
128163
batch_records = []
@@ -133,7 +168,7 @@ def _create_zero_padding_data(self, data):
133168

134169
# remaining data
135170
if batch_records:
136-
padded_list = self._pad_batch_records(batch_records)
171+
padded_list = self._pad_batch_records(batch_records, self.max_length)
137172
total_data.append(padded_list)
138173
else:
139174
examples = []
@@ -150,15 +185,15 @@ def _create_zero_padding_data(self, data):
150185
generate_packs = generate_greedy_packs(examples, self.max_length)
151186
for batch_records in generate_packs:
152187
if len(batch_records) > 0:
153-
padded_list = self._pad_batch_records(batch_records)
188+
padded_list = self._pad_batch_records(batch_records, self.max_length)
154189
total_data.append(padded_list)
155190
examples = [record]
156191
i = 1
157192
if len(examples) > 0:
158193
generate_packs = generate_greedy_packs(examples, self.max_length)
159194
for batch_records in generate_packs:
160195
if len(batch_records) > 0:
161-
padded_list = self._pad_batch_records(batch_records)
196+
padded_list = self._pad_batch_records(batch_records, self.max_length)
162197
total_data.append(padded_list)
163198

164199
return total_data
@@ -190,7 +225,7 @@ def __iter__(self):
190225
cur_len_so_far += len(record["input_ids"])
191226
else:
192227
# exceed max length
193-
padded_list = self._pad_batch_records(batch_records)
228+
padded_list = self._pad_batch_records(batch_records, self.max_length)
194229
yield padded_list
195230
# reset
196231
batch_records = []
@@ -200,7 +235,7 @@ def __iter__(self):
200235
self.zero_padding_global_step += 1
201236
cur_len_so_far += len(record["input_ids"])
202237
if batch_records:
203-
padded_list = self._pad_batch_records(batch_records)
238+
padded_list = self._pad_batch_records(batch_records, self.max_length)
204239
yield padded_list
205240
else:
206241
examples = []
@@ -218,7 +253,7 @@ def __iter__(self):
218253
generate_packs = generate_greedy_packs(examples, self.max_length)
219254
for batch_records in generate_packs:
220255
if len(batch_records) > 0:
221-
padded_list = self._pad_batch_records(batch_records)
256+
padded_list = self._pad_batch_records(batch_records, self.max_length)
222257
yield padded_list
223258
examples = [record]
224259
self.zero_padding_global_step += 1
@@ -227,5 +262,5 @@ def __iter__(self):
227262
generate_packs = generate_greedy_packs(examples, self.max_length)
228263
for batch_records in generate_packs:
229264
if len(batch_records) > 0:
230-
padded_list = self._pad_batch_records(batch_records)
265+
padded_list = self._pad_batch_records(batch_records, self.max_length)
231266
yield padded_list

0 commit comments

Comments
 (0)