Skip to content

Commit e6f51f9

Browse files
cmdr2ggerganov
authored andcommitted
cuda/vulkan: specify fp32-only support for some operations in supports_op (ggml/1129)
ggml-ci
1 parent 5e079c1 commit e6f51f9

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31533153
return false;
31543154
} break;
31553155
case GGML_OP_SILU_BACK:
3156-
return ggml_is_contiguous(op->src[0]);
3156+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
31573157
break;
31583158
case GGML_OP_NORM:
31593159
case GGML_OP_RMS_NORM:

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8450,7 +8450,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
84508450
case GGML_UNARY_OP_RELU:
84518451
case GGML_UNARY_OP_TANH:
84528452
case GGML_UNARY_OP_SIGMOID:
8453-
return ggml_is_contiguous(op->src[0]);
8453+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
84548454
default:
84558455
return false;
84568456
}
@@ -8651,19 +8651,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
86518651
case GGML_OP_RMS_NORM:
86528652
return ggml_is_contiguous(op->src[0]);
86538653
case GGML_OP_ADD:
8654-
case GGML_OP_ACC:
86558654
case GGML_OP_SUB:
86568655
case GGML_OP_MUL:
86578656
case GGML_OP_DIV:
8658-
case GGML_OP_CONCAT:
86598657
case GGML_OP_SILU_BACK:
86608658
case GGML_OP_RMS_NORM_BACK:
8661-
case GGML_OP_UPSCALE:
8662-
case GGML_OP_SCALE:
86638659
case GGML_OP_SQR:
86648660
case GGML_OP_SIN:
86658661
case GGML_OP_COS:
86668662
case GGML_OP_CLAMP:
8663+
return op->src[0]->type == GGML_TYPE_F32;
8664+
case GGML_OP_ACC:
8665+
case GGML_OP_CONCAT:
8666+
case GGML_OP_UPSCALE:
8667+
case GGML_OP_SCALE:
86678668
case GGML_OP_PAD:
86688669
case GGML_OP_DIAG_MASK_INF:
86698670
case GGML_OP_SOFT_MAX:

tests/test-backend-ops.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3980,10 +3980,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39803980

39813981
test_cases.emplace_back(new test_add1());
39823982
test_cases.emplace_back(new test_scale());
3983-
3984-
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3985-
test_cases.emplace_back(new test_silu_back());
3986-
}
3983+
test_cases.emplace_back(new test_silu_back());
39873984

39883985
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
39893986
for (bool v : {false, true}) {

0 commit comments

Comments
 (0)