Skip to content

Commit fb8f2be

Browse files
authored
[bloom] Add kv cache support for flash attention & fix bugs (#7735)
* Add kv cache support for flash attention * Update chatglm flash attention version check * Add test for flash attention * Fix unitest bug * Add flash attention to predictor * Add flash attention2 * Add flash attention unitests * fix prefix decoder * remove unused comments * Update unitest * Update unitest
1 parent fda20a7 commit fb8f2be

File tree

10 files changed

+67
-57
lines changed

10 files changed

+67
-57
lines changed

llm/predictor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ class PredictorArgument:
7575
"help": "the decoding strategy of generation, which should be one of ['sampling', 'greedy_search', 'beam_search']. Default to sampling"
7676
},
7777
)
78+
use_flash_attention: bool = field(
79+
default=False,
80+
metadata={"help": "Whether to use flash attention"},
81+
)
7882

7983
mode: str = field(
8084
default="dynamic", metadata={"help": "the type of predictor, it should be one of [dynamic, static]"}
@@ -241,6 +245,7 @@ def __init__(
241245
if self.model is None:
242246
self.model = AutoModelForCausalLM.from_pretrained(
243247
config.model_name_or_path,
248+
use_flash_attention=config.use_flash_attention,
244249
dtype=dtype,
245250
tensor_parallel_degree=self.tensor_parallel_degree,
246251
tensor_parallel_rank=self.tensor_parallel_rank,
@@ -685,7 +690,9 @@ def create_predictor(
685690
tensor_parallel_degree: int = 1,
686691
tensor_parallel_rank: int = 0,
687692
):
688-
tokenizer = AutoTokenizer.from_pretrained(predictor_args.model_name_or_path)
693+
tokenizer = AutoTokenizer.from_pretrained(
694+
predictor_args.model_name_or_path,
695+
)
689696
# init chat_template for tokenizer
690697
init_chat_template(tokenizer, predictor_args.model_name_or_path, predictor_args.chat_template)
691698

@@ -727,6 +734,7 @@ def create_predictor(
727734
model = AutoModelForCausalLM.from_pretrained(
728735
predictor_args.model_name_or_path,
729736
dtype=predictor_args.dtype,
737+
use_flash_attention=predictor_args.use_flash_attention,
730738
tensor_parallel_degree=tensor_parallel_degree,
731739
tensor_parallel_rank=tensor_parallel_rank,
732740
)

paddlenlp/data/data_collator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,6 @@ def __call__(self, features, return_tensors=None):
405405
return_tensors=return_tensors,
406406
return_attention_mask=self.return_attention_mask,
407407
)
408-
409408
# prepare decoder_input_ids
410409
if (
411410
labels is not None

paddlenlp/peft/prefix/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
def bloom_postprocess_past_key_value(past_key_values):
1919
# (layer_num, bs, head_num/tensor_parallel_degree, prefixlen, head_dim)*2
20-
past_key_values = paddle.transpose(past_key_values, perm=[2, 0, 3, 1, 4]).split(2)
20+
keys, values = paddle.transpose(past_key_values, perm=[2, 0, 1, 3, 4]).split(2)
2121
# keys: [layer_num, bs, head_num/tensor_parallel_degree, head_dim, prefixlen]
2222
# value: [layer_num, bs, head_num/tensor_parallel_degree, prefixlen, head_dim]
23-
keys, values = past_key_values[0].transpose([0, 1, 2, 4, 3]), past_key_values[1]
23+
# keys, values = past_key_values[0].transpose([0, 1, 2, 4, 3]), past_key_values[1]
2424
return tuple(zip(keys, values))
2525

2626

paddlenlp/transformers/bloom/modeling.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,23 @@ def forward(
378378
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
379379

380380
batch_size, q_length, _, _ = query_layer.shape
381+
382+
if layer_past is not None:
383+
past_key, past_value = layer_past
384+
# concatenate along seq_length dimension:
385+
# - key: [batch_size, kv_length, self.num_heads, head_dim]
386+
# - value: [batch_size, kv_length, self.num_heads, head_dim]
387+
key_layer = paddle.concat((past_key, key_layer), axis=1)
388+
value_layer = paddle.concat((past_value, value_layer), axis=1)
389+
390+
if use_cache is True:
391+
present = (key_layer, value_layer)
392+
else:
393+
present = None
394+
381395
version = paddle.version.full_version
382396
version_check = True
383-
if version != "0.0.0" and version <= "2.5.2":
397+
if self.config.use_flash_attention and version != "0.0.0" and version <= "2.5.2":
384398
logger.warning(
385399
"PaddlePaddle version 2.5.3 or higher is required, please upgrade your PaddlePaddle to 2.5.3 or other higher version."
386400
)
@@ -397,46 +411,19 @@ def forward(
397411
key_states,
398412
value_states,
399413
attn_mask=attention_mask,
414+
dropout_p=self.config.attention_dropout,
415+
training=self.training,
400416
is_causal=False,
401417
)
402418
attn_weights = None
403419
# [batch_size, seq_len, num_heads, head_dim] = > [batch_size, seq_len, hidden_size]
404420
attn_output = attn_output.reshape([attn_output.shape[0], attn_output.shape[1], -1])
405421
output_tensor = self.dense(attn_output)
406422

407-
query_layer = query_layer.transpose([0, 2, 1, 3])
408-
key_layer = key_layer.transpose([0, 2, 3, 1])
409-
value_layer = value_layer.transpose([0, 2, 1, 3])
410-
if layer_past is not None:
411-
past_key, past_value = layer_past
412-
# concatenate along seq_length dimension:
413-
# - key: [batch_size, self.num_heads, head_dim, kv_length]
414-
# - value: [batch_size, self.num_heads, kv_length, head_dim]
415-
key_layer = paddle.concat((past_key, key_layer), axis=3)
416-
value_layer = paddle.concat((past_value, value_layer), axis=2)
417-
418-
if use_cache:
419-
present = (key_layer, value_layer)
420-
else:
421-
present = None
422423
else:
423-
424424
query_layer = query_layer.transpose([0, 2, 1, 3])
425425
key_layer = key_layer.transpose([0, 2, 3, 1])
426426
value_layer = value_layer.transpose([0, 2, 1, 3])
427-
if layer_past is not None:
428-
past_key, past_value = layer_past
429-
# concatenate along seq_length dimension:
430-
# - key: [batch_size, self.num_heads, head_dim, kv_length]
431-
# - value: [batch_size, self.num_heads, kv_length, head_dim]
432-
key_layer = paddle.concat((past_key, key_layer), axis=3)
433-
value_layer = paddle.concat((past_value, value_layer), axis=2)
434-
435-
if use_cache is True:
436-
present = (key_layer, value_layer)
437-
else:
438-
present = None
439-
440427
_, _, _, kv_length = key_layer.shape
441428

442429
query_layer = query_layer.reshape([batch_size * self.num_heads, q_length, self.head_dim])
@@ -449,7 +436,6 @@ def forward(
449436
attention_scores = baddbmm(
450437
alibi, batch1=query_layer, batch2=key_layer, beta=self.beta, alpha=self.inv_norm_factor
451438
)
452-
453439
# change view to [batch_size, num_heads, q_length, kv_length]
454440
# attention_scores = matmul_result.reshape([batch_size, self.num_heads, q_length, kv_length])
455441

@@ -949,14 +935,13 @@ def forward(
949935
seq_length_with_past = seq_length
950936
past_key_values_length = 0
951937
if past_key_values[0] is not None:
952-
past_key_values_length = past_key_values[0][0].shape[3]
938+
past_key_values_length = past_key_values[0][0].shape[1]
953939
seq_length_with_past = seq_length_with_past + past_key_values_length
954940

955941
if attention_mask is None:
956942
attention_mask = paddle.ones([batch_size, seq_length_with_past], dtype="bool")
957943
elif attention_mask.dtype != paddle.bool:
958944
attention_mask = paddle.cast(attention_mask, "bool")
959-
960945
if len(attention_mask.shape) > 2:
961946
_attention_mask = paddle.ones([batch_size, seq_length_with_past], dtype="bool")
962947
alibi = build_alibi_tensor(_attention_mask, self.config.n_head, dtype=hidden_states.dtype)

paddlenlp/transformers/chatglm/modeling.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -245,25 +245,26 @@ def forward(
245245
q_layer, k_layer, v_layer = paddle.split(mixed_layer, 3, axis=-1)
246246
# [s, b, n, h/n]
247247
q_layer, k_layer = self._core_attention(q_layer, k_layer, position_ids, rotary_embeds)
248+
249+
if cache is not None:
250+
cache_k, cache_v = cache[0], cache[1]
251+
# [s + c, b, n, h/n]
252+
k_layer = paddle.concat([cache_k, k_layer], axis=0)
253+
v_layer = paddle.concat([cache_v, v_layer], axis=0)
254+
255+
cache_kv = None
256+
if use_cache:
257+
cache_kv = (k_layer, v_layer)
248258
version = paddle.version.full_version
249259
version_check = True
250-
if version != "0.0.0" and version <= "2.5.2":
260+
if self.config.use_flash_attention and version != "0.0.0" and version <= "2.5.2":
251261
logger.warning(
252262
"PaddlePaddle version 2.5.3 or higher is required, please upgrade your PaddlePaddle to 2.5.3 or other higher version."
253263
)
254264
version_check = False
255265
if self.config.use_flash_attention and version_check:
256266
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
257267
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
258-
if cache is not None:
259-
cache_k, cache_v = cache[0], cache[1]
260-
# [s + c, b, n, h/n]
261-
k_layer = paddle.concat([cache_k, k_layer], axis=0)
262-
v_layer = paddle.concat([cache_v, v_layer], axis=0)
263-
cache_kv = None
264-
if use_cache:
265-
cache_kv = (k_layer, v_layer)
266-
267268
# [s, b, n, h/n] = > [batch_size, seq_len, num_heads, head_dim]
268269
q_layer = paddle.transpose(q_layer, [1, 0, 2, 3])
269270
k_layer = paddle.transpose(k_layer, [1, 0, 2, 3])
@@ -286,18 +287,9 @@ def forward(
286287

287288
output, attention_probs = attn_output, attn_weights
288289
else:
289-
if cache is not None:
290-
cache_k, cache_v = cache[0], cache[1]
291-
# [s + c, b, n, h/n]
292-
k_layer = paddle.concat([cache_k, k_layer], axis=0)
293-
v_layer = paddle.concat([cache_v, v_layer], axis=0)
294290

295291
seq_length, batch_size, num_heads, hidden_size = k_layer.shape
296292

297-
cache_kv = None
298-
if use_cache:
299-
cache_kv = (k_layer, v_layer)
300-
301293
attention_scale_coeff = float(layer_id) + 1.0
302294
if self.attention_scale:
303295
# [s, b, n, h/n]

tests/fixtures/llm/finetune.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ finetune:
1717
fp16_opt_level: "O2"
1818
do_train: true
1919
do_eval: true
20+
use_flash_attention: true
2021
disable_tqdm: true
2122
load_best_model_at_end: true
2223
eval_with_do_generation: false

tests/fixtures/llm/predictor.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ inference-predict:
33
mode: dynamic
44
max_length: 40
55
batch_size: 2
6+
use_flash_attention: false
67
decode_strategy: greedy_search
78
dtype: float16
89
data_file: tests/fixtures/llm/data/train.json

tests/fixtures/llm/prefix_tuning.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ prefix_tuning:
1717
do_train: true
1818
do_eval: true
1919
disable_tqdm: true
20+
use_flash_attention: false
2021
load_best_model_at_end: true
2122
eval_with_do_generation: false
2223
metric_for_best_model: "accuracy"

tests/llm/test_predictor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,30 @@ def test_predictor(self):
7777
else:
7878
self.assertGreaterEqual(count / len(result_0), 0.4)
7979

80+
def test_flash_attention(self):
81+
self.run_predictor({"inference_model": False, "use_flash_attention": False})
82+
result_0 = self._read_result(os.path.join(self.output_dir, "predict.json"))
83+
84+
self.run_predictor({"inference_model": False, "use_flash_attention": True})
85+
result_1 = self._read_result(os.path.join(self.output_dir, "predict.json"))
86+
87+
# compare the generation result of dygraph & flash attention model
88+
assert len(result_0) == len(result_1)
89+
90+
count, full_match = 0, 0
91+
for inference_item, no_inference_item in zip(result_0, result_1):
92+
if self.model_name_or_path == "__internal_testing__/tiny-random-llama":
93+
min_length = 5
94+
else:
95+
min_length = min(len(inference_item), len(no_inference_item))
96+
count += int(inference_item[: min_length // 2] == no_inference_item[: min_length // 2])
97+
full_match += int(inference_item[:min_length] == no_inference_item[:min_length])
98+
99+
if self.model_name_or_path == "__internal_testing__/tiny-random-llama":
100+
self.assertGreaterEqual(count / len(result_0), 0.2)
101+
else:
102+
self.assertEqual(full_match / len(result_0), 1.0)
103+
80104
def test_wint8(self):
81105
self.run_predictor({"inference_model": True, "quant_type": "weight_only_int8"})
82106
result_0 = self._read_result(os.path.join(self.output_dir, "predict.json"))

tests/llm/test_prefix_tuning.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def test_prefix_tuning(self):
5555

5656
prefix_tuning_config["dataset_name_or_path"] = self.data_dir
5757
prefix_tuning_config["output_dir"] = self.output_dir
58-
5958
with argv_context_guard(prefix_tuning_config):
6059
from finetune_generation import main
6160

0 commit comments

Comments
 (0)