@@ -8687,41 +8687,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
86878687 GGML_UNUSED(src2);
86888688}
86898689
8690- static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
8691- switch (op) {
8692- case GGML_OP_CPY:
8693- case GGML_OP_GET_ROWS:
8694- case GGML_OP_ADD:
8695- case GGML_OP_SUB:
8696- case GGML_OP_MUL:
8697- case GGML_OP_DIV:
8698- case GGML_OP_ADD_ID:
8699- case GGML_OP_CONCAT:
8700- case GGML_OP_UPSCALE:
8701- case GGML_OP_SQR:
8702- case GGML_OP_SQRT:
8703- case GGML_OP_SIN:
8704- case GGML_OP_COS:
8705- case GGML_OP_LOG:
8706- case GGML_OP_CLAMP:
8707- case GGML_OP_PAD:
8708- case GGML_OP_REPEAT:
8709- case GGML_OP_REPEAT_BACK:
8710- case GGML_OP_ROPE:
8711- case GGML_OP_RMS_NORM:
8712- case GGML_OP_CONV_2D_DW:
8713- case GGML_OP_IM2COL:
8714- case GGML_OP_IM2COL_3D:
8715- case GGML_OP_SET_ROWS:
8716- case GGML_OP_SUM:
8717- case GGML_OP_SUM_ROWS:
8718- case GGML_OP_MEAN:
8719- return true;
8720- default:
8721- return false;
8722- }
8723- }
8724-
87258690template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
87268691 const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
87278692 const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8806,7 +8771,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88068771 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
88078772 std::cerr << "), " << ggml_op_name(op) << ")");
88088773 GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
8809- GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
88108774 GGML_ASSERT(dst->buffer != nullptr);
88118775 const uint64_t ne00 = src0->ne[0];
88128776 const uint64_t ne01 = src0->ne[1];
@@ -8837,22 +8801,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88378801
88388802 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
88398803
8840- const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
8841-
8842- vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous);
8843- vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
8844- vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
8845- vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
8846- vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous);
8804+ vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true);
8805+ vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{};
8806+ vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{};
8807+ vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{};
8808+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
88478809
88488810 // Compute misalignment offset for descriptors and store it in in push constants.
88498811 init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
88508812
88518813 std::array<uint32_t, 3> elements;
88528814
8853- // Single call if dimension 2 is contiguous
8854- GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
8855-
88568815 switch (op) {
88578816 case GGML_OP_NORM:
88588817 case GGML_OP_RMS_NORM_BACK:
@@ -13876,15 +13835,17 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1387613835 op->type == GGML_TYPE_F32;
1387713836 case GGML_OP_SILU_BACK:
1387813837 case GGML_OP_RMS_NORM_BACK:
13838+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1387913839 case GGML_OP_SQR:
1388013840 case GGML_OP_SQRT:
1388113841 case GGML_OP_SIN:
1388213842 case GGML_OP_COS:
1388313843 case GGML_OP_CLAMP:
13844+ return op->src[0]->type == GGML_TYPE_F32;
1388413845 case GGML_OP_LEAKY_RELU:
1388513846 case GGML_OP_OPT_STEP_ADAMW:
1388613847 case GGML_OP_OPT_STEP_SGD:
13887- return op->src[0]->type == GGML_TYPE_F32;
13848+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1388813849 case GGML_OP_LOG:
1388913850 return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
1389013851 case GGML_OP_ARGSORT:
@@ -13919,17 +13880,29 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1391913880 return true;
1392013881 case GGML_OP_UPSCALE:
1392113882 case GGML_OP_ACC:
13883+ return op->src[0]->type == GGML_TYPE_F32;
1392213884 case GGML_OP_CONCAT:
13885+ return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
1392313886 case GGML_OP_ADD1:
13887+ return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32)
13888+ || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
13889+ || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
1392413890 case GGML_OP_ARANGE:
1392513891 case GGML_OP_FILL:
13892+ return op->type == GGML_TYPE_F32;
1392613893 case GGML_OP_SCALE:
13894+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1392713895 case GGML_OP_PAD:
1392813896 case GGML_OP_ROLL:
13897+ return op->src[0]->type == GGML_TYPE_F32;
1392913898 case GGML_OP_DIAG_MASK_INF:
13899+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1393013900 case GGML_OP_SOFT_MAX:
13901+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
13902+ && (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16));
1393113903 case GGML_OP_SOFT_MAX_BACK:
13932- return true;
13904+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
13905+ && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32;
1393313906 case GGML_OP_SUM:
1393413907 case GGML_OP_SUM_ROWS:
1393513908 case GGML_OP_MEAN:
@@ -13944,15 +13917,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1394413917 return false;
1394513918 }
1394613919 case GGML_OP_ARGMAX:
13920+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1394713921 case GGML_OP_COUNT_EQUAL:
13922+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32
13923+ && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32;
1394813924 case GGML_OP_IM2COL:
13925+ return ggml_is_contiguous(op->src[1])
13926+ && op->src[1]->type == GGML_TYPE_F32
13927+ && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
1394913928 case GGML_OP_IM2COL_3D:
13929+ return op->src[1]->type == GGML_TYPE_F32
13930+ && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
1395013931 case GGML_OP_TIMESTEP_EMBEDDING:
13932+ return op->src[0]->type == GGML_TYPE_F32;
1395113933 case GGML_OP_CONV_2D_DW:
13934+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)
13935+ && op->src[1]->type == GGML_TYPE_F32;
1395213936 case GGML_OP_POOL_2D:
13937+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1395313938 case GGML_OP_RWKV_WKV6:
1395413939 case GGML_OP_RWKV_WKV7:
13955- return true;
13940+ return true; // all inputs are contiguous, see ggml.c
1395613941 case GGML_OP_SSM_SCAN:
1395713942 {
1395813943 for (int i = 0; i < 6; i++) {
@@ -13993,7 +13978,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1399313978 return true;
1399413979 }
1399513980 case GGML_OP_SSM_CONV:
13996- return true ;
13981+ return op->src[0]->type == GGML_TYPE_F32 ;
1399713982 case GGML_OP_CONV_TRANSPOSE_1D:
1399813983 return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1399913984 case GGML_OP_CONV_2D:
0 commit comments