Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Legend:
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
Expand All @@ -42,6 +42,7 @@ Legend:
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
Expand Down Expand Up @@ -84,7 +85,7 @@ Legend:
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
Expand All @@ -111,6 +112,6 @@ Legend:
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
8 changes: 8 additions & 0 deletions docs/ops/Vulkan.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5434,12 +5434,20 @@
"Vulkan0","LOG","type=f16,ne=[10,5,4,3]","support","0","no","Vulkan"
"Vulkan0","SIN","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
"Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
"Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
"Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","0","no","Vulkan"
"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan"
"Vulkan0","SIN","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan"
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan"
Expand Down
108 changes: 108 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,10 @@ struct vk_device_struct {
vk_pipeline pipeline_rms_norm_mul_rope_f32_f16;
vk_pipeline pipeline_rms_norm_back_f32;
vk_pipeline pipeline_l2_norm_f32;
vk_pipeline pipeline_floor_f32;
vk_pipeline pipeline_ceil_f32;
vk_pipeline pipeline_round_f32;
vk_pipeline pipeline_trunc_f32;

// [src/dst 0=fp32,1=fp16]
vk_pipeline pipeline_exp[2];
Expand Down Expand Up @@ -3712,6 +3716,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_floor_f32, "floor_f32", floor_f32_len, floor_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ceil_f32, "ceil_f32", ceil_f32_len, ceil_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_round_f32, "round_f32", round_f32_len, round_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_trunc_f32, "trunc_f32", trunc_f32_len, trunc_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);

#define CREATE_UNARY(name) \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
Expand Down Expand Up @@ -8223,6 +8232,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_cos_f32;
}
return nullptr;
case GGML_OP_FLOOR:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_floor_f32;
}
return nullptr;
case GGML_OP_CEIL:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_ceil_f32;
}
return nullptr;
case GGML_OP_ROUND:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_round_f32;
}
return nullptr;
case GGML_OP_TRUNC:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_trunc_f32;
}
return nullptr;
case GGML_OP_CLAMP:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
Expand Down Expand Up @@ -8320,6 +8349,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_HARDSWISH:
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_FLOOR:
if (dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_floor_f32;
}
break;
case GGML_UNARY_OP_CEIL:
if (dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_ceil_f32;
}
break;
case GGML_UNARY_OP_ROUND:
if (dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_round_f32;
}
break;
case GGML_UNARY_OP_TRUNC:
if (dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_trunc_f32;
}
break;
default:
break;
}
Expand Down Expand Up @@ -8640,6 +8689,10 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_FLOOR:
case GGML_OP_CEIL:
case GGML_OP_ROUND:
case GGML_OP_TRUNC:
return true;
default:
return false;
Expand Down Expand Up @@ -8930,6 +8983,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_UNARY:
case GGML_OP_GLU:
case GGML_OP_CONV_2D_DW:
case GGML_OP_FLOOR:
case GGML_OP_CEIL:
case GGML_OP_ROUND:
case GGML_OP_TRUNC:
{
uint32_t ne = ggml_nelements(dst);
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
Expand Down Expand Up @@ -9527,6 +9584,25 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst));
}

static void ggml_vk_floor(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, vk_op_unary_push_constants_init(src0, dst), dryrun);
}

static void ggml_vk_ceil(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, vk_op_unary_push_constants_init(src0, dst), dryrun);
}

static void ggml_vk_round(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, vk_op_unary_push_constants_init(src0, dst), dryrun);
}

static void ggml_vk_trunc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, vk_op_unary_push_constants_init(src0, dst), dryrun);

}

static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {

static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = ggml_get_op_params_f32(dst, 0);
Expand Down Expand Up @@ -11271,6 +11347,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
break;
default:
return false;
Expand Down Expand Up @@ -11361,6 +11441,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
case GGML_OP_FLOOR:
case GGML_OP_CEIL:
case GGML_OP_ROUND:
case GGML_OP_TRUNC:
break;
default:
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
Expand Down Expand Up @@ -11543,6 +11627,22 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_COS:
ggml_vk_cos(ctx, compute_ctx, src0, node);

break;
case GGML_OP_FLOOR:
ggml_vk_floor(ctx, compute_ctx, src0, node, dryrun);

break;
case GGML_OP_CEIL:
ggml_vk_ceil(ctx, compute_ctx, src0, node, dryrun);

break;
case GGML_OP_ROUND:
ggml_vk_round(ctx, compute_ctx, src0, node, dryrun);

break;
case GGML_OP_TRUNC:
ggml_vk_trunc(ctx, compute_ctx, src0, node, dryrun);

break;
case GGML_OP_CLAMP:
ggml_vk_clamp(ctx, compute_ctx, src0, node);
Expand Down Expand Up @@ -11601,6 +11701,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
ggml_vk_unary(ctx, compute_ctx, src0, node);
break;
default:
Expand Down Expand Up @@ -13655,6 +13759,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
case GGML_OP_FLOOR:
case GGML_OP_CEIL:
case GGML_OP_ROUND:
case GGML_OP_TRUNC:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT:
return op->ne[0] <= max_argsort_cols;
Expand Down
17 changes: 17 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#version 450

#include "types.glsl"
#include "generic_unary_head.glsl"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint idx = get_idx();

if (idx >= p.ne) {
return;
}

const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(ceil(val));
}
17 changes: 17 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/floor.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#version 450

#include "types.glsl"
#include "generic_unary_head.glsl"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint idx = get_idx();

if (idx >= p.ne) {
return;
}

const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(floor(val));
}
17 changes: 17 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/round.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#version 450

#include "types.glsl"
#include "generic_unary_head.glsl"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint idx = get_idx();

if (idx >= p.ne) {
return;
}

const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(round(val));
}
17 changes: 17 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#version 450

#include "types.glsl"
#include "generic_unary_head.glsl"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint idx = get_idx();

if (idx >= p.ne) {
return;
}

const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(trunc(val));
}
8 changes: 8 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,14 @@ void process_shaders() {

string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("floor_f32", "floor.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("ceil_f32", "ceil.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("round_f32", "round.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
Expand Down
Loading
Loading