@@ -6105,18 +6105,21 @@ static void ggml_call_mul_mat(
61056105
61066106// ggml_compute_forward_conv_2d
61076107
6108- static void ggml_compute_forward_conv_2d_f32 (const ggml_compute_params * params,
6109- ggml_tensor * dst) {
6110-
6111- const ggml_tensor * src = dst-> src [ 1 ]; // [W H C_in N]
6112- const ggml_tensor * kernel = dst-> src [ 0 ]; // [W H C_in C_out ]
6108+ static void ggml_compute_forward_conv_2d_f32 (
6109+ const ggml_compute_params * params,
6110+ const ggml_tensor * kernel, // [KW, KH, IC, OC] - fp32
6111+ const ggml_tensor * src, // [W, H, C, N]
6112+ ggml_tensor * dst) { // [OW, OH, OC, N ]
61136113
61146114 GGML_ASSERT (ggml_is_contiguous (kernel));
6115+ GGML_ASSERT (kernel->type == GGML_TYPE_F32);
61156116
6116- const int32_t stride_x = dst->op_params [0 ];
6117- const int32_t stride_y = dst->op_params [1 ];
6118- const int32_t pad_x = dst->op_params [2 ];
6119- const int32_t pad_y = dst->op_params [3 ];
6117+ const int32_t stride_x = dst->op_params [0 ];
6118+ const int32_t stride_y = dst->op_params [1 ];
6119+ const int32_t pad_x = dst->op_params [2 ];
6120+ const int32_t pad_y = dst->op_params [3 ];
6121+ const int32_t dilation_x = dst->op_params [4 ];
6122+ const int32_t dilation_y = dst->op_params [5 ];
61206123
61216124 const int64_t c_in = src->ne [2 ];
61226125 const int64_t c_out = kernel->ne [3 ];
@@ -6129,193 +6132,104 @@ static void ggml_compute_forward_conv_2d_f32(const ggml_compute_params * params,
61296132 const int64_t dst_w = dst->ne [0 ];
61306133 const int64_t dst_h = dst->ne [1 ];
61316134
6132-
6133- float * src_data = (float *) src->data ;
6134- float * knl_data = (float *) kernel->data ;
6135- float * dst_data = ( float *) dst->data ;
6136-
6135+ float * src_data = (float *) src->data ;
6136+ float * knl_data = (float *) kernel->data ;
6137+ float * dst_data = (float *) dst->data ;
61376138
61386139 const int64_t knl_n = knl_w * knl_h * c_in;
61396140 const int64_t patch_total = dst->ne [3 ] * dst_w * dst_h;
6140-
6141-
6142-
6143- const int64_t space_per_patch = knl_n * sizeof (float ) + patch_total * c_out * sizeof (float );
61446141
6145- const int64_t batch_size = params->wsize / space_per_patch;
6142+ const int64_t space_per_patch = knl_n * sizeof (float ) + c_out * sizeof (float );
6143+ const int64_t batch_size = params->wsize / space_per_patch;
61466144 const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8 ) * 8 : batch_size;
6147- const int64_t batch_n = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
6148-
6145+ const int64_t batch_n = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
61496146
61506147 GGML_ASSERT (patches_per_batch > 0 && batch_size >= 1 );
61516148
6152- float * tmp = (float *) params->wdata ; // per-thread scratch
6149+ float * tmp = (float *) params->wdata ;
61536150
61546151 for (int64_t batch_i = 0 ; batch_i < batch_n; ++batch_i) {
61556152
61566153 const int64_t patch_start_batch = batch_i * patches_per_batch;
61576154 const int64_t patch_end_batch = std::min (patch_start_batch + patches_per_batch,
61586155 patch_total);
6159- const int64_t patch_n = patch_end_batch - patch_start_batch;
6156+ const int64_t patch_n = patch_end_batch - patch_start_batch;
61606157
6161- const int64_t patch_per_thread =
6162- (patch_n + params->nth - 1 ) / params->nth ;
6163- const int64_t patch_start = patch_start_batch +
6164- params->ith * patch_per_thread;
6165- const int64_t patch_end = std::min (patch_start + patch_per_thread,
6166- patch_end_batch);
6158+ const int64_t patch_per_thread = (patch_n + params->nth - 1 ) / params->nth ;
6159+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6160+ const int64_t patch_end = std::min (patch_start + patch_per_thread,patch_end_batch);
61676161
61686162 // im2col for a patch
61696163 for (int64_t p = patch_start; p < patch_end; ++p) {
6170- const int64_t b = p / (dst_w * dst_h);
6171- const int64_t dy = (p / dst_w) % dst_h;
6172- const int64_t dx = p % dst_w;
6164+ const int64_t batch_n = p / (dst_w * dst_h);
6165+ const int64_t src_x = (p / dst_w) % dst_h;
6166+ const int64_t src_y = p % dst_w;
61736167
6174- const float * src_base = (const float *)((char *)src_data + b * src->nb [3 ]);
6175- float * out_row = tmp + (p % patches_per_batch) * knl_n;
6168+ float * src_base = (float *)((char *)src_data + batch_n * src->nb [3 ]);
6169+ float * dst_row = tmp + (p % patches_per_batch) * knl_n;
61766170
6177- // Extract patch in IC,KH,KW order (same as im2col)
61786171 for (int64_t ic = 0 ; ic < c_in; ++ic) {
61796172 for (int64_t ky = 0 ; ky < knl_h; ++ky) {
61806173 for (int64_t kx = 0 ; kx < knl_w; ++kx) {
6181- const int64_t sy = dy * stride_y + ky - pad_y;
6182- const int64_t sx = dx * stride_x + kx - pad_x;
6183-
6174+ const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6175+ const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6176+
61846177 int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6185-
6178+
61866179 if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6187- out_row [dst_idx] = 0 .0f ;
6180+ dst_row [dst_idx] = 0 .0f ;
61886181 } else {
6189- float * src_ptr = (float *)((char *)src_base +
6190- sx * src->nb [0 ] + sy * src->nb [1 ] + ic * src->nb [2 ]);
6191- out_row[dst_idx] = *src_ptr;
6182+ float * src_ptr = (float *)((char *)src_base + sx * src->nb [0 ] + sy * src->nb [1 ] + ic * src->nb [2 ]);
6183+ dst_row[dst_idx] = *src_ptr;
61926184 }
61936185 }
61946186 }
61956187 }
61966188 } // patches handled by this thread
61976189
6198- ggml_barrier (params->threadpool ); // wait for all threads
6190+ ggml_barrier (params->threadpool );
61996191
6200- // GEMM output is patch_n * cout
62016192 float * gemm_output = tmp + patches_per_batch * knl_n;
6202-
6193+
62036194 // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
62046195 ggml_call_mul_mat (params, patch_n, c_out, knl_n,
62056196 tmp, knl_data, gemm_output);
6206-
6207- // Barrier to ensure GEMM completes before permutation
6197+
62086198 ggml_barrier (params->threadpool );
6209-
6210- // Distribute permutation work across threads
6199+
6200+
6201+ // permute back [OC, N, OH, OW] to [N, OC, OH, OW]
62116202 const int64_t permute_per_thread = (patch_n + params->nth - 1 ) / params->nth ;
62126203 const int64_t permute_start = params->ith * permute_per_thread;
62136204 const int64_t permute_end = std::min (permute_start + permute_per_thread, patch_n);
6214-
6215- // Each thread handles part of the permutation from [patch_n, c_out] to WHCN layout
6205+
62166206 for (int64_t i = permute_start; i < permute_end; ++i) {
6217- const int64_t p = patch_start_batch + i;
6218- const int64_t b = p / (dst_w * dst_h); // batch index
6219- const int64_t dy = (p / dst_w) % dst_h; // height index
6220- const int64_t dx = p % dst_w; // width index
6221-
6222- // Copy all channels for this spatial position
6207+ const int64_t p = patch_start_batch + i;
6208+ const int64_t batch_n = p / (dst_w * dst_h);
6209+ const int64_t dst_y = (p / dst_w) % dst_h;
6210+ const int64_t dst_x = p % dst_w;
6211+
62236212 for (int64_t oc = 0 ; oc < c_out; ++oc) {
62246213 const float value = gemm_output[i * c_out + oc];
62256214 // Write to WHCN layout: dst[w, h, c, n]
6226- float * dst_ptr = (float *)((char *)dst_data +
6227- dx * dst->nb [0 ] + dy * dst->nb [1 ] + oc * dst->nb [2 ] + b * dst->nb [3 ]);
6215+ float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb [0 ] + dst_y * dst->nb [1 ] + oc * dst->nb [2 ] + batch_n * dst->nb [3 ]);
62286216 *dst_ptr = value;
62296217 }
62306218 }
62316219 }
62326220}
62336221
6234- static void ggml_compute_forward_conv_2d_f16 (
6235- const ggml_compute_params * params,
6236- const ggml_tensor * kernel, // [KW, KH, IC, OC]
6237- const ggml_tensor * src, // [W, H, C, N]
6238- ggml_tensor * dst) { // [OW, OH, OC, N]
6239-
6240- const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6241- const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6242- const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6243- const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6244- const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6245- const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6246-
6247- const int64_t OW = dst->ne [0 ];
6248- const int64_t OH = dst->ne [1 ];
6249- const int64_t OC = dst->ne [2 ];
6250- const int64_t N = dst->ne [3 ];
6251-
6252- const int64_t IW = src->ne [0 ];
6253- const int64_t IH = src->ne [1 ];
6254- const int64_t IC = src->ne [2 ];
6255-
6256- const int64_t KW = kernel->ne [0 ];
6257- const int64_t KH = kernel->ne [1 ];
6258-
6259- const ggml_fp16_t * kernel_data = (const ggml_fp16_t *)kernel->data ;
6260- const ggml_fp16_t * src_data = (const ggml_fp16_t *)src->data ;
6261- ggml_fp16_t * dst_data = (ggml_fp16_t *)dst->data ;
6262-
6263- const int64_t rows_total = OH * N;
6264- const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6265- const int64_t row_start = params->ith * rows_per_thread;
6266- const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6267-
6268- for (int64_t row = row_start; row < row_end; ++row) {
6269- const int64_t oh = row % OH;
6270- const int64_t n = row / OH;
6271- const ggml_fp16_t * src_batch = src_data + n * IW * IH * IC;
6272-
6273- for (int64_t ow = 0 ; ow < OW; ++ow) {
6274- for (int64_t oc = 0 ; oc < OC; ++oc) {
6275- float sum = 0 .0f ;
6276- const ggml_fp16_t * kernel_channel = kernel_data + oc * KW * KH * IC;
6277- for (int64_t kh = 0 ; kh < KH; ++kh) {
6278- const int64_t ih = oh * s1 - p1 + kh * d1;
6279- if (ih < 0 || ih >= IH) continue ;
6280-
6281- for (int64_t kw = 0 ; kw < KW; ++kw) {
6282- const int64_t iw = ow * s0 - p0 + kw * d0;
6283- if (iw < 0 || iw >= IW) continue ;
6284-
6285- for (int64_t ic = 0 ; ic < IC; ++ic) {
6286- const ggml_fp16_t * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6287- const ggml_fp16_t * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6288- sum += GGML_FP16_TO_FP32 (*kernel_ptr) * GGML_FP16_TO_FP32 (*src_ptr);
6289- }
6290- }
6291- }
6292-
6293- dst_data[((n * OC + oc) * OH + oh) * OW + ow] = GGML_FP32_TO_FP16 (sum);
6294- }
6295- }
6296- }
6297- }
6298-
62996222void ggml_compute_forward_conv_2d (
63006223 const ggml_compute_params * params,
63016224 ggml_tensor * dst) {
63026225
63036226 const ggml_tensor * src0 = dst->src [0 ];
63046227 const ggml_tensor * src1 = dst->src [1 ];
63056228
6306- switch (src0->type ) {
6307- case GGML_TYPE_F16:
6308- {
6309- ggml_compute_forward_conv_2d_f16 (params, src0, src1, dst);
6310- } break ;
6311- case GGML_TYPE_F32:
6312- {
6313- ggml_compute_forward_conv_2d_f32 (params, dst);
6314- } break ;
6315- default :
6316- {
6317- GGML_ABORT (" fatal error" );
6318- }
6229+ if (src0->type == GGML_TYPE_F16) {
6230+ GGML_ASSERT (false && " F16 not supported yet" );
6231+ } else {
6232+ ggml_compute_forward_conv_2d_f32 (params, src0, src1, dst);
63196233 }
63206234}
63216235
0 commit comments