Skip to content

Commit b915fc0

Browse files
committed
Merge branch 'concedo_experimental' into crokeso
2 parents 2470e7e + 37c7f7d commit b915fc0

File tree

11 files changed

+610
-570
lines changed

11 files changed

+610
-570
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 95 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,8 @@ struct vk_device_struct {
517517

518518
ggml_backend_buffer_type buffer_type;
519519

520+
bool disable_fusion;
521+
520522
#ifdef GGML_VULKAN_MEMORY_DEBUG
521523
std::unique_ptr<vk_memory_logger> memory_logger;
522524
#endif
@@ -652,6 +654,7 @@ struct vk_flash_attn_push_constants {
652654
uint32_t nev3;
653655
uint32_t nem1;
654656
uint32_t nem2;
657+
uint32_t nem3;
655658

656659
uint32_t nb01;
657660
uint32_t nb02;
@@ -667,8 +670,7 @@ struct vk_flash_attn_push_constants {
667670
float max_bias;
668671
float logit_softcap;
669672

670-
uint32_t mask;
671-
uint32_t n_head_log2;
673+
uint32_t mask_n_head_log2;
672674
float m0;
673675
float m1;
674676

@@ -1107,8 +1109,8 @@ static size_t vk_skip_checks;
11071109
static size_t vk_output_tensor;
11081110

11091111
static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
1110-
static void ggml_vk_check_results_0(ggml_tensor * tensor);
1111-
static void ggml_vk_check_results_1(ggml_tensor * tensor);
1112+
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
1113+
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
11121114
#endif
11131115

11141116
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -3531,6 +3533,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
35313533

35323534
device->idx = idx;
35333535

3536+
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3537+
35343538
return device;
35353539
}
35363540

@@ -6135,6 +6139,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61356139

61366140
const uint32_t nem1 = mask ? mask->ne[1] : 0;
61376141
const uint32_t nem2 = mask ? mask->ne[2] : 0;
6142+
const uint32_t nem3 = mask ? mask->ne[3] : 0;
61386143

61396144
const uint32_t HSK = nek0;
61406145
const uint32_t HSV = nev0;
@@ -6202,7 +6207,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
62026207
}
62036208

62046209
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
6205-
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
6210+
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
62066211
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
62076212
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
62086213
// and change addressing calculations to index Q's dimension 2.
@@ -6372,17 +6377,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
63726377
}
63736378
}
63746379

6380+
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
6381+
63756382
const vk_flash_attn_push_constants pc = { N, KV,
63766383
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
63776384
(uint32_t)neq2, (uint32_t)neq3,
63786385
(uint32_t)nek2, (uint32_t)nek3,
63796386
(uint32_t)nev2, (uint32_t)nev3,
6380-
nem1, nem2,
6387+
nem1, nem2, nem3,
63816388
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
63826389
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
63836390
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
63846391
scale, max_bias, logit_softcap,
6385-
mask != nullptr, n_head_log2, m0, m1,
6392+
mask_n_head_log2, m0, m1,
63866393
gqa_ratio, split_kv, split_k };
63876394

63886395
ggml_vk_sync_buffers(subctx);
@@ -7675,8 +7682,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
76757682
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
76767683
}
76777684

7678-
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7679-
float * op_params = (float *)dst->op_params;
7685+
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
76807686
const uint32_t src0_type_size = ggml_type_size(src0->type);
76817687
const uint32_t src1_type_size = ggml_type_size(src1->type);
76827688
const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -8906,7 +8912,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
89068912
}
89078913
}
89088914

8909-
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
8915+
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
89108916

89118917
// Returns true if node has enqueued work into the queue, false otherwise
89128918
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
@@ -9167,9 +9173,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91679173
// fused rms_norm + mul
91689174
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
91699175
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
9170-
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
9176+
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
91719177
} else {
9172-
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
9178+
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
91739179
}
91749180
break;
91759181
case GGML_OP_RMS_NORM_BACK:
@@ -9329,7 +9335,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93299335

93309336
ctx->compute_ctx.reset();
93319337

