@@ -6116,6 +6116,163 @@ void ggml_compute_forward_im2col_back_f32(
61166116 }
61176117}
61186118
6119+ // ggml_compute_forward_conv_2d
6120+
6121+ static void ggml_compute_forward_conv_2d_f32 (
6122+ const ggml_compute_params * params,
6123+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6124+ const ggml_tensor * src, // [W, H, C, N]
6125+ ggml_tensor * dst) { // [OW, OH, OC, N]
6126+
6127+ const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6128+ const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6129+ const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6130+ const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6131+ const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6132+ const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6133+
6134+ const int64_t OW = dst->ne [0 ];
6135+ const int64_t OH = dst->ne [1 ];
6136+ const int64_t OC = dst->ne [2 ];
6137+ const int64_t N = dst->ne [3 ];
6138+
6139+ const int64_t IW = src->ne [0 ];
6140+ const int64_t IH = src->ne [1 ];
6141+ const int64_t IC = src->ne [2 ];
6142+
6143+ const int64_t KW = kernel->ne [0 ];
6144+ const int64_t KH = kernel->ne [1 ];
6145+
6146+ const float * kernel_data = (const float *)kernel->data ;
6147+ const float * src_data = (const float *)src->data ;
6148+ float * dst_data = (float *)dst->data ;
6149+
6150+ const int64_t rows_total = OH * N;
6151+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6152+ const int64_t row_start = params->ith * rows_per_thread;
6153+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6154+
6155+ for (int64_t row = row_start; row < row_end; ++row) {
6156+ const int64_t oh = row % OH;
6157+ const int64_t n = row / OH;
6158+ const float * src_batch = src_data + n * IW * IH * IC;
6159+
6160+ for (int64_t ow = 0 ; ow < OW; ++ow) {
6161+ for (int64_t oc = 0 ; oc < OC; ++oc) {
6162+ float sum = 0 .0f ;
6163+ const float * kernel_channel = kernel_data + oc * KW * KH * IC;
6164+
6165+ for (int64_t kh = 0 ; kh < KH; ++kh) {
6166+ const int64_t ih = oh * s1 - p1 + kh * d1;
6167+ if (ih < 0 || ih >= IH) continue ;
6168+
6169+ for (int64_t kw = 0 ; kw < KW; ++kw) {
6170+ const int64_t iw = ow * s0 - p0 + kw * d0;
6171+ if (iw < 0 || iw >= IW) continue ;
6172+
6173+ #pragma omp simd
6174+ for (int64_t ic = 0 ; ic < IC; ++ic) {
6175+ const float * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6176+ const float * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6177+ sum += (*kernel_ptr) * (*src_ptr);
6178+ }
6179+ }
6180+ }
6181+
6182+ dst_data[((n * OC + oc) * OH + oh) * OW + ow] = sum;
6183+ }
6184+ }
6185+ }
6186+ }
6187+
6188+ static void ggml_compute_forward_conv_2d_f16 (
6189+ const ggml_compute_params * params,
6190+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6191+ const ggml_tensor * src, // [W, H, C, N]
6192+ ggml_tensor * dst) { // [OW, OH, OC, N]
6193+
6194+ const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6195+ const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6196+ const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6197+ const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6198+ const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6199+ const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6200+
6201+ const int64_t OW = dst->ne [0 ];
6202+ const int64_t OH = dst->ne [1 ];
6203+ const int64_t OC = dst->ne [2 ];
6204+ const int64_t N = dst->ne [3 ];
6205+
6206+ const int64_t IW = src->ne [0 ];
6207+ const int64_t IH = src->ne [1 ];
6208+ const int64_t IC = src->ne [2 ];
6209+
6210+ const int64_t KW = kernel->ne [0 ];
6211+ const int64_t KH = kernel->ne [1 ];
6212+
6213+ const ggml_fp16_t * kernel_data = (const ggml_fp16_t *)kernel->data ;
6214+ const ggml_fp16_t * src_data = (const ggml_fp16_t *)src->data ;
6215+ ggml_fp16_t * dst_data = (ggml_fp16_t *)dst->data ;
6216+
6217+ const int64_t rows_total = OH * N;
6218+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6219+ const int64_t row_start = params->ith * rows_per_thread;
6220+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6221+
6222+ for (int64_t row = row_start; row < row_end; ++row) {
6223+ const int64_t oh = row % OH;
6224+ const int64_t n = row / OH;
6225+ const ggml_fp16_t * src_batch = src_data + n * IW * IH * IC;
6226+
6227+ for (int64_t ow = 0 ; ow < OW; ++ow) {
6228+ for (int64_t oc = 0 ; oc < OC; ++oc) {
6229+ float sum = 0 .0f ;
6230+ const ggml_fp16_t * kernel_channel = kernel_data + oc * KW * KH * IC;
6231+ for (int64_t kh = 0 ; kh < KH; ++kh) {
6232+ const int64_t ih = oh * s1 - p1 + kh * d1;
6233+ if (ih < 0 || ih >= IH) continue ;
6234+
6235+ for (int64_t kw = 0 ; kw < KW; ++kw) {
6236+ const int64_t iw = ow * s0 - p0 + kw * d0;
6237+ if (iw < 0 || iw >= IW) continue ;
6238+
6239+ for (int64_t ic = 0 ; ic < IC; ++ic) {
6240+ const ggml_fp16_t * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6241+ const ggml_fp16_t * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6242+ sum += GGML_FP16_TO_FP32 (*kernel_ptr) * GGML_FP16_TO_FP32 (*src_ptr);
6243+ }
6244+ }
6245+ }
6246+
6247+ dst_data[((n * OC + oc) * OH + oh) * OW + ow] = GGML_FP32_TO_FP16 (sum);
6248+ }
6249+ }
6250+ }
6251+ }
6252+
6253+ void ggml_compute_forward_conv_2d (
6254+ const ggml_compute_params * params,
6255+ ggml_tensor * dst) {
6256+
6257+ const ggml_tensor * src0 = dst->src [0 ];
6258+ const ggml_tensor * src1 = dst->src [1 ];
6259+
6260+ switch (src0->type ) {
6261+ case GGML_TYPE_F16:
6262+ {
6263+ ggml_compute_forward_conv_2d_f16 (params, src0, src1, dst);
6264+ } break ;
6265+ case GGML_TYPE_F32:
6266+ {
6267+ ggml_compute_forward_conv_2d_f32 (params, src0, src1, dst);
6268+ } break ;
6269+ default :
6270+ {
6271+ GGML_ABORT (" fatal error" );
6272+ }
6273+ }
6274+ }
6275+
61196276// ggml_compute_forward_conv_transpose_2d
61206277
61216278void ggml_compute_forward_conv_transpose_2d (
0 commit comments