Skip to content

Commit 6a39d63

Browse files
committed
cann: update the FlashAttention with PSEShift
1 parent 3a73182 commit 6a39d63

File tree

9 files changed

+609
-50
lines changed

9 files changed

+609
-50
lines changed

ggml/src/ggml-cann/CMakeLists.txt

100644100755
File mode changed.

ggml/src/ggml-cann/Doxyfile

100644100755
File mode changed.

ggml/src/ggml-cann/acl_tensor.cpp

100644100755
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
3131
return ACL_FLOAT;
3232
case GGML_TYPE_F16:
3333
return ACL_FLOAT16;
34+
case GGML_TYPE_BF16:
35+
return ACL_BF16;
3436
case GGML_TYPE_I8:
3537
return ACL_INT8;
3638
case GGML_TYPE_I16:

ggml/src/ggml-cann/acl_tensor.h

100644100755
File mode changed.

ggml/src/ggml-cann/aclnn_ops.cpp

100644100755
Lines changed: 564 additions & 50 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cann/aclnn_ops.h

100644100755
File mode changed.

ggml/src/ggml-cann/common.h

100644100755
File mode changed.

ggml/src/ggml-cann/ggml-cann.cpp

100644100755
File mode changed.

ggml/src/ggml-cann/ifa.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# 单算子调用方式
2+
import torch
3+
import torch_npu
4+
import math
5+
6+
def load_float_array_to_tensor(file_path, shape, dtype):
7+
with open(file_path, 'r') as file:
8+
# 读取文件内容并按空格分割
9+
data = file.read().strip().split()
10+
# 将字符串转换为浮点数
11+
float_array = [float(num) for num in data]
12+
# 转换为 PyTorch 张量
13+
tensor = torch.tensor(float_array, dtype=dtype).reshape(shape).npu()
14+
return tensor
15+
16+
batch = 1
17+
nhead_q = 4
18+
nhead_kv = nhead_q
19+
seq_q = 1
20+
dims = 64
21+
seq_kv = 512
22+
layout="BNSD"
23+
24+
scale_value = 1 / pow(dims, 0.5)
25+
26+
q_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_0_q.txt",
27+
(batch, nhead_q, seq_q, dims), torch.float16)
28+
k_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_3_k.txt",
29+
(batch, nhead_kv, seq_kv, dims), torch.float16)
30+
31+
v_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_4_v.txt",
32+
(batch, nhead_kv, seq_kv, dims), torch.float16)
33+
34+
pse_tensor = load_float_array_to_tensor("/data/home/2101111451/pr/llama.cpp/output_acl_short_1_mask.txt",
35+
(1, 1, -1, seq_kv), torch.float16)
36+
37+
print(q_tensor.shape, k_tensor.shape, v_tensor.shape, pse_tensor.shape)
38+
39+
# 调用IFA算子
40+
out = torch_npu.npu_incre_flash_attention(q_tensor, k_tensor, v_tensor, pse_shift=pse_tensor,
41+
num_heads=nhead_q, num_key_value_heads=nhead_kv,
42+
input_layout=layout, scale_value=scale_value)
43+

0 commit comments

Comments
 (0)