Skip to content

Commit 8599a53

Browse files
authored
Add greedy_zero_padding (#8933)
1 parent d2d4d92 commit 8599a53

File tree

5 files changed

+144
-63
lines changed

5 files changed

+144
-63
lines changed

llm/alignment/dpo/dpo_argument.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,10 @@ class DPODataArgument:
6363
default=False,
6464
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
6565
)
66-
greedy_intokens: bool = field(
67-
default=True,
68-
metadata={"help": "Whether apply greedy intokens."},
66+
greedy_zero_padding: bool = field(
67+
default=False,
68+
metadata={"help": "Whether to use Greedy Zero Padding data stream."},
6969
)
70-
buffer_size: int = field(default=500, metadata={"help": "Buffer size for greedy_intokens strategy."})
7170

7271

7372
@dataclass
@@ -87,9 +86,7 @@ class DPOModelArgument:
8786
"help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`."
8887
},
8988
)
90-
flash_mask: bool = field(
91-
default=False, metadata={"help": "Whether to use flash mask in flash attention."}
92-
)
89+
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash mask in flash attention."})
9390
virtual_pp_degree: int = field(
9491
default=1,
9592
metadata={"help": "virtual_pp_degree"},

llm/alignment/dpo/run_dpo.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import os
1818
import sys
1919
import time
20-
import inspect
2120
from functools import partial
2221

2322
import paddle
@@ -30,17 +29,19 @@
3029
get_last_checkpoint,
3130
set_seed,
3231
)
33-
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
32+
from paddlenlp.transformers import (
33+
AutoConfig,
34+
AutoModelForCausalLM,
35+
AutoTokenizer,
36+
LlamaForCausalLM,
37+
LlamaForCausalLMPipe,
38+
)
3439
from paddlenlp.trl import (
3540
DPOTrainer,
3641
calculate_effective_tokens,
3742
preference_collate_fn,
3843
preprocess_preference_data,
3944
)
40-
from paddlenlp.transformers import (
41-
LlamaForCausalLM,
42-
LlamaForCausalLMPipe,
43-
)
4445
from paddlenlp.utils.log import logger
4546

4647
flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]
@@ -132,9 +133,7 @@ def main():
132133
model.set_state_dict(ref_model.state_dict())
133134

134135
if model_args.flash_mask and not model.config.use_flash_attention:
135-
logger.warning(
136-
"`flash_mask` must use with zero padding and flash attention."
137-
)
136+
logger.warning("`flash_mask` must use with zero padding and flash attention.")
138137
model.config.use_flash_attention = True
139138

140139
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
@@ -161,6 +160,7 @@ def main():
161160
train_ds.map(trans_func),
162161
tokenizer=tokenizer,
163162
max_length=data_args.max_seq_len,
163+
greedy_zero_padding=data_args.greedy_zero_padding,
164164
)
165165
if train_ds is not None
166166
else None

llm/run_finetune.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def neft_post_hook(module, input, output):
391391
train_ds,
392392
tokenizer=tokenizer,
393393
max_length=data_args.max_length,
394+
greedy_zero_padding=data_args.greedy_zero_padding,
394395
)
395396
if train_ds is not None
396397
else None
@@ -400,6 +401,7 @@ def neft_post_hook(module, input, output):
400401
ptq_ds,
401402
tokenizer=tokenizer,
402403
max_length=data_args.max_length,
404+
greedy_zero_padding=data_args.greedy_zero_padding,
403405
)
404406
if ptq_ds is not None
405407
else None

llm/utils/argument.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ class DataArgument:
8686
dataset_name_or_path: str = field(default=None, metadata={"help": "Name or path for dataset"})
8787
task_name: str = field(default=None, metadata={"help": "Additional name to select a more specific task."})
8888
zero_padding: bool = field(default=False, metadata={"help": "Whether to use Zero Padding data stream"})
89+
greedy_zero_padding: bool = field(
90+
default=False,
91+
metadata={
92+
"help": "Whether to use Greedy Zero Padding data stream, should be used together with `zero_padding=True`."
93+
},
94+
)
8995
pad_to_multiple_of: int = field(
9096
default=None, metadata={"help": "If set will pad the sequence to a multiple of the provided value."}
9197
)