9332-
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
9338+
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
93339339
if (!ok) {
93349340
if (node->op == GGML_OP_UNARY) {
93359341
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
@@ -9344,7 +9350,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93449350
return true;
93459351
}
93469352

9347-
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9353+
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9354+
GGML_UNUSED(cgraph);
93489355
ggml_backend_buffer * buf = nullptr;
93499356

93509357
switch (tensor->op) {
@@ -9454,7 +9461,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
94549461
// Only run if ctx hasn't been submitted yet
94559462
if (!subctx->seqs.empty()) {
94569463
#ifdef GGML_VULKAN_CHECK_RESULTS
9457-
ggml_vk_check_results_0(tensor);
9464+
ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
94589465
use_fence = true;
94599466
#endif
94609467

@@ -9474,7 +9481,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
94749481
ggml_vk_wait_for_fence(ctx);
94759482
}
94769483
#ifdef GGML_VULKAN_CHECK_RESULTS
9477-
ggml_vk_check_results_1(tensor);
9484+
ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
94789485
#endif
94799486
}
94809487

@@ -9921,6 +9928,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
99219928
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
99229929
}
99239930

9931+
static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
9932+
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
9933+
return false;
9934+
}
9935+
9936+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
9937+
// additional constraints specific to this fusion
9938+
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
9939+
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9940+
9941+
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
9942+
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
9943+
// rms_norm only supports f32
9944+
if (mul->src[0]->type != GGML_TYPE_F32 ||
9945+
mul->src[1]->type != GGML_TYPE_F32 ||
9946+
mul->type != GGML_TYPE_F32) {
9947+
return false;
9948+
}
9949+
// if rms_norm is the B operand, then we don't handle broadcast
9950+
if (rms_norm == mul->src[1] &&
9951+
mul->src[0]->ne[1] != rms_norm->ne[1]) {
9952+
return false;
9953+
}
9954+
// rms_norm shader assumes contiguous rows
9955+
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
9956+
return false;
9957+
}
9958+
}
9959+
return true;
9960+
}
9961+
99249962
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
99259963
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
99269964
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9934,7 +9972,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
99349972

99359973
uint64_t total_mat_mul_bytes = 0;
99369974
for (int i = 0; i < cgraph->n_nodes; i++) {
9937-
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9975+
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
99389976
ctx->num_additional_fused_ops = 1;
99399977
}
99409978
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -10004,7 +10042,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1000410042
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
1000510043
}
1000610044

10007-
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10045+
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1000810046
ctx->num_additional_fused_ops = 1;
1000910047
}
1001010048

@@ -10327,12 +10365,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1032710365
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
1032810366
return false;
1032910367
}
10330-
// TODO: support broadcast
10331-
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
10332-
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
10333-
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
10334-
return false;
10335-
}
1033610368
// It's straightforward to support different K/V dequant, but would
1033710369
// significantly increase the number of pipelines
1033810370
if (op->src[1]->type != op->src[2]->type) {
@@ -10787,11 +10819,21 @@ void * comp_result;
1078710819
size_t comp_size;
1078810820
size_t comp_nb[GGML_MAX_DIMS];
1078910821
size_t check_counter = 0;
10790-
static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10822+
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
10823+
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
1079110824
if (tensor->op == GGML_OP_TRANSPOSE) {
1079210825
return;
1079310826
}
1079410827

10828+
bool fused_rms_norm_mul = false;
10829+
int rms_norm_idx = -1;
10830+
if (ctx->num_additional_fused_ops == 1 &&
10831+
tensor->op == GGML_OP_RMS_NORM &&
10832+
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
10833+
fused_rms_norm_mul = true;
10834+
tensor = cgraph->nodes[tensor_idx + 1];
10835+
}
10836+
1079510837
check_counter++;
1079610838
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
1079710839
return;
@@ -10819,6 +10861,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1081910861

1082010862
for (int i = 0; i < 6; i++) {
1082110863
ggml_tensor * srci = tensor->src[i];
10864+
if (fused_rms_norm_mul) {
10865+
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
10866+
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
10867+
switch (i) {
10868+
case 0: srci = rms_norm->src[0]; break;
10869+
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
10870+
default: continue;
10871+
}
10872+
}
1082210873
if (srci == nullptr) {
1082310874
continue;
1082410875
}
@@ -10876,7 +10927,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1087610927
} else if (tensor->op == GGML_OP_SUB) {
1087710928
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
1087810929
} else if (tensor->op == GGML_OP_MUL) {
10879-
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
10930+
if (fused_rms_norm_mul) {
10931+
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
10932+
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
10933+
} else {
10934+
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
10935+
}
1088010936
} else if (tensor->op == GGML_OP_DIV) {
1088110937
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
1088210938
} else if (tensor->op == GGML_OP_CONCAT) {
@@ -11067,10 +11123,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1106711123
GGML_ABORT("fatal error");
1106811124
}
1106911125

11070-
ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
11071-
ggml_build_forward_expand(cgraph, tensor_clone);
11126+
ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
11127+
ggml_build_forward_expand(cgraph_cpu, tensor_clone);
1107211128

11073-
ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
11129+
ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
1107411130

1107511131
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
1107611132
ggml_vk_print_tensor(tensor_clone, "tensor_clone");
@@ -11093,10 +11149,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1109311149
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
1109411150
}
1109511151

