Skip to content

Commit fa282f9

Browse files
Add longlora for llama (#115)
1 parent a4ebea1 commit fa282f9

File tree

10 files changed

+586
-24
lines changed

10 files changed

+586
-24
lines changed

examples/pytorch/llm/src/llm_sft.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from utils import (SftArguments, dataset_map, get_dataset, get_model_tokenizer,
1212
get_preprocess)
1313

14-
from swift import (LoraConfig, LoRAConfig, Seq2SeqTrainer,
15-
Seq2SeqTrainingArguments, Swift, get_logger)
14+
from swift import (LongLoRAConfig, LongLoRAModelType, LoraConfig, LoRAConfig,
15+
Seq2SeqTrainer, Seq2SeqTrainingArguments, Swift, get_logger)
1616
from swift.utils import (add_version_to_work_dir, broadcast_string,
1717
check_json_format, compute_nlg_metrics,
1818
data_collate_fn, find_all_linear_for_lora,
@@ -54,27 +54,40 @@ def llm_sft(args: SftArguments) -> None:
5454
args.model_type, torch_dtype=args.torch_dtype, **kwargs)
5555

5656
# ### Preparing LoRA
57-
if args.sft_type == 'lora':
57+
if args.sft_type == 'lora' or args.sft_type == 'longlora':
5858
if args.resume_from_checkpoint is None:
5959
if 'ALL' in args.lora_target_modules:
6060
assert len(args.lora_target_modules) == 1
6161
args.lora_target_modules = find_all_linear_for_lora(
6262
model, args.quantization_bit, args.model_type)
6363
logger.info(
6464
f'Setting lora_target_modules: {args.lora_target_modules}')
65-
lora_kwargs = {}
66-
if args.tuner_bankend == 'peft':
67-
global LoRAConfig
68-
LoRAConfig = LoraConfig
69-
lora_kwargs['task_type'] = 'CAUSAL_LM'
70-
lora_config = LoRAConfig(
71-
r=args.lora_rank,
72-
target_modules=args.lora_target_modules,
73-
lora_alpha=args.lora_alpha,
74-
lora_dropout=args.lora_dropout_p,
75-
**lora_kwargs)
76-
model = Swift.prepare_model(model, lora_config)
77-
logger.info(f'lora_config: {lora_config}')
65+
if args.sft_type == 'lora':
66+
lora_kwargs = {}
67+
if args.tuner_bankend == 'peft':
68+
global LoRAConfig
69+
LoRAConfig = LoraConfig
70+
lora_kwargs['task_type'] = 'CAUSAL_LM'
71+
lora_config = LoRAConfig(
72+
r=args.lora_rank,
73+
target_modules=args.lora_target_modules,
74+
lora_alpha=args.lora_alpha,
75+
lora_dropout=args.lora_dropout_p,
76+
**lora_kwargs)
77+
model = Swift.prepare_model(model, lora_config)
78+
logger.info(f'lora_config: {lora_config}')
79+
elif args.sft_type == 'longlora':
80+
assert args.tuner_bankend != 'peft'
81+
assert LongLoRAModelType.LLAMA in args.model_type
82+
longlora_config = LongLoRAConfig(
83+
r=args.lora_rank,
84+
target_modules=args.lora_target_modules,
85+
lora_alpha=args.lora_alpha,
86+
lora_dropout=args.lora_dropout_p,
87+
model_type=LongLoRAModelType.LLAMA,
88+
use_flash_attn=args.use_flash_attn)
89+
model = Swift.prepare_model(model, longlora_config)
90+
logger.info(f'longlora_config: {longlora_config}')
7891
else:
7992
model = Swift.from_pretrained(
8093
model, args.resume_from_checkpoint, is_trainable=True)
@@ -109,7 +122,10 @@ def llm_sft(args: SftArguments) -> None:
109122
if args.test_oom_error:
110123
train_dataset = sort_by_max_length(train_dataset, 20000)
111124
# Data analysis
112-
data_collator = partial(data_collate_fn, tokenizer=tokenizer)
125+
data_collator = partial(
126+
data_collate_fn,
127+
tokenizer=tokenizer,
128+
padding_to=args.max_length if args.sft_type == 'longlora' else None)
113129
print_example(train_dataset[0], tokenizer)
114130
stat_dataset(train_dataset)
115131
stat_dataset(val_dataset)

examples/pytorch/llm/src/utils/argument.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class SftArguments:
2424
default=ModelType.qwen_7b_chat,
2525
metadata={'choices': list(MODEL_MAPPING.keys())})
2626
sft_type: str = field(
27-
default='lora', metadata={'choices': ['lora', 'full']})
27+
default='lora', metadata={'choices': ['longlora', 'lora', 'full']})
2828
tuner_bankend: str = field(
2929
default='swift', metadata={'choices': ['swift', 'peft']})
3030
template_type: Optional[str] = field(
@@ -147,7 +147,7 @@ def init_argument(self):
147147
# Initialize in advance
148148
dist.init_process_group(backend=self.ddp_backend)
149149

150-
if self.sft_type == 'lora':
150+
if self.sft_type == 'lora' or self.sft_type == 'longlora':
151151
if self.learning_rate is None:
152152
self.learning_rate = 1e-4
153153
if self.only_save_model is None:
@@ -223,7 +223,7 @@ class InferArguments:
223223
default=ModelType.qwen_7b_chat,
224224
metadata={'choices': list(MODEL_MAPPING.keys())})
225225
sft_type: str = field(
226-
default='lora', metadata={'choices': ['lora', 'full']})
226+
default='lora', metadata={'choices': ['longlora', 'lora', 'full']})
227227
template_type: Optional[str] = field(
228228
default=None, metadata={'choices': list(TEMPLATE_MAPPING.keys())})
229229
ckpt_dir: str = '/path/to/your/vx_xxx/checkpoint-xxx'

