| 
 | 1 | +#include "conv2d.cuh"  | 
 | 2 | + | 
 | 3 | +struct conv_params {  | 
 | 4 | +    const int64_t IW, IH;  | 
 | 5 | +    const int64_t OW, OH;  | 
 | 6 | +    const int64_t KW, KH;  | 
 | 7 | +    const int64_t ST_X, ST_Y;  | 
 | 8 | +    const int64_t PD_X, PD_Y;  | 
 | 9 | +    const int64_t DL_X, DL_Y;  | 
 | 10 | +    const int64_t IC, OC;  | 
 | 11 | +    const int64_t B;  | 
 | 12 | +    const int64_t TOTAL;  | 
 | 13 | +};  | 
 | 14 | + | 
 | 15 | +struct kernel_bounds {  | 
 | 16 | +    int64_t y_min, y_max;  | 
 | 17 | +    int64_t x_min, x_max;  | 
 | 18 | +};  | 
 | 19 | + | 
 | 20 | +__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {  | 
 | 21 | +    return (a > b) ? a : b;  | 
 | 22 | +}  | 
 | 23 | + | 
 | 24 | +__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {  | 
 | 25 | +    return (a < b) ? a : b;  | 
 | 26 | +}  | 
 | 27 | + | 
 | 28 | +__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {  | 
 | 29 | +    kernel_bounds bounds;  | 
 | 30 | +    bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);  | 
 | 31 | +    bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);  | 
 | 32 | +    bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);  | 
 | 33 | +    bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);  | 
 | 34 | +    return bounds;  | 
 | 35 | +}  | 
 | 36 | + | 
 | 37 | +__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,  | 
 | 38 | +                                                     int64_t kern_coord,  | 
 | 39 | +                                                     int64_t stride,  | 
 | 40 | +                                                     int64_t dilation,  | 
 | 41 | +                                                     int64_t padding) {  | 
 | 42 | +    return out_coord * stride + kern_coord * dilation - padding;  | 
 | 43 | +}  | 
 | 44 | + | 
 | 45 | +struct whcn_layout {  | 
 | 46 | +    __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {  | 
 | 47 | +        return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;  | 
 | 48 | +    }  | 
 | 49 | + | 
 | 50 | +    __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {  | 
 | 51 | +        return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;  | 
 | 52 | +    }  | 
 | 53 | + | 
 | 54 | +    __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {  | 
 | 55 | +        return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;  | 
 | 56 | +    }  | 
 | 57 | + | 
 | 58 | +    __device__ static void unpack_indices(int64_t             global_idx,  | 
 | 59 | +                                          const conv_params & P,  | 
 | 60 | +                                          int64_t &           n,  | 
 | 61 | +                                          int64_t &           c,  | 
 | 62 | +                                          int64_t &           out_y,  | 
 | 63 | +                                          int64_t &           out_x) {  | 
 | 64 | +        out_x = global_idx % P.OW;  | 
 | 65 | +        out_y = (global_idx / P.OW) % P.OH;  | 
 | 66 | +        c     = (global_idx / (P.OW * P.OH)) % P.OC;  | 
 | 67 | +        n     = global_idx / (P.OW * P.OH * P.OC);  | 
 | 68 | +    }  | 
 | 69 | +};  | 
 | 70 | + | 
 | 71 | +template <typename T, typename Layout>  | 
 | 72 | +static __global__ void conv2d_kernel(const float * __restrict__ input,  | 
 | 73 | +                                     const T * __restrict__ kernel,  | 
 | 74 | +                                     float * __restrict__ output,  | 
 | 75 | +                                     const conv_params P) {  | 
 | 76 | +    const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;  | 
 | 77 | + | 
 | 78 | +    if (global_idx >= P.TOTAL) {  | 
 | 79 | +        return;  | 
 | 80 | +    }  | 
 | 81 | + | 
 | 82 | +    int64_t n, c_out, out_y, out_x;  | 
 | 83 | +    Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);  | 
 | 84 | + | 
 | 85 | +    T acc = 0;  | 
 | 86 | + | 
 | 87 | +    for (int64_t c_in = 0; c_in < P.IC; ++c_in) {  | 
 | 88 | +        kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);  | 
 | 89 | + | 
 | 90 | +        for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {  | 
 | 91 | +            const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);  | 
 | 92 | + | 
 | 93 | +            for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {  | 
 | 94 | +                const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);  | 
 | 95 | + | 
 | 96 | +                T input_val;  | 
 | 97 | +                if (std::is_same<T, half>::value) {  | 
 | 98 | +                    input_val = __float2half(input[Layout::input_index(n, c_in, in_y, in_x, P)]);  | 
 | 99 | +                } else {  | 
 | 100 | +                    input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];  | 
 | 101 | +                }  | 
 | 102 | + | 
 | 103 | +                T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];  | 
 | 104 | +                acc += (input_val * kernel_val);  | 
 | 105 | +            }  | 
 | 106 | +        }  | 
 | 107 | +    }  | 
 | 108 | + | 
 | 109 | +    // [N, OC, OH, OW]  | 
 | 110 | +    output[Layout::output_index(n, c_out, out_y, out_x, P)] = (float) acc;  | 
 | 111 | +}  | 
 | 112 | + | 
 | 113 | +template <typename T>  | 
 | 114 | +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {  | 
 | 115 | +    const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;  | 
 | 116 | +    conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);  | 
 | 117 | +}  | 
 | 118 | + | 
 | 119 | +static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {  | 
 | 120 | +    conv2d_cuda<half>(X_D, K_D, Y_D, P, st);  | 
 | 121 | +}  | 
 | 122 | + | 
 | 123 | +static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {  | 
 | 124 | +    conv2d_cuda<float>(X_D, K_D, Y_D, P, st);  | 
 | 125 | +}  | 
 | 126 | + | 
 | 127 | +void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {  | 
 | 128 | +    const ggml_tensor * kernel = dst->src[0];  | 
 | 129 | +    const ggml_tensor * input  = dst->src[1];  | 
 | 130 | +    float *             K_D    = (float *) kernel->data;  | 
 | 131 | +    const float *       X_D    = (const float *) input->data;  | 
 | 132 | +    float *             Y_D    = (float *) dst->data;  | 
 | 133 | + | 
 | 134 | +    GGML_ASSERT(ggml_is_contiguous(kernel));  | 
 | 135 | +    GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);  | 
 | 136 | + | 
 | 137 | +    // same number of input channels  | 
 | 138 | +    GGML_ASSERT(input->ne[2] == kernel->ne[2]);  | 
 | 139 | + | 
 | 140 | +    cudaStream_t st = ctx.stream();  | 
 | 141 | + | 
 | 142 | +    const int32_t * p    = (const int32_t *) dst->op_params;  | 
 | 143 | +    const int       ST_X = p[0];  // stride_x  | 
 | 144 | +    const int       ST_Y = p[1];  // stride_y  | 
 | 145 | +    const int       PD_X = p[2];  // padding_x  | 
 | 146 | +    const int       PD_Y = p[3];  // padding_y  | 
 | 147 | +    const int       DL_X = p[4];  // dilation_x  | 
 | 148 | +    const int       DL_Y = p[5];  // dilation_y  | 
 | 149 | + | 
 | 150 | +    // No cwhn  | 
 | 151 | +    GGML_ASSERT(p[6] == false);  | 
 | 152 | + | 
 | 153 | +    const int IW = input->ne[0];   // input_w  | 
 | 154 | +    const int IH = input->ne[1];   // input_h  | 
 | 155 | +    const int OW = dst->ne[0];     // output_w  | 
 | 156 | +    const int OH = dst->ne[1];     // output_h  | 
 | 157 | +    const int KW = kernel->ne[0];  // kernel_w  | 
 | 158 | +    const int KH = kernel->ne[1];  // kernel_h  | 
 | 159 | +    const int IC = input->ne[2];   // input_channels  | 
 | 160 | +    const int OC = kernel->ne[3];  // ouptut_chanles  | 
 | 161 | +    const int B  = input->ne[3];   // n_batches  | 
 | 162 | + | 
 | 163 | +    const int64_t total  = B * OC * OH * OW;  | 
 | 164 | +    conv_params   params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };  | 
 | 165 | + | 
 | 166 | +    if (kernel->type == GGML_TYPE_F16) {  | 
 | 167 | +        conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);  | 
 | 168 | +    } else {  | 
 | 169 | +        conv2d_cuda_f32(X_D, K_D, Y_D, params, st);  | 
 | 170 | +    }  | 
 | 171 | +}  | 
0 commit comments