11096-
static void ggml_vk_check_results_1(ggml_tensor * tensor) {
11152+
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
11153+
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
1109711154
if (tensor->op == GGML_OP_TRANSPOSE) {
1109811155
return;
1109911156
}
11157+
bool fused_rms_norm_mul = false;
11158+
if (ctx->num_additional_fused_ops == 1 &&
11159+
tensor->op == GGML_OP_RMS_NORM &&
11160+
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
11161+
fused_rms_norm_mul = true;
11162+
tensor = cgraph->nodes[tensor_idx + 1];
11163+
}
11164+
1110011165
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
1110111166
return;
1110211167
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ void main() {
101101
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
102102
#endif
103103
uint32_t m_offset = 0;
104-
if (p.nem2 != 1) {
105-
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
104+
if (p.nem2 != 1 || p.nem3 != 1) {
105+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
106106
}
107107

108108
[[dont_unroll]]
@@ -149,7 +149,7 @@ void main() {
149149
}
150150
}
151151

152-
if (p.mask != 0) {
152+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
153153

154154
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
155155
uint32_t c = (idx + tid) % Bc;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
2525
uint32_t nev3;
2626
uint32_t nem1;
2727
uint32_t nem2;
28+
uint32_t nem3;
2829

2930
uint32_t nb01;
3031
uint32_t nb02;
@@ -40,8 +41,7 @@ layout (push_constant) uniform parameter {
4041
float max_bias;
4142
float logit_softcap;
4243

43-
uint32_t mask;
44-
uint32_t n_head_log2;
44+
uint32_t mask_n_head_log2;
4545
float m0;
4646
float m1;
4747

@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
5050
uint32_t k_num;
5151
} p;
5252

53+
#define MASK_ENABLE_BIT (1<<16)
54+
#define N_LOG2_MASK 0xFFFF
55+
5356
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
5457

5558
#if defined(A_TYPE_PACKED16)
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
100103
{
101104
const uint32_t h = iq2 + (r % p.gqa_ratio);
102105

103-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
104-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
106+
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
107+
108+
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
109+
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
105110

106111
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
107112
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ void main() {
126126
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
127127
#endif
128128
uint32_t m_offset = 0;
129-
if (p.nem2 != 1) {
130-
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
129+
if (p.nem2 != 1 || p.nem3 != 1) {
130+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
131131
}
132132

133133
[[dont_unroll]]
@@ -182,7 +182,7 @@ void main() {
182182
barrier();
183183
}
184184

185-
if (p.mask != 0) {
185+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
186186
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
187187
uint32_t c = (idx + tid) % Bc;
188188
uint32_t r = (idx + tid) / Bc;

0 commit comments

Comments
 (0)