@@ -6058,6 +6058,163 @@ void ggml_compute_forward_im2col_back_f32(
60586058 }
60596059}
60606060
6061+ // ggml_compute_forward_conv_2d
6062+
6063+ static void ggml_compute_forward_conv_2d_f32 (
6064+ const ggml_compute_params * params,
6065+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6066+ const ggml_tensor * src, // [W, H, C, N]
6067+ ggml_tensor * dst) { // [OW, OH, OC, N]
6068+
6069+ const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6070+ const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6071+ const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6072+ const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6073+ const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6074+ const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6075+
6076+ const int64_t OW = dst->ne [0 ];
6077+ const int64_t OH = dst->ne [1 ];
6078+ const int64_t OC = dst->ne [2 ];
6079+ const int64_t N = dst->ne [3 ];
6080+
6081+ const int64_t IW = src->ne [0 ];
6082+ const int64_t IH = src->ne [1 ];
6083+ const int64_t IC = src->ne [2 ];
6084+
6085+ const int64_t KW = kernel->ne [0 ];
6086+ const int64_t KH = kernel->ne [1 ];
6087+
6088+ const float * kernel_data = (const float *)kernel->data ;
6089+ const float * src_data = (const float *)src->data ;
6090+ float * dst_data = (float *)dst->data ;
6091+
6092+ const int64_t rows_total = OH * N;
6093+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6094+ const int64_t row_start = params->ith * rows_per_thread;
6095+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6096+
6097+ for (int64_t row = row_start; row < row_end; ++row) {
6098+ const int64_t oh = row % OH;
6099+ const int64_t n = row / OH;
6100+ const float * src_batch = src_data + n * IW * IH * IC;
6101+
6102+ for (int64_t ow = 0 ; ow < OW; ++ow) {
6103+ for (int64_t oc = 0 ; oc < OC; ++oc) {
6104+ float sum = 0 .0f ;
6105+ const float * kernel_channel = kernel_data + oc * KW * KH * IC;
6106+
6107+ for (int64_t kh = 0 ; kh < KH; ++kh) {
6108+ const int64_t ih = oh * s1 - p1 + kh * d1;
6109+ if (ih < 0 || ih >= IH) continue ;
6110+
6111+ for (int64_t kw = 0 ; kw < KW; ++kw) {
6112+ const int64_t iw = ow * s0 - p0 + kw * d0;
6113+ if (iw < 0 || iw >= IW) continue ;
6114+
6115+ #pragma omp simd
6116+ for (int64_t ic = 0 ; ic < IC; ++ic) {
6117+ const float * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6118+ const float * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6119+ sum += (*kernel_ptr) * (*src_ptr);
6120+ }
6121+ }
6122+ }
6123+
6124+ dst_data[((n * OC + oc) * OH + oh) * OW + ow] = sum;
6125+ }
6126+ }
6127+ }
6128+ }
6129+
6130+ static void ggml_compute_forward_conv_2d_f16 (
6131+ const ggml_compute_params * params,
6132+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6133+ const ggml_tensor * src, // [W, H, C, N]
6134+ ggml_tensor * dst) { // [OW, OH, OC, N]
6135+
6136+ const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6137+ const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6138+ const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6139+ const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6140+ const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6141+ const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6142+
6143+ const int64_t OW = dst->ne [0 ];
6144+ const int64_t OH = dst->ne [1 ];
6145+ const int64_t OC = dst->ne [2 ];
6146+ const int64_t N = dst->ne [3 ];
6147+
6148+ const int64_t IW = src->ne [0 ];
6149+ const int64_t IH = src->ne [1 ];
6150+ const int64_t IC = src->ne [2 ];
6151+
6152+ const int64_t KW = kernel->ne [0 ];
6153+ const int64_t KH = kernel->ne [1 ];
6154+
6155+ const ggml_fp16_t * kernel_data = (const ggml_fp16_t *)kernel->data ;
6156+ const ggml_fp16_t * src_data = (const ggml_fp16_t *)src->data ;
6157+ ggml_fp16_t * dst_data = (ggml_fp16_t *)dst->data ;
6158+
6159+ const int64_t rows_total = OH * N;
6160+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6161+ const int64_t row_start = params->ith * rows_per_thread;
6162+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6163+
6164+ for (int64_t row = row_start; row < row_end; ++row) {
6165+ const int64_t oh = row % OH;
6166+ const int64_t n = row / OH;
6167+ const ggml_fp16_t * src_batch = src_data + n * IW * IH * IC;
6168+
6169+ for (int64_t ow = 0 ; ow < OW; ++ow) {
6170+ for (int64_t oc = 0 ; oc < OC; ++oc) {
6171+ float sum = 0 .0f ;
6172+ const ggml_fp16_t * kernel_channel = kernel_data + oc * KW * KH * IC;
6173+ for (int64_t kh = 0 ; kh < KH; ++kh) {
6174+ const int64_t ih = oh * s1 - p1 + kh * d1;
6175+ if (ih < 0 || ih >= IH) continue ;
6176+
6177+ for (int64_t kw = 0 ; kw < KW; ++kw) {
6178+ const int64_t iw = ow * s0 - p0 + kw * d0;
6179+ if (iw < 0 || iw >= IW) continue ;
6180+
6181+ for (int64_t ic = 0 ; ic < IC; ++ic) {
6182+ const ggml_fp16_t * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6183+ const ggml_fp16_t * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6184+ sum += GGML_FP16_TO_FP32 (*kernel_ptr) * GGML_FP16_TO_FP32 (*src_ptr);
6185+ }
6186+ }
6187+ }
6188+
6189+ dst_data[((n * OC + oc) * OH + oh) * OW + ow] = GGML_FP32_TO_FP16 (sum);
6190+ }
6191+ }
6192+ }
6193+ }
6194+
6195+ void ggml_compute_forward_conv_2d (
6196+ const ggml_compute_params * params,
6197+ ggml_tensor * dst) {
6198+
6199+ const ggml_tensor * src0 = dst->src [0 ];
6200+ const ggml_tensor * src1 = dst->src [1 ];
6201+
6202+ switch (src0->type ) {
6203+ case GGML_TYPE_F16:
6204+ {
6205+ ggml_compute_forward_conv_2d_f16 (params, src0, src1, dst);
6206+ } break ;
6207+ case GGML_TYPE_F32:
6208+ {
6209+ ggml_compute_forward_conv_2d_f32 (params, src0, src1, dst);
6210+ } break ;
6211+ default :
6212+ {
6213+ GGML_ABORT (" fatal error" );
6214+ }
6215+ }
6216+ }
6217+
60616218// ggml_compute_forward_conv_transpose_2d
60626219
60636220void ggml_compute_forward_conv_transpose_2d (
0 commit comments