@@ -6545,6 +6545,163 @@ void ggml_compute_forward_im2col_back_f32(
65456545 }
65466546}
65476547
6548+ // ggml_compute_forward_conv_2d
6549+
6550+ static void ggml_compute_forward_conv_2d_f32 (
6551+ const ggml_compute_params * params,
6552+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6553+ const ggml_tensor * src, // [W, H, C, N]
6554+ ggml_tensor * dst) { // [OW, OH, OC, N]
6555+
6556+ const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6557+ const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6558+ const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6559+ const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6560+ const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6561+ const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6562+
6563+ const int64_t OW = dst->ne [0 ];
6564+ const int64_t OH = dst->ne [1 ];
6565+ const int64_t OC = dst->ne [2 ];
6566+ const int64_t N = dst->ne [3 ];
6567+
6568+ const int64_t IW = src->ne [0 ];
6569+ const int64_t IH = src->ne [1 ];
6570+ const int64_t IC = src->ne [2 ];
6571+
6572+ const int64_t KW = kernel->ne [0 ];
6573+ const int64_t KH = kernel->ne [1 ];
6574+
6575+ const float * kernel_data = (const float *)kernel->data ;
6576+ const float * src_data = (const float *)src->data ;
6577+ float * dst_data = (float *)dst->data ;
6578+
6579+ const int64_t rows_total = OH * N;
6580+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6581+ const int64_t row_start = params->ith * rows_per_thread;
6582+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6583+
6584+ for (int64_t row = row_start; row < row_end; ++row) {
6585+ const int64_t oh = row % OH;
6586+ const int64_t n = row / OH;
6587+ const float * src_batch = src_data + n * IW * IH * IC;
6588+
6589+ for (int64_t ow = 0 ; ow < OW; ++ow) {
6590+ for (int64_t oc = 0 ; oc < OC; ++oc) {
6591+ float sum = 0 .0f ;
6592+ const float * kernel_channel = kernel_data + oc * KW * KH * IC;
6593+
6594+ for (int64_t kh = 0 ; kh < KH; ++kh) {
6595+ const int64_t ih = oh * s1 - p1 + kh * d1;
6596+ if (ih < 0 || ih >= IH) continue ;
6597+
6598+ for (int64_t kw = 0 ; kw < KW; ++kw) {
6599+ const int64_t iw = ow * s0 - p0 + kw * d0;
6600+ if (iw < 0 || iw >= IW) continue ;
6601+
6602+ #pragma omp simd
6603+ for (int64_t ic = 0 ; ic < IC; ++ic) {
6604+ const float * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6605+ const float * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6606+ sum += (*kernel_ptr) * (*src_ptr);
6607+ }
6608+ }
6609+ }
6610+
6611+ dst_data[((n * OC + oc) * OH + oh) * OW + ow] = sum;
6612+ }
6613+ }
6614+ }
6615+ }
6616+
6617+ static void ggml_compute_forward_conv_2d_f16 (
6618+ const ggml_compute_params * params,
6619+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6620+ const ggml_tensor * src, // [W, H, C, N]
6621+ ggml_tensor * dst) { // [OW, OH, OC, N]
6622+
6623+ const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6624+ const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6625+ const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6626+ const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6627+ const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6628+ const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6629+
6630+ const int64_t OW = dst->ne [0 ];
6631+ const int64_t OH = dst->ne [1 ];
6632+ const int64_t OC = dst->ne [2 ];
6633+ const int64_t N = dst->ne [3 ];
6634+
6635+ const int64_t IW = src->ne [0 ];
6636+ const int64_t IH = src->ne [1 ];
6637+ const int64_t IC = src->ne [2 ];
6638+
6639+ const int64_t KW = kernel->ne [0 ];
6640+ const int64_t KH = kernel->ne [1 ];
6641+
6642+ const ggml_fp16_t * kernel_data = (const ggml_fp16_t *)kernel->data ;
6643+ const ggml_fp16_t * src_data = (const ggml_fp16_t *)src->data ;
6644+ ggml_fp16_t * dst_data = (ggml_fp16_t *)dst->data ;
6645+
6646+ const int64_t rows_total = OH * N;
6647+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6648+ const int64_t row_start = params->ith * rows_per_thread;
6649+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6650+
6651+ for (int64_t row = row_start; row < row_end; ++row) {
6652+ const int64_t oh = row % OH;
6653+ const int64_t n = row / OH;
6654+ const ggml_fp16_t * src_batch = src_data + n * IW * IH * IC;
6655+
6656+ for (int64_t ow = 0 ; ow < OW; ++ow) {
6657+ for (int64_t oc = 0 ; oc < OC; ++oc) {
6658+ float sum = 0 .0f ;
6659+ const ggml_fp16_t * kernel_channel = kernel_data + oc * KW * KH * IC;
6660+ for (int64_t kh = 0 ; kh < KH; ++kh) {
6661+ const int64_t ih = oh * s1 - p1 + kh * d1;
6662+ if (ih < 0 || ih >= IH) continue ;
6663+
6664+ for (int64_t kw = 0 ; kw < KW; ++kw) {
6665+ const int64_t iw = ow * s0 - p0 + kw * d0;
6666+ if (iw < 0 || iw >= IW) continue ;
6667+
6668+ for (int64_t ic = 0 ; ic < IC; ++ic) {
6669+ const ggml_fp16_t * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6670+ const ggml_fp16_t * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6671+ sum += GGML_FP16_TO_FP32 (*kernel_ptr) * GGML_FP16_TO_FP32 (*src_ptr);
6672+ }
6673+ }
6674+ }
6675+
6676+ dst_data[((n * OC + oc) * OH + oh) * OW + ow] = GGML_FP32_TO_FP16 (sum);
6677+ }
6678+ }
6679+ }
6680+ }
6681+
6682+ void ggml_compute_forward_conv_2d (
6683+ const ggml_compute_params * params,
6684+ ggml_tensor * dst) {
6685+
6686+ const ggml_tensor * src0 = dst->src [0 ];
6687+ const ggml_tensor * src1 = dst->src [1 ];
6688+
6689+ switch (src0->type ) {
6690+ case GGML_TYPE_F16:
6691+ {
6692+ ggml_compute_forward_conv_2d_f16 (params, src0, src1, dst);
6693+ } break ;
6694+ case GGML_TYPE_F32:
6695+ {
6696+ ggml_compute_forward_conv_2d_f32 (params, src0, src1, dst);
6697+ } break ;
6698+ default :
6699+ {
6700+ GGML_ABORT (" fatal error" );
6701+ }
6702+ }
6703+ }
6704+
65486705// ggml_compute_forward_conv_transpose_2d
65496706
65506707void ggml_compute_forward_conv_transpose_2d (
0 commit comments