Skip to content

Commit b78db3b

Browse files
authored
vulkan : move contiguous checks to device_supports_op (#17490)
* vulkan : remove op_supports_incontiguous and add missing constraints in device_supports_op * im2col: remove contraints on src0 (kernel input)
1 parent 142df17 commit b78db3b

File tree

1 file changed

+35
-50
lines changed

1 file changed

+35
-50
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 35 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8687,41 +8687,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
86878687
GGML_UNUSED(src2);
86888688
}
86898689

8690-
static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
8691-
switch (op) {
8692-
case GGML_OP_CPY:
8693-
case GGML_OP_GET_ROWS:
8694-
case GGML_OP_ADD:
8695-
case GGML_OP_SUB:
8696-
case GGML_OP_MUL:
8697-
case GGML_OP_DIV:
8698-
case GGML_OP_ADD_ID:
8699-
case GGML_OP_CONCAT:
8700-
case GGML_OP_UPSCALE:
8701-
case GGML_OP_SQR:
8702-
case GGML_OP_SQRT:
8703-
case GGML_OP_SIN:
8704-
case GGML_OP_COS:
8705-
case GGML_OP_LOG:
8706-
case GGML_OP_CLAMP:
8707-
case GGML_OP_PAD:
8708-
case GGML_OP_REPEAT:
8709-
case GGML_OP_REPEAT_BACK:
8710-
case GGML_OP_ROPE:
8711-
case GGML_OP_RMS_NORM:
8712-
case GGML_OP_CONV_2D_DW:
8713-
case GGML_OP_IM2COL:
8714-
case GGML_OP_IM2COL_3D:
8715-
case GGML_OP_SET_ROWS:
8716-
case GGML_OP_SUM:
8717-
case GGML_OP_SUM_ROWS:
8718-
case GGML_OP_MEAN:
8719-
return true;
8720-
default:
8721-
return false;
8722-
}
8723-
}
8724-
87258690
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
87268691
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
87278692
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8806,7 +8771,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88068771
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
88078772
std::cerr << "), " << ggml_op_name(op) << ")");
88088773
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
8809-
GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
88108774
GGML_ASSERT(dst->buffer != nullptr);
88118775
const uint64_t ne00 = src0->ne[0];
88128776
const uint64_t ne01 = src0->ne[1];
@@ -8837,22 +8801,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88378801

88388802
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
88398803

8840-
const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
8841-
8842-
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous);
8843-
vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
8844-
vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
8845-
vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
8846-
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous);
8804+
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true);
8805+
vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{};
8806+
vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{};
8807+
vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{};
8808+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
88478809

88488810
// Compute misalignment offset for descriptors and store it in in push constants.
88498811
init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
88508812

88518813
std::array<uint32_t, 3> elements;
88528814

8853-
// Single call if dimension 2 is contiguous
8854-
GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
8855-
88568815
switch (op) {
88578816
case GGML_OP_NORM:
88588817
case GGML_OP_RMS_NORM_BACK:
@@ -13876,15 +13835,17 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1387613835
op->type == GGML_TYPE_F32;
1387713836
case GGML_OP_SILU_BACK:
1387813837
case GGML_OP_RMS_NORM_BACK:
13838+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1387913839
case GGML_OP_SQR:
1388013840
case GGML_OP_SQRT:
1388113841
case GGML_OP_SIN:
1388213842
case GGML_OP_COS:
1388313843
case GGML_OP_CLAMP:
13844+
return op->src[0]->type == GGML_TYPE_F32;
1388413845
case GGML_OP_LEAKY_RELU:
1388513846
case GGML_OP_OPT_STEP_ADAMW:
1388613847
case GGML_OP_OPT_STEP_SGD:
13887-
return op->src[0]->type == GGML_TYPE_F32;
13848+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1388813849
case GGML_OP_LOG:
1388913850
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
1389013851
case GGML_OP_ARGSORT:
@@ -13919,17 +13880,29 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1391913880
return true;
1392013881
case GGML_OP_UPSCALE:
1392113882
case GGML_OP_ACC:
13883+
return op->src[0]->type == GGML_TYPE_F32;
1392213884
case GGML_OP_CONCAT:
13885+
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
1392313886
case GGML_OP_ADD1:
13887+
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32)
13888+
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
13889+
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
1392413890
case GGML_OP_ARANGE:
1392513891
case GGML_OP_FILL:
13892+
return op->type == GGML_TYPE_F32;
1392613893
case GGML_OP_SCALE:
13894+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1392713895
case GGML_OP_PAD:
1392813896
case GGML_OP_ROLL:
13897+
return op->src[0]->type == GGML_TYPE_F32;
1392913898
case GGML_OP_DIAG_MASK_INF:
13899+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1393013900
case GGML_OP_SOFT_MAX:
13901+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
13902+
&& (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16));
1393113903
case GGML_OP_SOFT_MAX_BACK:
13932-
return true;
13904+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
13905+
&& ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32;
1393313906
case GGML_OP_SUM:
1393413907
case GGML_OP_SUM_ROWS:
1393513908
case GGML_OP_MEAN:
@@ -13944,15 +13917,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1394413917
return false;
1394513918
}
1394613919
case GGML_OP_ARGMAX:
13920+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1394713921
case GGML_OP_COUNT_EQUAL:
13922+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32
13923+
&& ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32;
1394813924
case GGML_OP_IM2COL:
13925+
return ggml_is_contiguous(op->src[1])
13926+
&& op->src[1]->type == GGML_TYPE_F32
13927+
&& (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
1394913928
case GGML_OP_IM2COL_3D:
13929+
return op->src[1]->type == GGML_TYPE_F32
13930+
&& (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
1395013931
case GGML_OP_TIMESTEP_EMBEDDING:
13932+
return op->src[0]->type == GGML_TYPE_F32;
1395113933
case GGML_OP_CONV_2D_DW:
13934+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)
13935+
&& op->src[1]->type == GGML_TYPE_F32;
1395213936
case GGML_OP_POOL_2D:
13937+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1395313938
case GGML_OP_RWKV_WKV6:
1395413939
case GGML_OP_RWKV_WKV7:
13955-
return true;
13940+
return true; // all inputs are contiguous, see ggml.c
1395613941
case GGML_OP_SSM_SCAN:
1395713942
{
1395813943
for (int i = 0; i < 6; i++) {
@@ -13993,7 +13978,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1399313978
return true;
1399413979
}
1399513980
case GGML_OP_SSM_CONV:
13996-
return true;
13981+
return op->src[0]->type == GGML_TYPE_F32;
1399713982
case GGML_OP_CONV_TRANSPOSE_1D:
1399813983
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1399913984
case GGML_OP_CONV_2D:

0 commit comments

Comments
 (0)