Skip to content

Commit ff90529

Browse files
authored
cuda/vulkan: specify fp32-only support for some operations in supports_op (#1129)
* cuda: restrict SILU_BACK to fp32, since fp16 exceeds the desired test threshold * vulkan: specify fp32-only support for certain ops (that are now tested for fp16 as well) * f32 sigmoid in vulkan supports op * Revert "f32 sigmoid in vulkan supports op" This reverts commit c6f04b3.
1 parent 0d1ea2e commit ff90529

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31533153
return false;
31543154
} break;
31553155
case GGML_OP_SILU_BACK:
3156-
return ggml_is_contiguous(op->src[0]);
3156+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
31573157
break;
31583158
case GGML_OP_NORM:
31593159
case GGML_OP_RMS_NORM:

src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8371,7 +8371,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83718371
case GGML_UNARY_OP_SILU:
83728372
case GGML_UNARY_OP_RELU:
83738373
case GGML_UNARY_OP_TANH:
8374-
return ggml_is_contiguous(op->src[0]);
8374+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
83758375
default:
83768376
return false;
83778377
}
@@ -8571,17 +8571,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
85718571
case GGML_OP_RMS_NORM:
85728572
return ggml_is_contiguous(op->src[0]);
85738573
case GGML_OP_ADD:
8574-
case GGML_OP_ACC:
85758574
case GGML_OP_SUB:
85768575
case GGML_OP_MUL:
85778576
case GGML_OP_DIV:
8578-
case GGML_OP_CONCAT:
8579-
case GGML_OP_UPSCALE:
8580-
case GGML_OP_SCALE:
85818577
case GGML_OP_SQR:
85828578
case GGML_OP_SIN:
85838579
case GGML_OP_COS:
85848580
case GGML_OP_CLAMP:
8581+
return op->src[0]->type == GGML_TYPE_F32;
8582+
case GGML_OP_ACC:
8583+
case GGML_OP_CONCAT:
8584+
case GGML_OP_UPSCALE:
8585+
case GGML_OP_SCALE:
85858586
case GGML_OP_PAD:
85868587
case GGML_OP_DIAG_MASK_INF:
85878588
case GGML_OP_SOFT_MAX:

tests/test-backend-ops.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3980,10 +3980,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39803980

39813981
test_cases.emplace_back(new test_add1());
39823982
test_cases.emplace_back(new test_scale());
3983-
3984-
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3985-
test_cases.emplace_back(new test_silu_back());
3986-
}
3983+
test_cases.emplace_back(new test_silu_back());
39873984

39883985
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
39893986
for (bool v : {false, true}) {

0 commit comments

Comments
 (0)