Skip to content
Closed
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3929,13 +3929,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_SOFT_MAX:
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);

if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32;
}
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32_f16;
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (src1 == nullptr || src1->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32;
}
else if (src1->type == GGML_TYPE_F16) {
return ctx->device->pipeline_soft_max_f32_f16;
}
}
return nullptr;
case GGML_OP_ROPE:
Expand Down
Loading