Skip to content

Commit b43abde

Browse files
committed
Fix: Use GGML_UNARY_OP_* instead of GGML_OP_* for FLOOR/CEIL/ROUND/TRUNC in Vulkan backend
1 parent 20664ef commit b43abde

File tree

2 files changed

+51
-117
lines changed

2 files changed

+51
-117
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 51 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7900,26 +7900,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
79007900
return ctx->device->pipeline_cos_f32;
79017901
}
79027902
return nullptr;
7903-
case GGML_OP_FLOOR:
7904-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7905-
return ctx->device->pipeline_floor_f32;
7906-
}
7907-
return nullptr;
7908-
case GGML_OP_CEIL:
7909-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7910-
return ctx->device->pipeline_ceil_f32;
7911-
}
7912-
return nullptr;
7913-
case GGML_OP_ROUND:
7914-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7915-
return ctx->device->pipeline_round_f32;
7916-
}
7917-
return nullptr;
7918-
case GGML_OP_TRUNC:
7919-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7920-
return ctx->device->pipeline_trunc_f32;
7921-
}
7922-
return nullptr;
79237903
case GGML_OP_CLAMP:
79247904
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
79257905
return ctx->device->pipeline_clamp_f32;
@@ -8017,6 +7997,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
80177997
return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
80187998
case GGML_UNARY_OP_HARDSWISH:
80197999
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
8000+
case GGML_UNARY_OP_FLOOR:
8001+
if (dst->type == GGML_TYPE_F32) {
8002+
return ctx->device->pipeline_floor_f32;
8003+
}
8004+
break;
8005+
case GGML_UNARY_OP_CEIL:
8006+
if (dst->type == GGML_TYPE_F32) {
8007+
return ctx->device->pipeline_ceil_f32;
8008+
}
8009+
break;
8010+
case GGML_UNARY_OP_ROUND:
8011+
if (dst->type == GGML_TYPE_F32) {
8012+
return ctx->device->pipeline_round_f32;
8013+
}
8014+
break;
8015+
case GGML_UNARY_OP_TRUNC:
8016+
if (dst->type == GGML_TYPE_F32) {
8017+
return ctx->device->pipeline_trunc_f32;
8018+
}
8019+
break;
80208020
default:
80218021
break;
80228022
}
@@ -8297,10 +8297,6 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
82978297
case GGML_OP_SUM:
82988298
case GGML_OP_SUM_ROWS:
82998299
case GGML_OP_MEAN:
8300-
case GGML_OP_FLOOR:
8301-
case GGML_OP_CEIL:
8302-
case GGML_OP_ROUND:
8303-
case GGML_OP_TRUNC:
83048300
return true;
83058301
default:
83068302
return false;
@@ -8641,10 +8637,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
86418637
case GGML_OP_UNARY:
86428638
case GGML_OP_GLU:
86438639
case GGML_OP_CONV_2D_DW:
8644-
case GGML_OP_FLOOR:
8645-
case GGML_OP_CEIL:
8646-
case GGML_OP_ROUND:
8647-
case GGML_OP_TRUNC:
86488640
{
86498641
uint32_t ne = ggml_nelements(dst);
86508642
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
@@ -9432,19 +9424,19 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
94329424
}
94339425

94349426
static void ggml_vk_floor(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9435-
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_FLOOR, vk_op_unary_push_constants_init(src0, dst), dryrun);
9427+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, vk_op_unary_push_constants_init(src0, dst), dryrun);
94369428
}
94379429

94389430
static void ggml_vk_ceil(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9439-
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CEIL, vk_op_unary_push_constants_init(src0, dst), dryrun);
9431+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, vk_op_unary_push_constants_init(src0, dst), dryrun);
94409432
}
94419433

94429434
static void ggml_vk_round(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9443-
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROUND, vk_op_unary_push_constants_init(src0, dst), dryrun);
9435+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, vk_op_unary_push_constants_init(src0, dst), dryrun);
94449436
}
94459437

94469438
static void ggml_vk_trunc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9447-
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TRUNC, vk_op_unary_push_constants_init(src0, dst), dryrun);
9439+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, vk_op_unary_push_constants_init(src0, dst), dryrun);
94489440
}
94499441

94509442
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -11104,6 +11096,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1110411096
case GGML_UNARY_OP_SIGMOID:
1110511097
case GGML_UNARY_OP_HARDSIGMOID:
1110611098
case GGML_UNARY_OP_HARDSWISH:
11099+
case GGML_UNARY_OP_FLOOR:
11100+
case GGML_UNARY_OP_CEIL:
11101+
case GGML_UNARY_OP_ROUND:
11102+
case GGML_UNARY_OP_TRUNC:
1110711103
break;
1110811104
default:
1110911105
return false;
@@ -11193,10 +11189,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1119311189
case GGML_OP_FLASH_ATTN_EXT:
1119411190
case GGML_OP_OPT_STEP_ADAMW:
1119511191
case GGML_OP_OPT_STEP_SGD:
11196-
case GGML_OP_FLOOR:
11197-
case GGML_OP_CEIL:
11198-
case GGML_OP_ROUND:
11199-
case GGML_OP_TRUNC:
1120011192
break;
1120111193
default:
1120211194
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -11265,10 +11257,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1126511257
case GGML_OP_CONV_2D_DW:
1126611258
case GGML_OP_LEAKY_RELU:
1126711259
case GGML_OP_OPT_STEP_SGD:
11268-
case GGML_OP_FLOOR:
11269-
case GGML_OP_CEIL:
11270-
case GGML_OP_ROUND:
11271-
case GGML_OP_TRUNC:
1127211260
{
1127311261
// These operations all go through ggml_vk_op_f32, so short-circuit and
1127411262
// do the only thing needed for the dryrun.
@@ -11425,22 +11413,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1142511413
case GGML_OP_COS:
1142611414
ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun);
1142711415

