Skip to content

Commit 98033fa

Browse files
Fix/longlora (#294)
1 parent acbd343 commit 98033fa

File tree

4 files changed

+28
-25
lines changed

4 files changed

+28
-25
lines changed

swift/llm/infer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def prepare_model_template(
138138
logger.info(f'generation_config: {generation_config}')
139139
set_generation_config(model, generation_config)
140140
# Preparing LoRA
141-
if args.sft_type == 'lora' and args.ckpt_dir is not None:
141+
if args.sft_type in ('lora', 'qalora',
142+
'longlora') and args.ckpt_dir is not None:
142143
model = Swift.from_pretrained(
143144
model, args.ckpt_dir, inference_mode=True)
144145

swift/trainers/trainers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,12 @@ def prediction_step(
104104
generate_inputs = inputs.copy()
105105
if has_labels:
106106
_labels = inputs['labels'][0]
107-
n_mask = lower_bound(0, len(_labels), lambda i: _labels[i] != -100)
107+
n_mask = 0
108+
for i in range(len(_labels)):
109+
if _labels[i] != -100:
110+
n_mask = i
111+
break
112+
108113
for k in ['input_ids', 'attention_mask']:
109114
generate_inputs[k] = generate_inputs[k][:, :n_mask]
110115
generate_inputs['labels'] = generate_inputs['labels'][:, n_mask:]

swift/tuners/longlora/llama.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
from transformers.models.llama.modeling_llama import (apply_rotary_pos_emb,
1515
repeat_kv, rotate_half)
1616

17+
from swift.utils import get_logger
18+
19+
logger = get_logger()
20+
1721

1822
def forward_flashattn(
1923
self,
@@ -306,8 +310,8 @@ def forward_flashattn_inference(
306310
)) # noqa
307311

308312
kv_seq_len = k.shape[1]
309-
if past_key_value is not None:
310-
past_kv_len = past_key_value[0].shape[2]
313+
if past_key_value is not None and len(past_key_value):
314+
past_kv_len = past_key_value.seen_tokens
311315
kv_seq_len += past_kv_len
312316

313317
cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
@@ -316,15 +320,13 @@ def forward_flashattn_inference(
316320
q = q.transpose(1, 2)
317321
k = k.transpose(1, 2)
318322

319-
if past_key_value is not None:
320-
assert (flash_attn_version >=
321-
'2.1.0'), 'past_key_value support requires flash-attn >= 2.1.0'
322-
# reuse k, v
323-
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
324-
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
325-
326-
past_key_value = (k.transpose(1, 2),
327-
v.transpose(1, 2)) if use_cache else None
323+
if use_cache:
324+
k, v = past_key_value.update(
325+
k.transpose(1, 2), v.transpose(1, 2), layer_idx=self.idx)
326+
k = k.transpose(1, 2)
327+
v = v.transpose(1, 2)
328+
else:
329+
past_key_value = None
328330

329331
if attention_mask is None:
330332
output = flash_attn_func(
@@ -405,12 +407,13 @@ def forward_flashattn_inference_s2_attn(
405407

406408
def patch_llama_forward(model: nn.Module, forward_function) -> None:
407409
# Compatible with transformers device_map
408-
for m in model.model.layers:
410+
for idx, m in enumerate(model.model.layers):
409411
new_forward = MethodType(forward_function, m.self_attn)
410412
if hasattr(model, '_old_forward'):
411413
m.self_attn._old_forward = new_forward
412414
else:
413415
m.self_attn.forward = new_forward
416+
m.self_attn.idx = idx
414417

415418

416419
def replace_llama_attn(model: nn.Module, use_flash_attn=True):
@@ -425,4 +428,7 @@ def replace_llama_attn(model: nn.Module, use_flash_attn=True):
425428
_prepare_decoder_attention_mask)
426429
patch_llama_forward(model, forward_flashattn_inference_s2_attn)
427430
else:
431+
logger.warn(
432+
'The source code of LongLoRA without flash '
433+
'attention may has some problems, please use with careful.')
428434
patch_llama_forward(model, forward_noflashattn)

swift/tuners/longlora/longlora.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from swift import LoRA, LoRAConfig, SwiftOutput
1010
from swift.tuners.lora import lora_state_dict, mark_lora_as_trainable
11+
from swift.tuners.lora_layers import LoraModel
1112

1213

1314
class LongLoRAModelType:
@@ -59,17 +60,7 @@ class LongLoRA(LoRA):
5960
def prepare_model(model: nn.Module, config: LongLoRAConfig,
6061
adapter_name: str):
6162
"""Prepare a model with `LongLoRAConfig`"""
62-
LoRA._dynamic_patch_lora(
63-
model,
64-
target_modules=config.target_modules,
65-
r=config.r,
66-
adapter_name=adapter_name,
67-
lora_alpha=config.lora_alpha,
68-
lora_dropout=config.lora_dropout,
69-
merge_weights=config.merge_weights,
70-
use_merged_linear=config.use_merged_linear,
71-
enable_lora=config.enable_lora,
72-
fan_in_fan_out=config.fan_in_fan_out)
63+
LoraModel(model, config, adapter_name)
7364

7465
def state_dict_callback(state_dict, adapter_name):
7566
_state_dict = lora_state_dict(state_dict, adapter_name,

0 commit comments

Comments
 (0)