Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 244 additions & 6 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,10 @@ struct vk_device_struct {
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_ssm_scan_f32_d16;
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_opt_step_sgd_f32;
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
Expand Down Expand Up @@ -1087,6 +1091,19 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t C;
uint32_t H;
};
struct vk_op_ssm_scan_push_constants {
uint32_t src0_nb2, src0_nb3, src1_nb2, src1_nb3;
uint32_t src2_nb1, src2_nb2, src3_nb1;
uint32_t src4_nb2, src4_nb3, src5_nb2, src5_nb3;
uint32_t s_off;
uint32_t n_head, d_head, n_group, n_tok;
};
struct vk_op_ssm_conv_push_constants {
uint32_t src0_nb1, src0_nb2;
uint32_t src1_nb1;
uint32_t dst_nb0, dst_nb1, dst_nb2;
uint32_t nc, ncs, nr, n_t, n_s;
};

struct vk_op_conv2d_push_constants {
uint32_t Cout;
Expand Down Expand Up @@ -3588,6 +3605,12 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);

ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d16, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {16, device->subgroup_size, 16}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);

ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {1, 1, 1}, {32}, 1);

ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
Expand Down Expand Up @@ -8087,6 +8110,23 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
case GGML_OP_SSM_SCAN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t d_state = src0->ne[0];
if (d_state == 16) {
return ctx->device->pipeline_ssm_scan_f32_d16;
} else if (d_state == 128) {
return ctx->device->pipeline_ssm_scan_f32_d128;
} else if (d_state == 256) {
return ctx->device->pipeline_ssm_scan_f32_d256;
}
}
return nullptr;
case GGML_OP_SSM_CONV:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_ssm_conv_f32;
}
return nullptr;
case GGML_OP_OPT_STEP_ADAMW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_adamw_f32;
Expand Down Expand Up @@ -8581,6 +8621,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}
break;
case GGML_OP_SSM_CONV:
{
const uint32_t nr = src0->ne[1];
const uint32_t n_t = dst->ne[1];
const uint32_t n_s = dst->ne[2];
elements = { CEIL_DIV(nr, 32), n_t, n_s };
}
break;
default:
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
break;
Expand Down Expand Up @@ -9027,6 +9075,122 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
);
}

static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const ggml_tensor * src3 = dst->src[3];
const ggml_tensor * src4 = dst->src[4];
const ggml_tensor * src5 = dst->src[5];

GGML_ASSERT(dst->buffer != nullptr);

const uint32_t head_dim = src0->ne[1];
const uint32_t n_head = src1->ne[1];
const uint32_t n_group = src4->ne[1];
const uint32_t n_tok = src1->ne[2];
const uint32_t n_seq = src1->ne[3];

bool is_mamba2 = (src3->nb[1] == sizeof(float));

vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
GGML_ASSERT(pipeline != nullptr);

if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}

const int64_t s_off = ggml_nelements(src1) * sizeof(float);

const vk_op_ssm_scan_push_constants pc = {
(uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
(uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
(uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
(uint32_t)src3->nb[1],
(uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
(uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
(uint32_t)s_off,
n_head, head_dim, n_group, n_tok
};

ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC];
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
}

vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr };
size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 };
bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false };

if (ctx->device->uma) {
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
srcs_uma[i] = d_srcs[i] != nullptr;
}
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
dst_uma = d_D != nullptr;
}

if (!dst_uma) {
d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
}
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
if (!srcs_uma[i]) {
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
}
}

size_t dst_size = ggml_nbytes(dst);
size_t src_sizes[GGML_MAX_SRC];
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
src_sizes[i] = ggml_nbytes(dst->src[i]);
}

std::array<uint32_t, 3> elements;

if (is_mamba2) {
const int splitH = 16;
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
const uint32_t num_workgroups_y = n_seq;
elements = { num_workgroups_x, num_workgroups_y, 1 };
} else {
const uint32_t num_workgroups_x = n_seq;
const uint32_t num_workgroups_y = CEIL_DIV(n_head, 128);
elements = { num_workgroups_x, num_workgroups_y, 1 };
}

ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, pc, elements);
}

static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
(uint32_t)src1->nb[1],
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
(uint32_t)src1->ne[0],
(uint32_t)src0->ne[0],
(uint32_t)src0->ne[1],
(uint32_t)dst->ne[1],
(uint32_t)dst->ne[2],
}, dryrun);
}

