Skip to content

Commit 939a221

Browse files
Support longlora for transformers 4.38 (#456)
1 parent c13cb37 commit 939a221

File tree

7 files changed

+364
-355
lines changed

7 files changed

+364
-355
lines changed

docs/source/LLM/支持的模型和数据集.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@
275275
|generated-chat-zh|[AI-ModelScope/generated_chat_0.4M](https://modelscope.cn/datasets/AI-ModelScope/generated_chat_0.4M/summary)|396004|0|273.3±52.0, min=32, max=873|chat, character-dialogue|
276276
|cls-fudan-news-zh|[damo/zh_cls_fudan-news](https://modelscope.cn/datasets/damo/zh_cls_fudan-news/summary)|4959|0|3234.4±2547.5, min=91, max=19548|chat, classification|
277277
|ner-jave-zh|[damo/zh_ner-JAVE](https://modelscope.cn/datasets/damo/zh_ner-JAVE/summary)|1266|0|118.3±45.5, min=44, max=223|chat, ner|
278+
|long-alpaca-12k|[AI-ModelScope/LongAlpaca-12k](https://modelscope.cn/datasets/AI-ModelScope/LongAlpaca-12k/summary)|11998|0|9619.0±8295.8, min=36, max=78925|longlora, QA|
278279
|coco-en|[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption/summary)|414113|40504|298.8±2.8, min=294, max=351|chat, multi-modal, vision|
279280
|🔥coco-mini-en|[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption/summary)|20000|200|298.8±2.8, min=294, max=339|chat, multi-modal, vision|
280281
|🔥coco-mini-en-2|[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption/summary)|20000|200|36.8±2.8, min=32, max=77|chat, multi-modal, vision|

swift/llm/sft.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,9 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
182182
if val_dataset is not None:
183183
val_dataset = LazyLLMDataset(val_dataset, template)
184184

185-
padding_to = args.max_length if args.sft_type == 'longlora' else None
186-
data_collator = partial(template.data_collator, padding_to=padding_to)
185+
pad_to_multiple_of = 8 if args.sft_type == 'longlora' else None
186+
data_collator = partial(
187+
template.data_collator, pad_to_multiple_of=pad_to_multiple_of)
187188

188189
# Trainer
189190
logger.info(f'training_args: {training_args}')

swift/llm/tuner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def prepare_model(model, args: SftArguments):
9696
longlora_config = LongLoRAConfig(
9797
lora_dtype=args.lora_dtype,
9898
model_type=LongLoRAModelType.LLAMA,
99-
use_flash_attn=args.use_flash_attn,
10099
**lora_kwargs)
101100
model = Swift.prepare_model(model, longlora_config)
102101
logger.info(f'longlora_config: {longlora_config}')

swift/llm/utils/dataset.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class DatasetName:
110110
# example dataset for specific model
111111
cls_fudan_news_zh = 'cls-fudan-news-zh' # seqgpt-560m
112112
ner_java_zh = 'ner-jave-zh' # seqgpt-560m
113+
long_alpaca_12k = 'long-alpaca-12k'
113114

114115
# multi-modal
115116
# for qwen-vl
@@ -457,6 +458,24 @@ def _repair_ms_bench(conversations: str) -> Dict[str, str]:
457458
return conversations
458459

459460

461+
def long_alpaca_preprocessor(dataset: HfDataset):
462+
463+
def map_row(row):
464+
response = row['response']
465+
if response and response.startswith('Answer:'):
466+
response = response[len('Answer:') + 1:].strip()
467+
return {'query': row['query'], 'response': response}
468+
return dataset.rename_columns({'instruction': 'query', 'output': 'response'})\
469+
.remove_columns(['input', 'file']).map(map_row).filter(lambda row: row['response'] is not None)
470+
471+
472+
register_dataset(
473+
DatasetName.long_alpaca_12k,
474+
'AI-ModelScope/LongAlpaca-12k', ['train'], [],
475+
long_alpaca_preprocessor,
476+
get_dataset_from_repo,
477+
tags=['longlora', 'QA'])
478+
460479
register_dataset(
461480
DatasetName.ms_bench,
462481
'iic/ms_bench', ['train'], [],

swift/llm/utils/template.py

Lines changed: 35 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import torch.nn.functional as F
99
from torch import Tensor
1010
from torch.nn.utils.rnn import pad_sequence
11-
from transformers import PreTrainedTokenizerBase, StoppingCriteria
11+
from transformers import (DataCollatorForSeq2Seq, PreTrainedTokenizerBase,
12+
StoppingCriteria)
1213

1314
from swift.llm.agent.utils import calculate_loss_scale
1415

@@ -186,6 +187,10 @@ def _init_template(self,
186187
self.truncation_strategy = truncation_strategy
187188
self.model = kwargs.get('model', None)
188189
self.use_loss_scale = kwargs.get('use_loss_scale', False)
190+
self._data_collator = DataCollatorForSeq2Seq(
191+
tokenizer=self.tokenizer,
192+
label_pad_token_id=self.tokenizer.pad_token_id,
193+
)
189194
for key in [
190195
'prefix', 'prompt', 'chat_sep', 'suffix', 'prefix_has_system'
191196
]:
@@ -386,55 +391,28 @@ def concat_tokenizer_kwargs(
386391
assert len(old_tokenizer_kwargs) == 0
387392
return curr_tokenizer_kwargs
388393

389-
def data_collator(self,
390-
batch: List[Dict[str, Any]],
391-
padding_to: Optional[int] = None) -> Dict[str, Any]:
394+
def data_collator(
395+
self,
396+
batch: List[Dict[str, Any]],
397+
pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]:
392398
"""
393399
Args:
394400
batch(`List[Dict[str, Any]]`): The input data in batch
395-
padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
396-
will be padded to the `longest`
401+
pad_to_multiple_of(`int`, optional): Whether padding to the multiple of an integer value.
397402
"""
398-
tokenizer = self.tokenizer
399-
assert tokenizer.pad_token_id is not None
400-
input_ids = [torch.tensor(b['input_ids']) for b in batch]
401-
labels = [torch.tensor(b['labels']) for b in batch]
402-
loss_scale = [torch.tensor(b['loss_scale'])
403+
self._data_collator.pad_to_multiple_of = pad_to_multiple_of
404+
if pad_to_multiple_of:
405+
self.tokenizer.padding_side = 'right'
406+
loss_scale = [torch.tensor(b.pop('loss_scale'))
403407
for b in batch] if 'loss_scale' in batch[0] else None
404-
attention_mask = [
405-
torch.ones(len(input_ids[i]), dtype=torch.int64)
406-
for i in range(len(input_ids))
407-
]
408-
409-
if padding_to is not None:
410-
padding_len = padding_to - input_ids[0].shape[-1]
411-
if padding_len > 0:
412-
input_ids[0] = F.pad(input_ids[0], (0, padding_len),
413-
'constant', tokenizer.pad_token_id)
414-
attention_mask[0] = F.pad(attention_mask[0], (0, padding_len),
415-
'constant', 0)
416-
labels[0] = F.pad(labels[0], (0, padding_len), 'constant',
417-
-100)
418-
if loss_scale:
419-
loss_scale[0] = F.pad(
420-
loss_scale[0], (0, padding_to - labels[0].shape[-1]),
421-
'constant', 0.)
422-
423-
input_ids = pad_sequence(
424-
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
425-
attention_mask = pad_sequence(
426-
attention_mask, batch_first=True, padding_value=0)
408+
res = self._data_collator(batch, return_tensors='pt')
409+
padding_to = res['input_ids'].shape[1]
427410
if loss_scale:
411+
loss_scale[0] = F.pad(loss_scale[0],
412+
(0, padding_to - loss_scale[0].shape[-1]),
413+
'constant', 0.)
428414
loss_scale = pad_sequence(
429415
loss_scale, batch_first=True, padding_value=0.)
430-
labels = pad_sequence(labels, batch_first=True, padding_value=-100)
431-
432-
res = {
433-
'input_ids': input_ids,
434-
'attention_mask': attention_mask,
435-
'labels': labels,
436-
}
437-
if loss_scale is not None:
438416
res['loss_scale'] = loss_scale
439417
return res
440418

@@ -601,10 +579,11 @@ def encode(
601579
inputs['images'] = image_tensor.to(model.dtype)
602580
return inputs, {}
603581

604-
def data_collator(self,
605-
batch: List[Dict[str, Any]],
606-
padding_to: Optional[int] = None) -> Dict[str, Any]:
607-
res = super().data_collator(batch, padding_to)
582+
def data_collator(
583+
self,
584+
batch: List[Dict[str, Any]],
585+
pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]:
586+
res = super().data_collator(batch, pad_to_multiple_of)
608587
res['images'] = torch.concat([b['images'] for b in batch])
609588
return res
610589

@@ -908,10 +887,11 @@ def encode(
908887
inputs['image_sizes'] = image_sizes
909888
return inputs, {}
910889

911-
def data_collator(self,
912-
batch: List[Dict[str, Any]],
913-
padding_to: Optional[int] = None) -> Dict[str, Any]:
914-
res = super().data_collator(batch, padding_to)
890+
def data_collator(
891+
self,
892+
batch: List[Dict[str, Any]],
893+
pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]:
894+
res = super().data_collator(batch, pad_to_multiple_of)
915895
res['images'] = torch.concat([b['images'] for b in batch])
916896
res['image_sizes'] = sum([b['image_sizes'] for b in batch], start=[])
917897
return res
@@ -1093,10 +1073,11 @@ def encode(
10931073
len(inputs['input_ids']) - len(token_type_ids))
10941074
return inputs, {}
10951075

1096-
def data_collator(self,
1097-
batch: List[Dict[str, Any]],
1098-
padding_to: Optional[int] = None) -> Dict[str, Any]:
1099-
res = super().data_collator(batch, padding_to)
1076+
def data_collator(
1077+
self,
1078+
batch: List[Dict[str, Any]],
1079+
pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]:
1080+
res = super().data_collator(batch, pad_to_multiple_of)
11001081
is_cogagent = 'cross_images' in batch[0]
11011082
keys = ['images', 'cross_images'] if is_cogagent else ['images']
11021083
for key in keys:

0 commit comments

Comments
 (0)