11#include " conv2d-mm.cuh"
2+ #include " convert.cuh"
23
34#include < cuda_runtime.h>
45
1314
1415#define CEIL_DIV (M, N ) (((M) + (N) - 1 ) / (N))
1516
17+ typedef uint32_t uint;
18+
1619uint32_t ceil_div (uint32_t M, uint32_t N);
1720int get_sm_count ();
1821
@@ -69,11 +72,11 @@ __inline__ __device__ uint fastdiv(uint n, uint mp, uint L) {
6972}
7073
7174// --> conv_2d kernel modified to function as a matmul
72- template <uint BS_K, uint BS_NPQ, uint BS_CRS, uint TS_K, uint TS_NPQ, uint WG_SIZE, uint VEC_SIZE>
75+ template <typename T, uint BS_K, uint BS_NPQ, uint BS_CRS, uint TS_K, uint TS_NPQ, uint WG_SIZE, uint VEC_SIZE>
7376__global__ void __launch_bounds__ (WG_SIZE, 1 ) mm(uint K,
7477 uint NPQ,
7578 uint CRS,
76- const float * knl_data,
79+ const T * knl_data,
7780 const float * src_data,
7881 float * dst_data) {
7982 // Each block computes a tile of the result of size BS_K*BS_NPQ
@@ -98,7 +101,8 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
98101 const uint T_y = threadIdx .x / NT_x;
99102 const uint T_x = threadIdx .x % NT_x;
100103
101- __shared__ float Ash[BS_K * BS_CRS];
104+ // __shared__ float Ash[BS_K * BS_CRS];
105+ __shared__ T Ash[BS_K * BS_CRS];
102106 __shared__ float Bsh[BS_CRS * BS_NPQ];
103107
104108 const uint Ar = threadIdx .x / BS_CRS;
@@ -148,9 +152,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
148152 // General addressing (does not assume contiguity)
149153 // const uint32_t knl_idx = KW_idx_a + KH_idx_a*dp.nb01 + Cin_idx_a*dp.nb02 + K_idx_a*dp.nb03;
150154 // Contiguous addressing
151- float val = knl_data[min (CRS_idx_a + K_idx_a * dp.nb03 , K * CRS - 1 )];
155+ T val = knl_data[min (CRS_idx_a + K_idx_a * dp.nb03 , K * CRS - 1 )];
152156 if (CRS_idx_a >= CRS || K_idx_a >= K) {
153- val = 0.0 ;
157+ val = (T) 0.0 ;
154158 }
155159
156160#ifdef A_TRANS
@@ -235,9 +239,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
235239# else
236240 uint32_t col_offset = (T_y * TS_K + T_ly);
237241# endif
238- regA[T_ly] = Ash[CRS_lidx * BS_K + col_offset];
242+ regA[T_ly] = ggml_cuda_cast< float >( Ash[CRS_lidx * BS_K + col_offset]) ;
239243#else
240- regA[T_ly] = Ash[(T_y * TS_K + T_ly) * BS_CRS + CRS_lidx];
244+ regA[T_ly] = ggml_cuda_cast< float >( Ash[(T_y * TS_K + T_ly) * BS_CRS + CRS_lidx]) ;
241245#endif
242246 }
243247 for (uint32_t T_lx = 0 ; T_lx < TS_NPQ; ++T_lx) {
@@ -343,16 +347,22 @@ void ggml_cuda_op_conv_2d_variant(ggml_backend_cuda_context & ctx,
343347 cudaMemcpyToSymbol (dp, &p, sizeof (Params));
344348
345349 // Kernel arguments
346- float * src0_data = (float *) src0->data ;
347350 float * src1_data = (float *) src1->data ;
348351 float * dst_data = (float *) dst->data ;
349352
350353 dim3 gridDim (NB_K, NB_NPQ);
351354 dim3 blockDim (WG_SIZE);
352355 cudaStream_t stream = ctx.stream ();
353-
354- mm<BS_K, BS_NPQ, BS_CRS, TS_K, TS_NPQ, WG_SIZE, VEC_SIZE>
356+ if (src0->type == GGML_TYPE_F16) {
357+ half *src0_data = (half *) src0->data ;
358+ mm<half, BS_K, BS_NPQ, BS_CRS, TS_K, TS_NPQ, WG_SIZE, VEC_SIZE>
359+ <<<gridDim , blockDim , 0 , stream>>> (p.Cout , NPQ, p.Cin * p.KW * p.KH , src0_data, src1_data, dst_data);
360+ } else {
361+ float *src0_data = (float *) src0->data ;
362+ mm<float , BS_K, BS_NPQ, BS_CRS, TS_K, TS_NPQ, WG_SIZE, VEC_SIZE>
355363 <<<gridDim , blockDim , 0 , stream>>> (p.Cout , NPQ, p.Cin * p.KW * p.KH , src0_data, src1_data, dst_data);
364+ }
365+
356366}
357367
358368void ggml_cuda_op_conv2d_mm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -372,13 +382,13 @@ void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
372382 ggml_tensor * src0 = dst->src [0 ];
373383 ggml_tensor * src1 = dst->src [1 ];
374384
375- GGML_ASSERT (src0->type == GGML_TYPE_F32);
385+ // GGML_ASSERT(src0->type == GGML_TYPE_F32);
376386 GGML_ASSERT (src1->type == GGML_TYPE_F32);
377387 GGML_ASSERT (dst->type == GGML_TYPE_F32);
378388
379389 GGML_TENSOR_BINARY_OP_LOCALS
380390
381- GGML_ASSERT (nb00 == sizeof (float ));
391+ // GGML_ASSERT(nb00 == sizeof(float));
382392 GGML_ASSERT (nb10 == sizeof (float ));
383393 GGML_ASSERT (nb0 == sizeof (float ));
384394
0 commit comments