@@ -213,6 +213,7 @@ struct vk_device_struct {
213213 vk_pipeline pipeline_sum_rows_f32;
214214 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
215215 vk_pipeline pipeline_timestep_embedding_f32;
216+ vk_pipeline pipeline_pool2d_f32;
216217
217218 std::unordered_map<std::string, vk_pipeline_ref> pipelines;
218219 std::unordered_map<std::string, uint64_t > pipeline_descriptor_set_requirements;
@@ -403,6 +404,17 @@ struct vk_op_timestep_embedding_push_constants {
403404 uint32_t max_period;
404405};
405406
407+ struct vk_op_pool2d_push_constants {
408+ uint32_t IW; uint32_t IH;
409+ uint32_t OW; uint32_t OH;
410+ uint32_t OC;
411+ uint32_t pelements;
412+ uint32_t op;
413+ int32_t k0; int32_t k1;
414+ int32_t s0; int32_t s1;
415+ int32_t p0; int32_t p1;
416+ };
417+
406418// Allow pre-recording command buffers
407419struct vk_staging_memcpy {
408420 vk_staging_memcpy (void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1803,6 +1815,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
18031815
18041816 ggml_vk_create_pipeline (device, device->pipeline_timestep_embedding_f32 , " timestep_embedding_f32" , timestep_embedding_f32_len, timestep_embedding_f32_data, " main" , 2 , sizeof (vk_op_timestep_embedding_push_constants), {256 , 1 , 1 }, {}, 1 );
18051817
1818+ ggml_vk_create_pipeline (device, device->pipeline_pool2d_f32 , " pool2d_f32" , pool2d_f32_len, pool2d_f32_data, " main" , 2 , sizeof (vk_op_pool2d_push_constants), {512 , 1 , 1 }, {}, 1 );
1819+
18061820 for (auto &c : compiles) {
18071821 c.wait ();
18081822 }
@@ -4234,6 +4248,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
42344248 return ctx->device ->pipeline_timestep_embedding_f32 ;
42354249 }
42364250 return nullptr ;
4251+ case GGML_OP_POOL_2D:
4252+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4253+ return ctx->device ->pipeline_pool2d_f32 ;
4254+ }
4255+ return nullptr ;
42374256 case GGML_OP_LEAKY_RELU:
42384257 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
42394258 return ctx->device ->pipeline_leaky_relu_f32 ;
@@ -4464,6 +4483,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
44644483 uint32_t half_ceil = (dim + 1 ) / 2 ;
44654484 elements = { half_ceil, (uint32_t )src0->ne [0 ], 1 };
44664485 } break ;
4486+ case GGML_OP_POOL_2D:
4487+ {
4488+ const uint32_t N = dst->ne [3 ];
4489+ const uint32_t OC = dst->ne [2 ];
4490+ const uint32_t OH = dst->ne [1 ];
4491+ const uint32_t OW = dst->ne [0 ];
4492+ elements = { N * OC * OH * OW, 1 , 1 };
4493+ } break ;
44674494 case GGML_OP_ADD:
44684495 case GGML_OP_DIV:
44694496 case GGML_OP_MUL:
@@ -4914,6 +4941,34 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
49144941 }, dryrun);
49154942}
49164943
4944+ static void ggml_vk_pool_2d (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
4945+ uint32_t op = static_cast <uint32_t >(dst->op_params [0 ]);
4946+ const int32_t k1 = dst->op_params [1 ];
4947+ const int32_t k0 = dst->op_params [2 ];
4948+ const int32_t s1 = dst->op_params [3 ];
4949+ const int32_t s0 = dst->op_params [4 ];
4950+ const int32_t p1 = dst->op_params [5 ];
4951+ const int32_t p0 = dst->op_params [6 ];
4952+
4953+ const uint32_t IH = src0->ne [1 ];
4954+ const uint32_t IW = src0->ne [0 ];
4955+
4956+ const uint32_t N = dst->ne [3 ];
4957+
4958+ const uint32_t OC = dst->ne [2 ];
4959+ const uint32_t OH = dst->ne [1 ];
4960+ const uint32_t OW = dst->ne [0 ];
4961+
4962+ const uint32_t parallel_elements = N * OC * OH * OW;
4963+
4964+ ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_POOL_2D, {
4965+ IW, IH, OW, OH, OC,
4966+ parallel_elements,
4967+ op,
4968+ k0, k1, s0, s1, p0, p1,
4969+ }, dryrun);
4970+ }
4971+
49174972static void ggml_vk_leaky_relu (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
49184973 const float * op_params = (const float *)dst->op_params ;
49194974 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_LEAKY_RELU, { (uint32_t )ggml_nelements (src0), 0 , op_params[0 ], 0 .0f }, dryrun);
@@ -5792,6 +5847,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
57925847 case GGML_OP_SUM_ROWS:
57935848 case GGML_OP_IM2COL:
57945849 case GGML_OP_TIMESTEP_EMBEDDING:
5850+ case GGML_OP_POOL_2D:
57955851 case GGML_OP_LEAKY_RELU:
57965852 break ;
57975853 default :
@@ -5927,6 +5983,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
59275983 case GGML_OP_TIMESTEP_EMBEDDING:
59285984 ggml_vk_timestep_embedding (ctx, compute_ctx, src0, node, dryrun);
59295985
5986+ break ;
5987+ case GGML_OP_POOL_2D:
5988+ ggml_vk_pool_2d (ctx, compute_ctx, src0, node, dryrun);
5989+
59305990 break ;
59315991 case GGML_OP_LEAKY_RELU:
59325992 ggml_vk_leaky_relu (ctx, compute_ctx, src0, node, dryrun);
@@ -6018,6 +6078,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
60186078 case GGML_OP_SUM_ROWS:
60196079 case GGML_OP_IM2COL:
60206080 case GGML_OP_TIMESTEP_EMBEDDING:
6081+ case GGML_OP_POOL_2D:
60216082 case GGML_OP_LEAKY_RELU:
60226083 case GGML_OP_REPEAT:
60236084 buf = tensor->buffer ;
@@ -6821,6 +6882,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
68216882 case GGML_OP_SUM_ROWS:
68226883 case GGML_OP_IM2COL:
68236884 case GGML_OP_TIMESTEP_EMBEDDING:
6885+ case GGML_OP_POOL_2D:
68246886 case GGML_OP_LEAKY_RELU:
68256887 return true ;
68266888 default :
@@ -7334,6 +7396,16 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
73347396 const int32_t dim = tensor->op_params [0 ];
73357397 const int32_t max_period = tensor->op_params [1 ];
73367398 tensor_clone = ggml_timestep_embedding (ggml_ctx, src0_clone, dim, max_period);
7399+ } else if (tensor->op == GGML_OP_POOL_2D) {
7400+ enum ggml_op_pool op = static_cast <ggml_op_pool>(dst->op_params [0 ]);
7401+ const int32_t k0 = tensor->op_params [1 ];
7402+ const int32_t k1 = tensor->op_params [2 ];
7403+ const int32_t s0 = tensor->op_params [3 ];
7404+ const int32_t s1 = tensor->op_params [4 ];
7405+ const int32_t p0 = tensor->op_params [5 ];
7406+ const int32_t p1 = tensor->op_params [6 ];
7407+
7408+ tensor_clone = ggml_pool_2d (ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
73377409 } else if (tensor->op == GGML_OP_LEAKY_RELU) {
73387410 const float * op_params = (const float *)tensor->op_params ;
73397411 tensor_clone = ggml_leaky_relu (ggml_ctx, src0_clone, op_params[0 ], false );
0 commit comments