@@ -549,6 +549,10 @@ struct vk_device_struct {
549549 vk_pipeline pipeline_rms_norm_mul_partials_f32;
550550 vk_pipeline pipeline_rms_norm_back_f32;
551551 vk_pipeline pipeline_l2_norm_f32;
552+ vk_pipeline pipeline_floor_f32;
553+ vk_pipeline pipeline_ceil_f32;
554+ vk_pipeline pipeline_round_f32;
555+ vk_pipeline pipeline_trunc_f32;
552556
553557 // [src/dst 0=fp32,1=fp16]
554558 vk_pipeline pipeline_exp[2];
@@ -3516,6 +3520,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
35163520 ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
35173521 ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
35183522
3523+ ggml_vk_create_pipeline(device, device->pipeline_floor_f32, "floor_f32", floor_f32_len, floor_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3524+ ggml_vk_create_pipeline(device, device->pipeline_ceil_f32, "ceil_f32", ceil_f32_len, ceil_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3525+ ggml_vk_create_pipeline(device, device->pipeline_round_f32, "round_f32", round_f32_len, round_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3526+ ggml_vk_create_pipeline(device, device->pipeline_trunc_f32, "trunc_f32", trunc_f32_len, trunc_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3527+
35193528#define CREATE_UNARY(name) \
35203529 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
35213530 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -7891,6 +7900,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
78917900 return ctx->device->pipeline_cos_f32;
78927901 }
78937902 return nullptr;
7903+ case GGML_OP_FLOOR:
7904+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7905+ return ctx->device->pipeline_floor_f32;
7906+ }
7907+ return nullptr;
7908+ case GGML_OP_CEIL:
7909+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7910+ return ctx->device->pipeline_ceil_f32;
7911+ }
7912+ return nullptr;
7913+ case GGML_OP_ROUND:
7914+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7915+ return ctx->device->pipeline_round_f32;
7916+ }
7917+ return nullptr;
7918+ case GGML_OP_TRUNC:
7919+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7920+ return ctx->device->pipeline_trunc_f32;
7921+ }
7922+ return nullptr;
78947923 case GGML_OP_CLAMP:
78957924 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
78967925 return ctx->device->pipeline_clamp_f32;
@@ -8268,6 +8297,10 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
82688297 case GGML_OP_SUM:
82698298 case GGML_OP_SUM_ROWS:
82708299 case GGML_OP_MEAN:
8300+ case GGML_OP_FLOOR:
8301+ case GGML_OP_CEIL:
8302+ case GGML_OP_ROUND:
8303+ case GGML_OP_TRUNC:
82718304 return true;
82728305 default:
82738306 return false;
@@ -8608,6 +8641,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
86088641 case GGML_OP_UNARY:
86098642 case GGML_OP_GLU:
86108643 case GGML_OP_CONV_2D_DW:
8644+ case GGML_OP_FLOOR:
8645+ case GGML_OP_CEIL:
8646+ case GGML_OP_ROUND:
8647+ case GGML_OP_TRUNC:
86118648 {
86128649 uint32_t ne = ggml_nelements(dst);
86138650 if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
@@ -9394,6 +9431,22 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
93949431 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
93959432}
93969433
9434+ static void ggml_vk_floor(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9435+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_FLOOR, vk_op_unary_push_constants_init(src0, dst), dryrun);
9436+ }
9437+
9438+ static void ggml_vk_ceil(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9439+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CEIL, vk_op_unary_push_constants_init(src0, dst), dryrun);
9440+ }
9441+
9442+ static void ggml_vk_round(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9443+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROUND, vk_op_unary_push_constants_init(src0, dst), dryrun);
9444+ }
9445+
9446+ static void ggml_vk_trunc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9447+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TRUNC, vk_op_unary_push_constants_init(src0, dst), dryrun);
9448+ }
9449+
93979450static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
93989451 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
93999452 p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -11140,6 +11193,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1114011193 case GGML_OP_FLASH_ATTN_EXT:
1114111194 case GGML_OP_OPT_STEP_ADAMW:
1114211195 case GGML_OP_OPT_STEP_SGD:
11196+ case GGML_OP_FLOOR:
11197+ case GGML_OP_CEIL:
11198+ case GGML_OP_ROUND:
11199+ case GGML_OP_TRUNC:
1114311200 break;
1114411201 default:
1114511202 std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -11208,6 +11265,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1120811265 case GGML_OP_CONV_2D_DW:
1120911266 case GGML_OP_LEAKY_RELU:
1121011267 case GGML_OP_OPT_STEP_SGD:
11268+ case GGML_OP_FLOOR:
11269+ case GGML_OP_CEIL:
11270+ case GGML_OP_ROUND:
11271+ case GGML_OP_TRUNC:
1121111272 {
1121211273 // These operations all go through ggml_vk_op_f32, so short-circuit and
1121311274 // do the only thing needed for the dryrun.
@@ -11364,6 +11425,22 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1136411425 case GGML_OP_COS:
1136511426 ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun);
1136611427
11428+ break;
11429+ case GGML_OP_FLOOR:
11430+ ggml_vk_floor(ctx, compute_ctx, src0, node, dryrun);
11431+
11432+ break;
11433+ case GGML_OP_CEIL:
11434+ ggml_vk_ceil(ctx, compute_ctx, src0, node, dryrun);
11435+
11436+ break;
11437+ case GGML_OP_ROUND:
11438+ ggml_vk_round(ctx, compute_ctx, src0, node, dryrun);
11439+
11440+ break;
11441+ case GGML_OP_TRUNC:
11442+ ggml_vk_trunc(ctx, compute_ctx, src0, node, dryrun);
11443+
1136711444 break;
1136811445 case GGML_OP_CLAMP:
1136911446 ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun);
@@ -13275,6 +13352,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1327513352 case GGML_OP_LEAKY_RELU:
1327613353 case GGML_OP_OPT_STEP_ADAMW:
1327713354 case GGML_OP_OPT_STEP_SGD:
13355+ case GGML_OP_FLOOR:
13356+ case GGML_OP_CEIL:
13357+ case GGML_OP_ROUND:
13358+ case GGML_OP_TRUNC:
1327813359 return op->src[0]->type == GGML_TYPE_F32;
1327913360 case GGML_OP_ARGSORT:
1328013361 return op->ne[0] <= max_argsort_cols;
@@ -13788,12 +13869,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1378813869 tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
1378913870 } else if (tensor->op == GGML_OP_COS) {
1379013871 tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
13872+ } else if (tensor->op == GGML_OP_FLOOR) {
13873+ tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
13874+ } else if (tensor->op == GGML_OP_CEIL) {
13875+ tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]);
13876+ } else if (tensor->op == GGML_OP_ROUND) {
13877+ tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
13878+ } else if (tensor->op == GGML_OP_TRUNC) {
13879+ tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
1379113880 } else if (tensor->op == GGML_OP_CLAMP) {
1379213881 const float * params = (const float *)tensor->op_params;
1379313882 tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
1379413883 } else if (tensor->op == GGML_OP_PAD) {
1379513884 tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3],
13796- tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
13885+ tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
1379713886 } else if (tensor->op == GGML_OP_REPEAT) {
1379813887 tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
1379913888 } else if (tensor->op == GGML_OP_REPEAT_BACK) {
0 commit comments