11428-
break;
11429-
case GGML_OP_FLOOR:
11430-
ggml_vk_floor(ctx, compute_ctx, src0, node, dryrun);
11431-
11432-
break;
11433-
case GGML_OP_CEIL:
11434-
ggml_vk_ceil(ctx, compute_ctx, src0, node, dryrun);
11435-
11436-
break;
11437-
case GGML_OP_ROUND:
11438-
ggml_vk_round(ctx, compute_ctx, src0, node, dryrun);
11439-
11440-
break;
11441-
case GGML_OP_TRUNC:
11442-
ggml_vk_trunc(ctx, compute_ctx, src0, node, dryrun);
11443-
1144411416
break;
1144511417
case GGML_OP_CLAMP:
1144611418
ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun);
@@ -11506,6 +11478,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1150611478
case GGML_UNARY_OP_SIGMOID:
1150711479
case GGML_UNARY_OP_HARDSIGMOID:
1150811480
case GGML_UNARY_OP_HARDSWISH:
11481+
case GGML_UNARY_OP_FLOOR:
11482+
case GGML_UNARY_OP_CEIL:
11483+
case GGML_UNARY_OP_ROUND:
11484+
case GGML_UNARY_OP_TRUNC:
1150911485
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
1151011486
break;
1151111487
default:
@@ -13352,10 +13328,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1335213328
case GGML_OP_LEAKY_RELU:
1335313329
case GGML_OP_OPT_STEP_ADAMW:
1335413330
case GGML_OP_OPT_STEP_SGD:
13355-
case GGML_OP_FLOOR:
13356-
case GGML_OP_CEIL:
13357-
case GGML_OP_ROUND:
13358-
case GGML_OP_TRUNC:
1335913331
return op->src[0]->type == GGML_TYPE_F32;
1336013332
case GGML_OP_ARGSORT:
1336113333
return op->ne[0] <= max_argsort_cols;
@@ -13869,20 +13841,30 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1386913841
tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
1387013842
} else if (tensor->op == GGML_OP_COS) {
1387113843
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
13872-
} else if (tensor->op == GGML_OP_FLOOR) {
13873-
tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
13874-
} else if (tensor->op == GGML_OP_CEIL) {
13875-
tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]);
13876-
} else if (tensor->op == GGML_OP_ROUND) {
13877-
tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
13878-
} else if (tensor->op == GGML_OP_TRUNC) {
13879-
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
13844+
} else if (tensor->op == GGML_OP_UNARY) {
13845+
switch (ggml_get_unary_op(tensor)) {
13846+
case GGML_UNARY_OP_FLOOR:
13847+
tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
13848+
break;
13849+
case GGML_UNARY_OP_CEIL:
13850+
tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]);
13851+
break;
13852+
case GGML_UNARY_OP_ROUND:
13853+
tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
13854+
break;
13855+
case GGML_UNARY_OP_TRUNC:
13856+
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
13857+
break;
13858+
default:
13859+
std::cerr << "Unsupported unary op: " << ggml_unary_op_name(ggml_get_unary_op(tensor)) << std::endl;
13860+
GGML_ABORT("fatal error");
13861+
}
1388013862
} else if (tensor->op == GGML_OP_CLAMP) {
1388113863
const float * params = (const float *)tensor->op_params;
1388213864
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
1388313865
} else if (tensor->op == GGML_OP_PAD) {
1388413866
tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3],
13885-
tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
13867+
tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
1388613868
} else if (tensor->op == GGML_OP_REPEAT) {
1388713869
tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
1388813870
} else if (tensor->op == GGML_OP_REPEAT_BACK) {

tests/test-backend-ops.cpp

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3753,18 +3753,6 @@ struct test_floor : public test_case {
37533753
init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].
37543754
}
37553755
}
3756-
3757-
double max_maa_err() override {
3758-
return 1e-3;
3759-
}
3760-
3761-
float grad_eps() override {
3762-
return 0.2f;
3763-
}
3764-
3765-
bool grad_precise() override {
3766-
return true;
3767-
}
37683756
};
37693757
// GGML_OP_CEIL
37703758
struct test_ceil : public test_case {
@@ -3795,18 +3783,6 @@ struct test_ceil : public test_case {
37953783
init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].
37963784
}
37973785
}
3798-
3799-
double max_maa_err() override {
3800-
return 1e-3;
3801-
}
3802-
3803-
float grad_eps() override {
3804-
return 0.2f;
3805-
}
3806-
3807-
bool grad_precise() override {
3808-
return true;
3809-
}
38103786
};
38113787

38123788
// GGML_OP_ROUND
@@ -3838,18 +3814,6 @@ struct test_round : public test_case {
38383814
init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].
38393815
}
38403816
}
3841-
3842-
double max_maa_err() override {
3843-
return 1e-3;
3844-
}
3845-
3846-
float grad_eps() override {
3847-
return 0.2f;
3848-
}
3849-
3850-
bool grad_precise() override {
3851-
return true;
3852-
}
38533817
};
38543818

38553819
// GGML_OP_TRUNC
@@ -3881,18 +3845,6 @@ struct test_trunc : public test_case {
38813845
init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].
38823846
}
38833847
}
3884-
3885-
double max_maa_err() override {
3886-
return 1e-3;
3887-
}
3888-
3889-
float grad_eps() override {
3890-
return 0.2f;
3891-
}
3892-
3893-
bool grad_precise() override {
3894-
return true;
3895-
}
38963848
};
38973849

38983850
// GGML_OP_CLAMP

0 commit comments

Comments
 (0)