Skip to content

Commit 2081b3f

Browse files
ikawrakowIwan Kawrakow
andauthored
Vulkan: a fresh start (#608)
* It compiles * Seems to be working with coopmat * Vulkan needs f32 precision for flash attention * Vulkan: fix u_batch > 4096/n_active_experts for coopmat1. Without this fix we get an assert. We get the same assert in mainline too. --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 45fae1a commit 2081b3f

35 files changed

+2259
-513
lines changed

ggml/include/ggml-vulkan.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ extern "C" {
1010
#define GGML_VK_NAME "Vulkan"
1111
#define GGML_VK_MAX_DEVICES 16
1212

13-
GGML_API GGML_CALL void ggml_vk_instance_init(void);
13+
//GGML_API GGML_CALL void ggml_vk_instance_init(void);
1414

1515
// backend API
1616
GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num);

ggml/include/ggml.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,15 @@ extern "C" {
884884
GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
885885
GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
886886

887+
// returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)
888+
GGML_API bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor);
889+
890+
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
891+
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
892+
893+
// true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
894+
GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
895+
887896
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
888897
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
889898

ggml/src/ggml-impl.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,16 @@ static size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct g
748748
GGML_ABORT("fatal error");
749749
}
750750

751+
static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) {
752+
assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
753+
return ((const int32_t *)(tensor->op_params))[i];
754+
}
755+
756+
static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) {
757+
assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
758+
return ((const float *)(tensor->op_params))[i];
759+
}
760+
751761
#ifdef __cplusplus
752762
}
753763
#endif

ggml/src/ggml-vulkan.cpp

Lines changed: 1303 additions & 334 deletions
Large diffs are not rendered by default.

ggml/src/ggml.c

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4682,6 +4682,24 @@ GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
46824682
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
46834683
}
46844684

4685+
GGML_CALL bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
4686+
return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
4687+
}
4688+
4689+
GGML_CALL bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
4690+
return
4691+
tensor->nb[0] > tensor->nb[2] &&
4692+
tensor->nb[1] > tensor->nb[0] &&
4693+
tensor->nb[2] == ggml_type_size(tensor->type);
4694+
}
4695+
4696+
GGML_CALL bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
4697+
return
4698+
tensor->ne[0] == ggml_blck_size(tensor->type) ||
4699+
tensor->nb[0] == ggml_type_size(tensor->type);
4700+
}
4701+
4702+
46854703
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
46864704
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
46874705

@@ -5195,16 +5213,6 @@ static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params,
51955213
memcpy(tensor->op_params, params, params_size);
51965214
}
51975215

5198-
static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) {
5199-
assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
5200-
return ((const int32_t *)(tensor->op_params))[i];
5201-
}
5202-
5203-
static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) {
5204-
assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
5205-
return ((const float *)(tensor->op_params))[i];
5206-
}
5207-
52085216
static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
52095217
assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
52105218
((int32_t *)(tensor->op_params))[i] = value;

ggml/src/vulkan-shaders/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ endif()
2727
set(TARGET vulkan-shaders-gen)
2828
add_executable(${TARGET} vulkan-shaders-gen.cpp)
2929
install(TARGETS ${TARGET} RUNTIME)
30-
target_compile_features(${TARGET} PRIVATE cxx_std_11)
30+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
3131
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
5+
layout (push_constant) uniform parameter
6+
{
7+
uint ne;
8+
uint batches;
9+
uint channels;
10+
uint dst_w;
11+
uint dst_h;
12+
uint src_w;
13+
uint src_h;
14+
uint knl_w;
15+
uint knl_h;
16+
int stride_x;
17+
int stride_y;
18+
int pad_x;
19+
int pad_y;
20+
int dilation_x;
21+
int dilation_y;
22+
} p;
23+
24+
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
25+
layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
26+
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
27+
28+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
29+
30+
FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
31+
uint i0 = idx / p.dst_w;
32+
uint dst_x = idx - i0 * p.dst_w;
33+
uint i1 = i0 / p.dst_h;
34+
uint dst_y = i0 - i1 * p.dst_h;
35+
uint n = i1 / p.channels;
36+
uint c = i1 - n * p.channels;
37+
38+
uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
39+
uint knl_i = c * p.knl_h * p.knl_w;
40+
41+
FLOAT_TYPE sum = 0.0;
42+
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
43+
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
44+
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
45+
continue;
46+
}
47+
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
48+
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
49+
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
50+
continue;
51+
}
52+
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
53+
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
54+
sum = fma(v, k, sum);
55+
}
56+
}
57+
return sum;
58+
}
59+
60+
FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
61+
uint i0 = idx / p.channels;
62+
uint c = idx - i0 * p.channels;
63+
uint i1 = i0 / p.dst_w;
64+
uint dst_x = i0 - i1 * p.dst_w;
65+
uint n = i1 / p.dst_h;
66+
uint dst_y = i1 - n * p.dst_h;
67+
68+
uint src_i = n * p.channels * p.src_h * p.src_w;
69+
uint src_row = p.src_w * p.channels;
70+
uint knl_row = p.knl_w * p.channels;
71+
72+
FLOAT_TYPE sum = 0.0;
73+
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
74+
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
75+
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
76+
continue;
77+
}
78+
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
79+
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
80+
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
81+
continue;
82+
}
83+
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
84+
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
85+
sum = fma(v, k, sum);
86+
}
87+
}
88+
return sum;
89+
}
90+
91+
void main() {
92+
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
93+
if (idx >= p.ne) {
94+
return;
95+
}
96+
97+
FLOAT_TYPE result =
98+
#ifdef WHCN
99+
conv_2d_dw_whcn(idx);
100+
#else
101+
conv_2d_dw_cwhn(idx);
102+
#endif
103+
dst_data[idx] = D_TYPE(result);
104+
}
105+