paddlenlp/datasets/zero_padding_dataset.py

Lines changed: 123 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,27 @@
1717
from scipy.linalg import block_diag
1818

1919

20+
def generate_greedy_packs(examples, max_length):
21+
left_len = np.zeros([len(examples)]) - 1
22+
left_len[0] = max_length # At the beginning, only the first pack is valid.
23+
generate_packs = [[] for i in range(len(examples))]
24+
index, left_index = 0, 0
25+
26+
while index < len(examples):
27+
record = examples[index]
28+
max_left_index = left_len.argmax()
29+
# Put the current sequence into the largest left space valid pack.
30+
if len(record["input_ids"]) <= left_len[max_left_index]:
31+
generate_packs[max_left_index].append(record)
32+
left_len[max_left_index] -= len(record["input_ids"])
33+
index += 1
34+
else:
35+
left_index += 1
36+
left_len[left_index] = max_length
37+
38+
return generate_packs
39+
40+
2041
class ZeroPadding:
2142
required_output_keys = ["input_ids", "labels", "attention_mask"]
2243
# Only supported the following keys for ZeroPadding. Keys outside of the set will be ignored.
@@ -80,38 +101,66 @@ def _pad_batch_records(cls, batch_records):
80101

81102

82103
class ZeroPaddingMapDataset(ZeroPadding, Dataset):
83-
def __init__(self, data, tokenizer, max_length):
104+
def __init__(self, data, tokenizer, max_length, greedy_zero_padding=False):
84105
self.tokenizer = tokenizer
85106
self.max_length = max_length
107+
self.greedy_zero_padding = greedy_zero_padding
86108
self.new_data = self._create_zero_padding_data(data)
87109

88110
def _create_zero_padding_data(self, data):
89-
batch_records, max_len = [], 0
90-
cur_len_so_far = 0
91-
92111
total_data = []
93-
for i in range(len(data)):
94-
record = data[i]
95-
max_len = max(max_len, len(record["input_ids"]))
96-
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
97-
if to_append:
98-
batch_records.append(record)
99-
cur_len_so_far += len(record["input_ids"])
100-
else:
101-
# exceed max length
112+
if not self.greedy_zero_padding:
113+
batch_records = []
114+
cur_len_so_far = 0
115+
for i in range(len(data)):
116+
record = data[i]
117+
if len(record["input_ids"]) > self.max_length:
118+
continue
119+
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
120+
if to_append:
121+
batch_records.append(record)
122+
cur_len_so_far += len(record["input_ids"])
123+
else:
124+
# exceed max length
125+
padded_list = self._pad_batch_records(batch_records)
126+
total_data.append(padded_list)
127+
# reset
128+
batch_records = []
129+
cur_len_so_far = 0
130+
# append current data
131+
batch_records.append(record)
132+
cur_len_so_far += len(record["input_ids"])
133+
134+
# remaining data
135+
if batch_records:
102136
padded_list = self._pad_batch_records(batch_records)
103137
total_data.append(padded_list)
104-
# reset
105-
batch_records, max_len = [], 0
106-
cur_len_so_far = 0
107-
# append current data
108-
batch_records.append(record)
109-
cur_len_so_far += len(record["input_ids"])
110-
111-
# remaining data
112-
if batch_records:
113-
padded_list = self._pad_batch_records(batch_records)
114-
total_data.append(padded_list)
138+
else:
139+
examples = []
140+
buffer_size = 500
141+
i = 0
142+
for record in data:
143+
if len(record["input_ids"]) > self.max_length:
144+
continue
145+
if i < buffer_size:
146+
examples.append(record)
147+
i += 1
148+
else:
149+
# Running greedy strategy in examples.
150+
generate_packs = generate_greedy_packs(examples, self.max_length)
151+
for batch_records in generate_packs:
152+
if len(batch_records) > 0:
153+
padded_list = self._pad_batch_records(batch_records)
154+
total_data.append(padded_list)
155+
examples = [record]
156+
i = 1
157+
if len(examples) > 0:
158+
generate_packs = generate_greedy_packs(examples, self.max_length)
159+
for batch_records in generate_packs:
160+
if len(batch_records) > 0:
161+
padded_list = self._pad_batch_records(batch_records)
162+
total_data.append(padded_list)
163+
115164
return total_data
116165