static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
const ggml_tensor * x = dst->src[0];
const ggml_tensor * g = dst->src[1];
Expand Down Expand Up @@ -10859,6 +11023,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW:
Expand Down Expand Up @@ -11276,6 +11442,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr

break;

case GGML_OP_SSM_SCAN:
ggml_vk_ssm_scan(ctx, compute_ctx, node, dryrun);

break;

case GGML_OP_SSM_CONV:
ggml_vk_ssm_conv(ctx, compute_ctx, node, dryrun);

break;

case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);

Expand Down Expand Up @@ -11387,6 +11563,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
Expand Down Expand Up @@ -12867,6 +13045,61 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_SSM_SCAN:
{
if (ggml_is_quantized(op->src[0]->type) || ggml_is_quantized(op->src[1]->type) || ggml_is_quantized(op->src[2]->type)) {
return false;
}
if (op->src[3] && ggml_is_quantized(op->src[3]->type)) {
return false;
}
if (op->src[4] && ggml_is_quantized(op->src[4]->type)) {
return false;
}
if (op->src[5] && ggml_is_quantized(op->src[5]->type)) {
return false;
}
if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
return false;
}
if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
return false;
}

const uint32_t d_state = op->src[0]->ne[0];
const uint32_t head_dim = op->src[0]->ne[1];
const uint32_t n_head = op->src[1]->ne[1];
const uint32_t n_group = op->src[4] ? op->src[4]->ne[1] : 1;

bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
if (is_mamba2) {
if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
return false;
}
} else {
if (n_head % 128 != 0 || head_dim != 1 || n_group != 1 || d_state != 16) {
return false;
}
}

ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);

const uint32_t splitH = 16;

size_t stateC_size = splitH * d_state * sizeof(float);
size_t subgroup_sdata_size = d_state * sizeof(float);
size_t total_shared_memory = stateC_size + subgroup_sdata_size;

// Check that there is enough memory to hold the stateC buffer when splitH is 16.
if (total_shared_memory > device->properties.limits.maxComputeSharedMemorySize) {
return false;
}

return true;
}
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
Expand Down Expand Up @@ -13211,14 +13444,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *

struct ggml_context * ggml_ctx = ggml_init(iparams);

std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
std::array<size_t, GGML_MAX_SRC> src_size = {};
std::array<void *, GGML_MAX_SRC> src_buffer = {};
const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};

struct ggml_tensor * tensor_clone = nullptr;

for (int i = 0; i < 6; i++) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
ggml_tensor * srci = tensor->src[i];
if (fused_rms_norm_mul) {
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
Expand Down Expand Up @@ -13525,6 +13758,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[2]);
} else if (tensor->op == GGML_OP_ADD_ID) {
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
} else if (tensor->op == GGML_OP_SSM_SCAN) {
tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_SSM_CONV) {
tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
}
else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
Expand All @@ -13546,7 +13784,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
memcpy(comp_result, tensor_clone->data, comp_size);
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);

for (int i = 0; i < 6; i++) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (src_buffer[i] != nullptr) {
free(src_buffer[i]);
}
Expand Down
44 changes: 44 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#version 450

#extension GL_EXT_control_flow_attributes : require

#include "types.glsl"

layout(constant_id = 0) const int BLOCK_SIZE = 32;

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout(binding = 0) readonly buffer Src0 { float src0[]; };
layout(binding = 1) readonly buffer Src1 { float src1[]; };
layout(binding = 2) buffer Dst { float dst[]; };

layout(push_constant) uniform PushConstants {
uint src0_nb1; uint src0_nb2;
uint src1_nb1;
uint dst_nb0; uint dst_nb1; uint dst_nb2;
uint nc; uint ncs; uint nr; uint n_t; uint n_s;
};

void main() {
const uint global_thread_id = gl_WorkGroupID.x * gl_WorkGroupSize.x + gl_LocalInvocationID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i3 = gl_WorkGroupID.z;

if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
return;
}

const uint i1 = global_thread_id;
const uint src0_base = i3 * (src0_nb2 / 4) + i2 + i1 * (src0_nb1 / 4);
const uint src1_base = i1 * (src1_nb1 / 4);
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;

float sum = 0.0;
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
const uint src0_idx = src0_base + i0;
const uint src1_idx = src1_base + i0;
sum += src0[src0_idx] * src1[src1_idx];
}

dst[dst_idx] = sum;
}
Loading