ggml/src/vulkan-shaders/copy_to_quant.comp

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,25 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi
66
#endif // RTE16
77

88
#include "types.comp"
9-
#include "generic_unary_head.comp"
109

11-
#if defined(DATA_A_IQ4_NL)
12-
// 16 invocations needed for init_iq4nl_shmem
13-
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
10+
#if defined(SET_ROWS) && QUANT_K == 1
11+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
12+
const uint BLOCK_SIZE = 512;
1413
#else
15-
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
14+
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
15+
const uint BLOCK_SIZE = 32;
1616
#endif
1717

1818
layout (binding = 0) readonly buffer S {float data_s[];};
19+
20+
#if defined(SET_ROWS)
21+
#include "generic_binary_head.comp"
22+
layout (binding = 1) readonly buffer C {uvec2 data_i[];};
23+
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
24+
#else
25+
#include "generic_unary_head.comp"
1926
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
27+
#endif
2028

2129
#if defined(DATA_A_Q4_0)
2230
void quantize(uint dst_idx, uint src_idx)
@@ -221,15 +229,56 @@ void quantize(uint dst_idx, uint src_idx)
221229
}
222230
#endif
223231

232+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
233+
void quantize(uint dst_idx, uint src_idx)
234+
{
235+
data_q[dst_idx] = A_TYPE(data_s[src_idx]);
236+
}
237+
#endif
238+
239+
#if defined(DATA_A_BF16)
240+
void quantize(uint dst_idx, uint src_idx)
241+
{
242+
data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
243+
}
244+
#endif
245+
246+
#if defined(SET_ROWS)
247+
224248
void main() {
225249
#ifdef NEEDS_INIT_IQ_SHMEM
226250
init_iq_shmem(gl_WorkGroupSize);
227-
if (gl_LocalInvocationIndex.x != 0) {
251+
#endif
252+
253+
const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
254+
255+
if (idx >= p.ne) {
228256
return;
229257
}
258+
259+
uint i00, i01, i02, i03;
260+
get_indices(idx, i00, i01, i02, i03);
261+
262+
uint i12 = fastmod(i03, p.ne12);
263+
uint i11 = fastmod(i02, p.ne11);
264+
uint i10 = i01;
265+
266+
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
267+
268+
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
269+
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
270+
271+
quantize(dst_idx, src0_idx);
272+
}
273+
274+
#else
275+
276+
void main() {
277+
#ifdef NEEDS_INIT_IQ_SHMEM
278+
init_iq_shmem(gl_WorkGroupSize);
230279
#endif
231280

232-
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
281+
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
233282

234283
if (idx >= p.ne) {
235284
return;
@@ -240,3 +289,5 @@ void main() {
240289

241290
quantize(dst_idx, src_idx);
242291
}
292+
293+
#endif

ggml/src/vulkan-shaders/flash_attn.comp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ void main() {
100100
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
101101
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
102102
#endif
103+
uint32_t m_offset = 0;
104+
if (p.nem2 != 1 || p.nem3 != 1) {
105+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
106+
}
103107

104108
[[dont_unroll]]
105109
for (uint32_t j = start_j; j < end_j; ++j) {
@@ -145,13 +149,13 @@ void main() {
145149
}
146150
}
147151

148-
if (p.mask != 0) {
152+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
149153

150154
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
151155
uint32_t c = (idx + tid) % Bc;
152156
uint32_t r = (idx + tid) / Bc;
153157
if (idx + tid < Bc * Br) {
154-
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
158+
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
155159
}
156160
}
157161
barrier();

ggml/src/vulkan-shaders/flash_attn_base.comp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ layout (push_constant) uniform parameter {
2424
uint32_t nev2;
2525
uint32_t nev3;
2626
uint32_t nem1;
27+
uint32_t nem2;
28+
uint32_t nem3;
2729

2830
uint32_t nb01;
2931
uint32_t nb02;
@@ -34,14 +36,12 @@ layout (push_constant) uniform parameter {
3436
uint32_t nb21;
3537
uint32_t nb22;
3638
uint32_t nb23;
37-
uint32_t nb31;
3839

3940
float scale;
4041
float max_bias;
4142
float logit_softcap;
4243

43-
uint32_t mask;
44-
uint32_t n_head_log2;
44+
uint32_t mask_n_head_log2;
4545
float m0;
4646
float m1;
4747

@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
5050
uint32_t k_num;
5151
} p;
5252

53+
#define MASK_ENABLE_BIT (1<<16)
54+
#define N_LOG2_MASK 0xFFFF
55+
5356
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
5457

5558
#if defined(A_TYPE_PACKED16)
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
100103
{
101104
const uint32_t h = iq2 + (r % p.gqa_ratio);
102105

103-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
104-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
106+
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
107+
108+
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
109+
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
105110

106111
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
107112
}

0 commit comments

Comments
 (0)