Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5653,8 +5653,12 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
ggml_vk_queue_command_pools_cleanup(dst->device);
}

static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) {
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")");

if (disable_split_k) {
return 1;
}

uint32_t split_k = 1;
if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
Expand Down Expand Up @@ -5979,7 +5983,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_sync_buffers(ctx, subctx);
}

static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
Expand All @@ -5999,6 +6003,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub

const uint64_t ne20 = dst->ne[0];
const uint64_t ne21 = dst->ne[1];
const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type);
const uint32_t stride_batch_d = stride_d*ne21;

const uint64_t r2 = ne12 / ne02;
const uint64_t r3 = ne13 / ne03;
Expand Down Expand Up @@ -6067,7 +6073,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const int y_ne = padded_n * ne10;
const int d_ne = ne11 * ne01;

const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline);

const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
Expand Down Expand Up @@ -6226,13 +6232,16 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
}

// No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange.
VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});

// compute
ggml_vk_matmul(
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
{ d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
ne01, ne11, ne10,
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
); // NOLINT

Expand Down Expand Up @@ -6712,7 +6721,34 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con

static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&

// Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
// where the M dimension is very large.
// Split_k doesn't work with M splitting.
const size_t nbytes = ggml_nbytes(src0);
const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange;
if (needs_split) {
// Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
uint32_t m_offset = 0;
while (m_offset < dst->ne[0]) {
const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset));
ggml_tensor dst2 = *dst;
ggml_tensor src02 = *src0;

dst2.view_src = dst->view_src ? dst->view_src : (ggml_tensor *)dst;
src02.view_src = src0->view_src ? src0->view_src : (ggml_tensor *)src0;

dst2.view_offs += m_offset * dst->nb[0];
src02.view_offs += m_offset * src0->nb[1];
dst2.ne[0] = cur_M_size;
src02.ne[1] = cur_M_size;

ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true, dryrun);

m_offset += cur_M_size;
}
} else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
// detect 0213 permutation, and batch size of 1
src0->nb[0] <= src0->nb[2] &&
src0->nb[2] <= src0->nb[1] &&
Expand All @@ -6732,7 +6768,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
} else {
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false, dryrun);
}
}

Expand Down
3 changes: 2 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ void main() {
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);

#if QUANT_K > 1
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
Expand All @@ -281,6 +280,8 @@ void main() {
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);

tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);

tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);

#if !defined(MUL_MAT_ID)
Expand Down
8 changes: 8 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6200,6 +6200,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));

#if 0
// > 4GB A matrix. Too slow to be enabled by default.
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 96, 2592, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 3, 2592, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 1, 2592, {1, 1}, {1, 1}));
#endif

for (ggml_type type_a : all_types) {
for (int i = 1; i < 10; ++i) {
test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
Expand Down
Loading