Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false;
} break;
case GGML_OP_SILU_BACK:
return ggml_is_contiguous(op->src[0]);
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
Expand Down
11 changes: 6 additions & 5 deletions src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8371,7 +8371,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH:
return ggml_is_contiguous(op->src[0]);
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
}
Expand Down Expand Up @@ -8571,17 +8571,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_RMS_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_PAD:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
Expand Down
5 changes: 1 addition & 4 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3980,10 +3980,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {

test_cases.emplace_back(new test_add1());
test_cases.emplace_back(new test_scale());

for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
test_cases.emplace_back(new test_silu_back());
}
test_cases.emplace_back(new test_silu_back());

for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
for (bool v : {false, true}) {
Expand Down