swift/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
PrefixTuningConfig, PromptEncoderConfig, PromptLearningConfig,
1414
PromptTuningConfig, get_peft_config, get_peft_model,
1515
get_peft_model_state_dict, Prompt, PromptConfig, PromptModule,
16-
SwiftConfig, SwiftOutput, Swift, SwiftTuners)
16+
SwiftConfig, SwiftOutput, Swift, SwiftTuners, LongLoRAConfig, LongLoRA,
17+
LongLoRAModelType)
1718
from .hub import snapshot_download, push_to_hub, push_to_hub_async, push_to_hub_in_queue
1819
from .trainers import (EvaluationStrategy, FSDPOption, HPSearchBackend,
1920
HubStrategy, IntervalStrategy, SchedulerType,
@@ -38,7 +39,7 @@
3839
'PromptTuningConfig', 'get_peft_config', 'get_peft_model',
3940
'get_peft_model_state_dict', 'Prompt', 'PromptConfig',
4041
'PromptModule', 'SwiftConfig', 'SwiftOutput', 'Swift',
41-
'SwiftTuners'
42+
'SwiftTuners', 'LongLoRAConfig', 'LongLoRA', 'LongLoRAModelType'
4243
],
4344
'trainers': [
4445
'EvaluationStrategy', 'FSDPOption', 'HPSearchBackend',

swift/tuners/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .lora import LoRA, LoRAConfig
1010
from .mapping import SWIFT_MAPPING, SwiftTuners
1111
from .side import Side, SideConfig, SideModule
12+
from .longlora.longlora import LongLoRAModelType, LongLoRAConfig, LongLoRA
1213
from .restuning import ResTuning, ResTuningConfig, ResTuningBypassModule
1314
from .peft import (LoraConfig, PeftConfig, PeftModel, PeftModelForCausalLM,
1415
PeftModelForSeq2SeqLM,
@@ -24,6 +25,8 @@
2425
'adapter': ['Adapter', 'AdapterConfig', 'AdapterModule'],
2526
'base': ['SwiftModel', 'Swift'],
2627
'lora': ['LoRA', 'LoRAConfig'],
28+
'longlora.longlora':
29+
['LongLoRAModelType', 'LongLoRAConfig', 'LongLoRA'],
2730
'mapping': ['SWIFT_MAPPING', 'SwiftTuners'],
2831
'side': ['Side', 'SideConfig', 'SideModule'],
2932
'restuning': ['ResTuning', 'ResTuningConfig', 'ResTuningBypassModule'],

swift/tuners/longlora/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)