117166
def __getitem__(self, idx):
@@ -122,34 +171,61 @@ def __len__(self):
122171

123172

124173
class ZeroPaddingIterableDataset(ZeroPadding, IterableDataset):
125-
def __init__(self, data, tokenizer, max_length):
126-
174+
def __init__(self, data, tokenizer, max_length, greedy_zero_padding=False):
127175
self.data = data
128176
self.tokenizer = tokenizer
129177
self.max_length = max_length
130178
self.zero_padding_global_step = 0
179+
self.greedy_zero_padding = greedy_zero_padding
131180

132181
def __iter__(self):
133-
batch_records, max_len = [], 0
134-
cur_len_so_far = 0
135-
for record in self.data:
136-
max_len = max(max_len, len(record["input_ids"]))
137-
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
138-
if to_append:
139-
batch_records.append(record)
140-
self.zero_padding_global_step += 1
141-
cur_len_so_far += len(record["input_ids"])
142-
else:
143-
# exceed max length
182+
if not self.greedy_zero_padding:
183+
batch_records = []
184+
cur_len_so_far = 0
185+
for record in self.data:
186+
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
187+
if to_append:
188+
batch_records.append(record)
189+
self.zero_padding_global_step += 1
190+
cur_len_so_far += len(record["input_ids"])
191+
else:
192+
# exceed max length
193+
padded_list = self._pad_batch_records(batch_records)
194+
yield padded_list
195+
# reset
196+
batch_records = []
197+
cur_len_so_far = 0
198+
# append current data
199+
batch_records.append(record)
200+
self.zero_padding_global_step += 1
201+
cur_len_so_far += len(record["input_ids"])
202+
if batch_records:
144203
padded_list = self._pad_batch_records(batch_records)
145204
yield padded_list
146-
# reset
147-
batch_records, max_len = [], 0
148-
cur_len_so_far = 0
149-
# append current data
150-
batch_records.append(record)
151-
self.zero_padding_global_step += 1
152-
cur_len_so_far += len(record["input_ids"])
153-
if batch_records:
154-
padded_list = self._pad_batch_records(batch_records)
155-
yield padded_list
205+
else:
206+
examples = []
207+
buffer_size = 500
208+
i = 0
209+
for record in self.data:
210+
if len(record["input_ids"]) > self.max_length:
211+
continue
212+
if i < buffer_size:
213+
examples.append(record)
214+
self.zero_padding_global_step += 1
215+
i += 1
216+
else:
217+
# Running greedy strategy in examples.
218+
generate_packs = generate_greedy_packs(examples, self.max_length)
219+
for batch_records in generate_packs:
220+
if len(batch_records) > 0:
221+
padded_list = self._pad_batch_records(batch_records)
222+
yield padded_list
223+
examples = [record]
224+
self.zero_padding_global_step += 1
225+
i = 1
226+
if len(examples) > 0:
227+
generate_packs = generate_greedy_packs(examples, self.max_length)
228+
for batch_records in generate_packs:
229+
if len(batch_records) > 0:
230+
padded_list = self._pad_batch_records(batch_records)
231+
yield padded_list

0 commit comments

Comments
 (0)