@@ -211,6 +211,7 @@ struct vk_device_struct {
211211 vk_pipeline pipeline_sum_rows_f32;
212212 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
213213 vk_pipeline pipeline_timestep_embedding_f32;
214+ vk_pipeline pipeline_pool2d_f32;
214215
215216 std::unordered_map<std::string, vk_pipeline_ref> pipelines;
216217 std::unordered_map<std::string, uint64_t > pipeline_descriptor_set_requirements;
@@ -401,6 +402,17 @@ struct vk_op_timestep_embedding_push_constants {
401402 uint32_t max_period;
402403};
403404
405+ struct vk_op_pool2d_push_constants {
406+ uint32_t IW; uint32_t IH;
407+ uint32_t OW; uint32_t OH;
408+ uint32_t OC;
409+ uint32_t pelements;
410+ uint32_t op;
411+ int32_t k0; int32_t k1;
412+ int32_t s0; int32_t s1;
413+ int32_t p0; int32_t p1;
414+ };
415+
404416// Allow pre-recording command buffers
405417struct vk_staging_memcpy {
406418 vk_staging_memcpy (void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1743,6 +1755,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
17431755 ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_len, im2col_f32_f16_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {256 , 1 , 1 }, {}, 1 );
17441756
17451757 ggml_vk_create_pipeline (device, device->pipeline_timestep_embedding_f32 , " timestep_embedding_f32" , timestep_embedding_f32_len, timestep_embedding_f32_data, " main" , 2 , sizeof (vk_op_timestep_embedding_push_constants), {256 , 1 , 1 }, {}, 1 );
1758+ ggml_vk_create_pipeline (device, device->pipeline_pool2d_f32 , " pool2d_f32" , pool2d_f32_len, pool2d_f32_data, " main" , 2 , sizeof (vk_op_pool2d_push_constants), {512 , 1 , 1 }, {}, 1 );
17461759}
17471760
17481761static vk_device ggml_vk_get_device (size_t idx) {
@@ -4170,6 +4183,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
41704183 return ctx->device ->pipeline_timestep_embedding_f32 ;
41714184 }
41724185 return nullptr ;
4186+ case GGML_OP_POOL_2D:
4187+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4188+ return ctx->device ->pipeline_pool2d_f32 ;
4189+ }
4190+ return nullptr ;
41734191 case GGML_OP_LEAKY_RELU:
41744192 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
41754193 return ctx->device ->pipeline_leaky_relu_f32 ;
@@ -4400,6 +4418,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
44004418 uint32_t half_ceil = (dim + 1 ) / 2 ;
44014419 elements = { half_ceil, (uint32_t )src0->ne [0 ], 1 };
44024420 } break ;
4421+ case GGML_OP_POOL_2D:
4422+ {
4423+ const uint32_t N = dst->ne [3 ];
4424+ const uint32_t OC = dst->ne [2 ];
4425+ const uint32_t OH = dst->ne [1 ];
4426+ const uint32_t OW = dst->ne [0 ];
4427+ elements = { N * OC * OH * OW, 1 , 1 };
4428+ } break ;
44034429 case GGML_OP_ADD:
44044430 case GGML_OP_DIV:
44054431 case GGML_OP_MUL:
@@ -4852,6 +4878,34 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
48524878 }, dryrun);
48534879}
48544880
4881+ static void ggml_vk_pool_2d (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
4882+ uint32_t op = static_cast <uint32_t >(dst->op_params [0 ]);
4883+ const int32_t k0 = dst->op_params [1 ];
4884+ const int32_t k1 = dst->op_params [2 ];
4885+ const int32_t s0 = dst->op_params [3 ];
4886+ const int32_t s1 = dst->op_params [4 ];
4887+ const int32_t p0 = dst->op_params [5 ];
4888+ const int32_t p1 = dst->op_params [6 ];
4889+
4890+ const uint32_t IH = src0->ne [1 ];
4891+ const uint32_t IW = src0->ne [0 ];
4892+
4893+ const uint32_t N = dst->ne [3 ];
4894+
4895+ const uint32_t OC = dst->ne [2 ];
4896+ const uint32_t OH = dst->ne [1 ];
4897+ const uint32_t OW = dst->ne [0 ];
4898+
4899+ const uint32_t parallel_elements = N * OC * OH * OW;
4900+
4901+ ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_POOL_2D, {
4902+ IW, IH, OW, OH, OC,
4903+ parallel_elements,
4904+ op,
4905+ k0, k1, s0, s1, p0, p1,
4906+ }, dryrun);
4907+ }
4908+
48554909static void ggml_vk_leaky_relu (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
48564910 const float * op_params = (const float *)dst->op_params ;
48574911 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_LEAKY_RELU, { (uint32_t )ggml_nelements (src0), 0 , op_params[0 ], 0 .0f }, dryrun);
@@ -5733,6 +5787,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
57335787 case GGML_OP_SUM_ROWS:
57345788 case GGML_OP_IM2COL:
57355789 case GGML_OP_TIMESTEP_EMBEDDING:
5790+ case GGML_OP_POOL_2D:
57365791 case GGML_OP_LEAKY_RELU:
57375792 break ;
57385793 default :
@@ -5868,6 +5923,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
58685923 case GGML_OP_TIMESTEP_EMBEDDING:
58695924 ggml_vk_timestep_embedding (ctx, compute_ctx, src0, node, dryrun);
58705925
5926+ break ;
5927+ case GGML_OP_POOL_2D:
5928+ ggml_vk_pool_2d (ctx, compute_ctx, src0, node, dryrun);
5929+
58715930 break ;
58725931 case GGML_OP_LEAKY_RELU:
58735932 ggml_vk_leaky_relu (ctx, compute_ctx, src0, node, dryrun);
@@ -5959,6 +6018,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
59596018 case GGML_OP_SUM_ROWS:
59606019 case GGML_OP_IM2COL:
59616020 case GGML_OP_TIMESTEP_EMBEDDING:
6021+ case GGML_OP_POOL_2D:
59626022 case GGML_OP_LEAKY_RELU:
59636023 case GGML_OP_REPEAT:
59646024 extra = (ggml_tensor_extra_gpu *) tensor->extra ;
@@ -6685,6 +6745,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
66856745 case GGML_OP_SUM_ROWS:
66866746 case GGML_OP_IM2COL:
66876747 case GGML_OP_TIMESTEP_EMBEDDING:
6748+ case GGML_OP_POOL_2D:
66886749 case GGML_OP_LEAKY_RELU:
66896750 return true ;
66906751 default :
@@ -7217,6 +7278,16 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
72177278 const int32_t dim = tensor->op_params [0 ];
72187279 const int32_t max_period = tensor->op_params [1 ];
72197280 tensor_clone = ggml_timestep_embedding (ggml_ctx, src0_clone, dim, max_period);
7281+ } else if (tensor->op == GGML_OP_POOL_2D) {
7282+ enum ggml_op_pool op = static_cast <ggml_op_pool>(dst->op_params [0 ]);
7283+ const int32_t k0 = tensor->op_params [1 ];
7284+ const int32_t k1 = tensor->op_params [2 ];
7285+ const int32_t s0 = tensor->op_params [3 ];
7286+ const int32_t s1 = tensor->op_params [4 ];
7287+ const int32_t p0 = tensor->op_params [5 ];
7288+ const int32_t p1 = tensor->op_params [6 ];
7289+
7290+ tensor_clone = ggml_pool_2d (ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
72207291 } else if (tensor->op == GGML_OP_LEAKY_RELU) {
72217292 const float * op_params = (const float *)tensor->op_params ;
72227293 tensor_clone = ggml_leaky_relu (ggml_ctx, src0_clone, op_params[0 ], false );
0 commit comments