33#include " ggml-cpu.h"
44#include " ggml-impl.h"
55#include " binary-ops.h"
6+ #include " ggml.h"
67#include " unary-ops.h"
78#include " vec.h"
89
@@ -6545,6 +6546,186 @@ void ggml_compute_forward_im2col_back_f32(
65456546 }
65466547}
65476548
6549+ static void ggml_call_mul_mat (ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6550+ void * a, void * b, float * c) {
6551+ const ggml_type_traits * traits = ggml_get_type_traits (type);
6552+ struct ggml_tensor src1 = {};
6553+ src1.type = type;
6554+ src1.ne [0 ] = k;
6555+ src1.ne [1 ] = m;
6556+ src1.ne [2 ] = 1 ;
6557+ src1.ne [3 ] = 1 ;
6558+ src1.nb [0 ] = traits->type_size ;
6559+ src1.nb [1 ] = k * traits->type_size ;
6560+ src1.nb [2 ] = src1.nb [1 ];
6561+ src1.nb [3 ] = src1.nb [2 ];
6562+ src1.data = a;
6563+
6564+ struct ggml_tensor src0 = {};
6565+ src0.type = type;
6566+ src0.ne [0 ] = k;
6567+ src0.ne [1 ] = n;
6568+ src0.ne [2 ] = 1 ;
6569+ src0.ne [3 ] = 1 ;
6570+ src0.nb [0 ] = traits->type_size ;
6571+ src0.nb [1 ] = k * traits->type_size ;
6572+ src0.nb [2 ] = src0.nb [1 ];
6573+ src0.nb [3 ] = src0.nb [2 ];
6574+ src0.data = b;
6575+
6576+ struct ggml_tensor dst = {};
6577+ dst.ne [0 ] = n;
6578+ dst.ne [1 ] = m;
6579+ dst.ne [2 ] = 1 ;
6580+ dst.ne [3 ] = 1 ;
6581+ dst.nb [0 ] = sizeof (float );
6582+ dst.nb [1 ] = n * sizeof (float );
6583+ dst.nb [2 ] = dst.nb [1 ];
6584+ dst.nb [3 ] = dst.nb [2 ];
6585+ dst.data = c;
6586+ dst.src [0 ] = &src0;
6587+ dst.src [1 ] = &src1;
6588+
6589+ ggml_compute_forward_mul_mat (params, &dst);
6590+ }
6591+
6592+ // ggml_compute_forward_conv_2d
6593+
6594+ static void ggml_compute_forward_conv_2d_impl (const ggml_compute_params * params,
6595+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6596+ const ggml_tensor * src, // [W, H, C, N]
6597+ ggml_tensor * dst, // [OW, OH, OC, N]
6598+ ggml_type kernel_type) {
6599+
6600+ GGML_ASSERT (ggml_is_contiguous (kernel));
6601+ GGML_ASSERT (kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6602+ GGML_ASSERT (kernel->type == kernel_type);
6603+
6604+ const ggml_type_traits * traits = ggml_get_type_traits (kernel_type);
6605+
6606+ const int32_t stride_x = dst->op_params [0 ];
6607+ const int32_t stride_y = dst->op_params [1 ];
6608+ const int32_t pad_x = dst->op_params [2 ];
6609+ const int32_t pad_y = dst->op_params [3 ];
6610+ const int32_t dilation_x = dst->op_params [4 ];
6611+ const int32_t dilation_y = dst->op_params [5 ];
6612+
6613+ const int64_t c_in = src->ne [2 ];
6614+ const int64_t c_out = kernel->ne [3 ];
6615+ GGML_ASSERT (c_in == kernel->ne [2 ]);
6616+
6617+ const int64_t src_w = src->ne [0 ];
6618+ const int64_t src_h = src->ne [1 ];
6619+ const int64_t knl_w = kernel->ne [0 ];
6620+ const int64_t knl_h = kernel->ne [1 ];
6621+ const int64_t dst_w = dst->ne [0 ];
6622+ const int64_t dst_h = dst->ne [1 ];
6623+
6624+ const float * src_data = (float *) src->data ;
6625+ void * knl_data = kernel->data ;
6626+ float * dst_data = (float *) dst->data ;
6627+
6628+ const int64_t knl_n = knl_w * knl_h * c_in;
6629+ const int64_t patch_total = dst->ne [3 ] * dst_w * dst_h;
6630+
6631+ const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof (float );
6632+ const int64_t batch_size = params->wsize / space_per_patch;
6633+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8 ) * 8 : batch_size;
6634+ const int64_t batch_n = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
6635+
6636+ GGML_ASSERT (patches_per_batch > 0 && batch_size >= 1 );
6637+
6638+ void * tmp = params->wdata ;
6639+
6640+ for (int64_t batch_i = 0 ; batch_i < batch_n; ++batch_i) {
6641+
6642+ const int64_t patch_start_batch = batch_i * patches_per_batch;
6643+ const int64_t patch_end_batch = std::min (patch_start_batch + patches_per_batch,
6644+ patch_total);
6645+ const int64_t patch_n = patch_end_batch - patch_start_batch;
6646+
6647+ const int64_t patch_per_thread = (patch_n + params->nth - 1 ) / params->nth ;
6648+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6649+ const int64_t patch_end = std::min (patch_start + patch_per_thread, patch_end_batch);
6650+
6651+ // im2col for a patch
6652+ for (int64_t p = patch_start; p < patch_end; ++p) {
6653+ const int64_t batch_n = p / (dst_w * dst_h);
6654+ const int64_t src_x = (p / dst_w) % dst_h;
6655+ const int64_t src_y = p % dst_w;
6656+
6657+ const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb [3 ]);
6658+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size ;
6659+
6660+ for (int64_t ic = 0 ; ic < c_in; ++ic) {
6661+ for (int64_t ky = 0 ; ky < knl_h; ++ky) {
6662+ for (int64_t kx = 0 ; kx < knl_w; ++kx) {
6663+ const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6664+ const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6665+
6666+ int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6667+
6668+ float src_val;
6669+ if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6670+ src_val = 0 .0f ;
6671+ } else {
6672+ const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb [0 ] + sy * src->nb [1 ] + ic * src->nb [2 ]);
6673+ src_val = *src_ptr;
6674+ }
6675+
6676+ char * element_ptr = dst_row + dst_idx * traits->type_size ;
6677+ if (kernel_type == GGML_TYPE_F32) {
6678+ *(float *) element_ptr = src_val;
6679+ } else if (kernel_type == GGML_TYPE_F16) {
6680+ *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16 (src_val);
6681+ }
6682+ }
6683+ }
6684+ }
6685+ } // patches handled by this thread
6686+
6687+ ggml_barrier (params->threadpool );
6688+
6689+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size );
6690+
6691+ GGML_ASSERT (gemm_output + patch_n * c_out <= (float *)tmp + params->wsize );
6692+
6693+ // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6694+ ggml_call_mul_mat (kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6695+
6696+ ggml_barrier (params->threadpool );
6697+
6698+
6699+ // permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6700+ const int64_t permute_per_thread = (patch_n + params->nth - 1 ) / params->nth ;
6701+ const int64_t permute_start = params->ith * permute_per_thread;
6702+ const int64_t permute_end = std::min (permute_start + permute_per_thread, patch_n);
6703+
6704+ for (int64_t i = permute_start; i < permute_end; ++i) {
6705+ const int64_t p = patch_start_batch + i;
6706+ const int64_t batch_n = p / (dst_w * dst_h);
6707+ const int64_t dst_y = (p / dst_w) % dst_h;
6708+ const int64_t dst_x = p % dst_w;
6709+
6710+ for (int64_t oc = 0 ; oc < c_out; ++oc) {
6711+ const float value = gemm_output[i * c_out + oc];
6712+ float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb [0 ] + dst_y * dst->nb [1 ] + oc * dst->nb [2 ] + batch_n * dst->nb [3 ]);
6713+ *dst_ptr = value;
6714+ }
6715+ }
6716+ }
6717+ }
6718+
6719+ void ggml_compute_forward_conv_2d (
6720+ const ggml_compute_params * params,
6721+ ggml_tensor * dst) {
6722+
6723+ const ggml_tensor * src0 = dst->src [0 ];
6724+ const ggml_tensor * src1 = dst->src [1 ];
6725+
6726+ ggml_compute_forward_conv_2d_impl (params, src0, src1, dst, src0->type );
6727+ }
6728+
65486729// ggml_compute_forward_conv_transpose_2d
65496730
65506731void ggml_compute_forward_conv_transpose_2d (
@@ -7095,12 +7276,13 @@ static void ggml_compute_forward_upscale_f32(
70957276
70967277 GGML_TENSOR_UNARY_OP_LOCALS
70977278
7098- const float sf0 = (float )ne0/src0->ne [0 ];
7099- const float sf1 = (float )ne1/src0->ne [1 ];
7100- const float sf2 = (float )ne2/src0->ne [2 ];
7101- const float sf3 = (float )ne3/src0->ne [3 ];
7279+ float sf0 = (float )ne0/src0->ne [0 ];
7280+ float sf1 = (float )ne1/src0->ne [1 ];
7281+ float sf2 = (float )ne2/src0->ne [2 ];
7282+ float sf3 = (float )ne3/src0->ne [3 ];
71027283
7103- const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32 (dst, 0 );
7284+ const int32_t mode_flags = ggml_get_op_params_i32 (dst, 0 );
7285+ const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF );
71047286
71057287 if (mode == GGML_SCALE_MODE_NEAREST) {
71067288 for (int64_t i3 = 0 ; i3 < ne3; i3++) {
@@ -7121,8 +7303,12 @@ static void ggml_compute_forward_upscale_f32(
71217303 }
71227304 }
71237305 } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7124- // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
7125- const float pixel_offset = 0 .5f ;
7306+ float pixel_offset = 0 .5f ;
7307+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7308+ pixel_offset = 0 .0f ;
7309+ sf0 = (float )(ne0 - 1 ) / (src0->ne [0 ] - 1 );
7310+ sf1 = (float )(ne1 - 1 ) / (src0->ne [1 ] - 1 );
7311+ }
71267312
71277313 for (int64_t i3 = 0 ; i3 < ne3; i3++) {
71287314 const int64_t i03 = i3 / sf3;
0 commit comments