Skip to content

Commit 82d745e

Browse files
committed
vulkan: Fix validation failure in quantized flash attention
1 parent 3f81b4e commit 82d745e

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,30 +65,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
6565
#if defined(A_TYPE_PACKED16)
6666
#define BINDING_IDX_K 0
6767
#define BINDING_IDX_V 1
68-
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
68+
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
69+
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
6970
#endif
7071

7172
#if defined(DATA_A_Q4_0)
7273
#define BLOCK_BYTE_SIZE 18
7374

7475
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
75-
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
76-
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
77-
uint shift = (iqs & 0x10) >> 2;
78-
vui_lo >>= shift;
79-
vui_hi >>= shift;
80-
81-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
76+
if (binding_idx == BINDING_IDX_K) {
77+
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
78+
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
79+
uint shift = (iqs & 0x10) >> 2;
80+
vui_lo >>= shift;
81+
vui_hi >>= shift;
82+
83+
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
84+
} else {
85+
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
86+
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
87+
uint shift = (iqs & 0x10) >> 2;
88+
vui_lo >>= shift;
89+
vui_hi >>= shift;
90+
91+
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
92+
}
8293
}
8394
#endif
8495

8596
#if defined(DATA_A_Q8_0)
8697
#define BLOCK_BYTE_SIZE 34
8798
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
88-
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
89-
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
99+
if (binding_idx == BINDING_IDX_K) {
100+
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
101+
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
90102

91-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
103+
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
104+
} else {
105+
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
106+
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
107+
108+
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
109+
}
92110
}
93111
#endif
94112

0 commit comments

Comments
 (0)