| 
3 | 3 | #include "ggml-cpu.h"  | 
4 | 4 | #include "ggml-impl.h"  | 
5 | 5 | #include "binary-ops.h"  | 
 | 6 | +#include "ggml.h"  | 
6 | 7 | #include "unary-ops.h"  | 
7 | 8 | #include "vec.h"  | 
8 | 9 | 
 
  | 
@@ -6545,6 +6546,186 @@ void ggml_compute_forward_im2col_back_f32(  | 
6545 | 6546 |     }  | 
6546 | 6547 | }  | 
6547 | 6548 | 
 
  | 
 | 6549 | +static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,  | 
 | 6550 | +                              void * a, void * b, float * c) {  | 
 | 6551 | +    const ggml_type_traits * traits = ggml_get_type_traits(type);  | 
 | 6552 | +    struct ggml_tensor src1 = {};  | 
 | 6553 | +    src1.type  = type;  | 
 | 6554 | +    src1.ne[0] = k;  | 
 | 6555 | +    src1.ne[1] = m;  | 
 | 6556 | +    src1.ne[2] = 1;  | 
 | 6557 | +    src1.ne[3] = 1;  | 
 | 6558 | +    src1.nb[0] = traits->type_size;  | 
 | 6559 | +    src1.nb[1] = k * traits->type_size;  | 
 | 6560 | +    src1.nb[2] = src1.nb[1];  | 
 | 6561 | +    src1.nb[3] = src1.nb[2];  | 
 | 6562 | +    src1.data  = a;  | 
 | 6563 | + | 
 | 6564 | +    struct ggml_tensor src0 = {};  | 
 | 6565 | +    src0.type  = type;  | 
 | 6566 | +    src0.ne[0] = k;  | 
 | 6567 | +    src0.ne[1] = n;  | 
 | 6568 | +    src0.ne[2] = 1;  | 
 | 6569 | +    src0.ne[3] = 1;  | 
 | 6570 | +    src0.nb[0] = traits->type_size;  | 
 | 6571 | +    src0.nb[1] = k * traits->type_size;  | 
 | 6572 | +    src0.nb[2] = src0.nb[1];  | 
 | 6573 | +    src0.nb[3] = src0.nb[2];  | 
 | 6574 | +    src0.data  = b;  | 
 | 6575 | + | 
 | 6576 | +    struct ggml_tensor dst = {};  | 
 | 6577 | +    dst.ne[0] = n;  | 
 | 6578 | +    dst.ne[1] = m;  | 
 | 6579 | +    dst.ne[2] = 1;  | 
 | 6580 | +    dst.ne[3] = 1;  | 
 | 6581 | +    dst.nb[0] = sizeof(float);  | 
 | 6582 | +    dst.nb[1] = n * sizeof(float);  | 
 | 6583 | +    dst.nb[2] = dst.nb[1];  | 
 | 6584 | +    dst.nb[3] = dst.nb[2];  | 
 | 6585 | +    dst.data  = c;  | 
 | 6586 | +    dst.src[0] = &src0;  | 
 | 6587 | +    dst.src[1] = &src1;  | 
 | 6588 | + | 
 | 6589 | +    ggml_compute_forward_mul_mat(params, &dst);  | 
 | 6590 | +}  | 
 | 6591 | + | 
 | 6592 | +// ggml_compute_forward_conv_2d  | 
 | 6593 | + | 
 | 6594 | +static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,  | 
 | 6595 | +                                              const ggml_tensor *         kernel,  // [KW, KH, IC, OC]  | 
 | 6596 | +                                              const ggml_tensor *         src,     // [W, H, C, N]  | 
 | 6597 | +                                              ggml_tensor *               dst,     // [OW, OH, OC, N]  | 
 | 6598 | +                                              ggml_type                   kernel_type) {  | 
 | 6599 | + | 
 | 6600 | +    GGML_ASSERT(ggml_is_contiguous(kernel));  | 
 | 6601 | +    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);  | 
 | 6602 | +    GGML_ASSERT(kernel->type == kernel_type);  | 
 | 6603 | + | 
 | 6604 | +    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);  | 
 | 6605 | + | 
 | 6606 | +    const int32_t stride_x   = dst->op_params[0];  | 
 | 6607 | +    const int32_t stride_y   = dst->op_params[1];  | 
 | 6608 | +    const int32_t pad_x      = dst->op_params[2];  | 
 | 6609 | +    const int32_t pad_y      = dst->op_params[3];  | 
 | 6610 | +    const int32_t dilation_x = dst->op_params[4];  | 
 | 6611 | +    const int32_t dilation_y = dst->op_params[5];  | 
 | 6612 | + | 
 | 6613 | +    const int64_t c_in  = src->ne[2];  | 
 | 6614 | +    const int64_t c_out = kernel->ne[3];  | 
 | 6615 | +    GGML_ASSERT(c_in == kernel->ne[2]);  | 
 | 6616 | + | 
 | 6617 | +    const int64_t src_w = src->ne[0];  | 
 | 6618 | +    const int64_t src_h = src->ne[1];  | 
 | 6619 | +    const int64_t knl_w = kernel->ne[0];  | 
 | 6620 | +    const int64_t knl_h = kernel->ne[1];  | 
 | 6621 | +    const int64_t dst_w = dst->ne[0];  | 
 | 6622 | +    const int64_t dst_h = dst->ne[1];  | 
 | 6623 | + | 
 | 6624 | +    const float * src_data = (float *) src->data;  | 
 | 6625 | +    void  * knl_data       = kernel->data;  | 
 | 6626 | +    float * dst_data       = (float *) dst->data;  | 
 | 6627 | + | 
 | 6628 | +    const int64_t knl_n           = knl_w * knl_h * c_in;  | 
 | 6629 | +    const int64_t patch_total     = dst->ne[3] * dst_w * dst_h;  | 
 | 6630 | + | 
 | 6631 | +    const int64_t space_per_patch   = knl_n * traits->type_size + c_out * sizeof(float);  | 
 | 6632 | +    const int64_t batch_size        = params->wsize / space_per_patch;  | 
 | 6633 | +    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;  | 
 | 6634 | +    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;  | 
 | 6635 | + | 
 | 6636 | +    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);  | 
 | 6637 | + | 
 | 6638 | +    void * tmp = params->wdata;  | 
 | 6639 | + | 
 | 6640 | +    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {  | 
 | 6641 | + | 
 | 6642 | +        const int64_t patch_start_batch = batch_i * patches_per_batch;  | 
 | 6643 | +        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch,  | 
 | 6644 | +                                              patch_total);  | 
 | 6645 | +        const int64_t patch_n           = patch_end_batch - patch_start_batch;  | 
 | 6646 | + | 
 | 6647 | +        const int64_t patch_per_thread  = (patch_n + params->nth - 1) / params->nth;  | 
 | 6648 | +        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;  | 
 | 6649 | +        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);  | 
 | 6650 | + | 
 | 6651 | +        //im2col for a patch  | 
 | 6652 | +        for (int64_t p = patch_start; p < patch_end; ++p) {  | 
 | 6653 | +            const int64_t  batch_n     =  p / (dst_w * dst_h);  | 
 | 6654 | +            const int64_t  src_x       = (p / dst_w) % dst_h;  | 
 | 6655 | +            const int64_t  src_y       =  p % dst_w;  | 
 | 6656 | + | 
 | 6657 | +            const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);  | 
 | 6658 | +            char *        dst_row  = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;  | 
 | 6659 | + | 
 | 6660 | +            for (int64_t ic = 0; ic < c_in; ++ic) {  | 
 | 6661 | +                for (int64_t ky = 0; ky < knl_h; ++ky) {  | 
 | 6662 | +                    for (int64_t kx = 0; kx < knl_w; ++kx) {  | 
 | 6663 | +                        const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;  | 
 | 6664 | +                        const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;  | 
 | 6665 | + | 
 | 6666 | +                        int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;  | 
 | 6667 | + | 
 | 6668 | +                        float src_val;  | 
 | 6669 | +                        if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {  | 
 | 6670 | +                            src_val = 0.0f;  | 
 | 6671 | +                        } else {  | 
 | 6672 | +                            const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);  | 
 | 6673 | +                            src_val               = *src_ptr;  | 
 | 6674 | +                        }  | 
 | 6675 | + | 
 | 6676 | +                        char * element_ptr = dst_row + dst_idx * traits->type_size;  | 
 | 6677 | +                        if (kernel_type == GGML_TYPE_F32) {  | 
 | 6678 | +                            *(float *) element_ptr = src_val;  | 
 | 6679 | +                        } else if (kernel_type == GGML_TYPE_F16) {  | 
 | 6680 | +                            *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);  | 
 | 6681 | +                        }  | 
 | 6682 | +                    }  | 
 | 6683 | +                }  | 
 | 6684 | +            }  | 
 | 6685 | +        }   // patches handled by this thread  | 
 | 6686 | + | 
 | 6687 | +        ggml_barrier(params->threadpool);  | 
 | 6688 | + | 
 | 6689 | +        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);  | 
 | 6690 | + | 
 | 6691 | +        GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);  | 
 | 6692 | + | 
 | 6693 | +        // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]  | 
 | 6694 | +        ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);  | 
 | 6695 | + | 
 | 6696 | +        ggml_barrier(params->threadpool);  | 
 | 6697 | + | 
 | 6698 | + | 
 | 6699 | +        //permute back [OC, N, OH, OW] to [N, OC, OH, OW]  | 
 | 6700 | +        const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;  | 
 | 6701 | +        const int64_t permute_start = params->ith * permute_per_thread;  | 
 | 6702 | +        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);  | 
 | 6703 | + | 
 | 6704 | +        for (int64_t i = permute_start; i < permute_end; ++i) {  | 
 | 6705 | +            const int64_t p       = patch_start_batch + i;  | 
 | 6706 | +            const int64_t batch_n = p / (dst_w * dst_h);  | 
 | 6707 | +            const int64_t dst_y   = (p / dst_w) % dst_h;  | 
 | 6708 | +            const int64_t dst_x   = p % dst_w;  | 
 | 6709 | + | 
 | 6710 | +            for (int64_t oc = 0; oc < c_out; ++oc) {  | 
 | 6711 | +                const float value = gemm_output[i * c_out + oc];  | 
 | 6712 | +                float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);  | 
 | 6713 | +                *dst_ptr = value;  | 
 | 6714 | +            }  | 
 | 6715 | +        }  | 
 | 6716 | +    }  | 
 | 6717 | +}  | 
 | 6718 | + | 
 | 6719 | +void ggml_compute_forward_conv_2d(  | 
 | 6720 | +        const ggml_compute_params * params,  | 
 | 6721 | +        ggml_tensor * dst) {  | 
 | 6722 | + | 
 | 6723 | +    const ggml_tensor * src0 = dst->src[0];  | 
 | 6724 | +    const ggml_tensor * src1 = dst->src[1];  | 
 | 6725 | + | 
 | 6726 | +    ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);  | 
 | 6727 | +}  | 
 | 6728 | + | 
6548 | 6729 | // ggml_compute_forward_conv_transpose_2d  | 
6549 | 6730 | 
 
  | 
6550 | 6731 | void ggml_compute_forward_conv_transpose_2d(  | 
 | 
0 commit comments