33#include " convert.cuh"
44#include " mma.cuh"
55
6- #define CEIL_DIV (M, N ) (((M) + (N) - 1 ) / (N))
7-
8- static uint32_t ceil_div (uint32_t M, uint32_t N);
9- static int get_sm_count ();
10-
11- uint32_t ceil_div (uint32_t M, uint32_t N) {
12- return (M + N - 1 ) / N;
6+ constexpr static size_t ceil_div (const size_t m, const size_t n) {
7+ return (m + n - 1 ) / n;
138}
149
1510__align__ (16 ) struct Params {
@@ -25,23 +20,15 @@ __align__(16) struct Params {
2520 uint32_t IC_KH_KW, N_OH_OW;
2621 uint32_t IK_TOTAL, IN_TOTAL;
2722
28- uint32_t KWmp;
29- uint32_t KWL;
30- uint32_t KWKHmp;
31- uint32_t KWKHL;
32- uint32_t OWmp;
33- uint32_t OWL;
34- uint32_t OWOHmp;
35- uint32_t OWOHL;
23+ // fastdiv
24+ uint3 KW_fastdiv;
25+ uint3 KWKH_fastdiv;
26+ uint3 OW_fastdiv;
27+ uint3 OWOH_fastdiv;
3628};
3729
3830__constant__ __device__ Params P;
3931
40- // see init_fastdiv_values in ggml-vulkan.cpp
41- __inline__ __device__ uint fastdiv (uint n, uint mp, uint L) {
42- return (__umulhi (n, mp) + n) >> L;
43- }
44-
4532__device__ struct T_ICKHKW {
4633 const uint32_t ic, kh, kw;
4734};
@@ -82,20 +69,20 @@ struct whcn_layout {
8269
8370 __device__ __forceinline__ static T_ICKHKW unpack_ickhkw (const uint32_t & idx) {
8471 // const uint32_t ic = idx / (P.KW * P.KH);
85- const uint32_t ic = fastdiv (idx, P.KWKHmp , P. KWKHL );
72+ const uint32_t ic = fastdiv (idx, P.KWKH_fastdiv );
8673 const uint32_t r = idx - ic * (P.KW * P.KH );
8774 // const uint32_t kh = r / P.KW;
88- const uint32_t kh = fastdiv (r, P.KWmp , P. KWL );
75+ const uint32_t kh = fastdiv (r, P.KW_fastdiv );
8976 const uint32_t kw = r - kh * P.KW ;
9077 return T_ICKHKW{ ic, kh, kw };
9178 }
9279
9380 __device__ __forceinline__ static T_NOHOW unpack_nohow (const uint32_t & idx) {
9481 // const uint32_t n = idx / (P.OH * P.OW);
95- const uint32_t n = fastdiv (idx, P.OWOHmp , P. OWOHL );
82+ const uint32_t n = fastdiv (idx, P.OWOH_fastdiv );
9683 const uint32_t r = idx - n * (P.OH * P.OW );
9784 // const uint32_t oh = r / P.OW;
98- const uint32_t oh = fastdiv (r, P.OWmp , P. OWL );
85+ const uint32_t oh = fastdiv (r, P.OW_fastdiv );
9986 const uint32_t ow = r - oh * P.OW ;
10087 return T_NOHOW{ n, oh, ow };
10188 }
@@ -113,7 +100,6 @@ template <typename layout,
113100 const uint32_t BS_NOHOW,
114101 const uint32_t BS_ICKHKW,
115102 const uint32_t NUM_TILES_PER_WARP,
116- const uint32_t NUM_WARPS_NEED,
117103 const uint32_t NUM_WARPS_NOHOW,
118104 const uint32_t NUM_WARPS,
119105 const uint32_t WG_SIZE>
@@ -222,43 +208,18 @@ __global__ void __launch_bounds__(NUM_WARPS * WARP_SIZE) conv2d_tensor_cores_ker
222208 }
223209}
224210
225- // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
226- // Precompute mp (m' in the paper) and L such that division
227- // can be computed using a multiply (high 32b of 64b result)
228- // and a shift:
229- //
230- // n/d = (mulhi(n, mp) + n) >> L;
231- static void init_fastdiv_values (uint32_t d, uint32_t & mp, uint32_t & L) {
232- // compute L = ceil(log2(d));
233- L = 0 ;
234- while (L < 32 && (uint32_t { 1 } << L) < d) {
235- L++;
236- }
237-
238- mp = (uint32_t ) ((uint64_t { 1 } << 32 ) * ((uint64_t { 1 } << L) - d) / d + 1 );
239- }
240-
241211constexpr int conv_shapes[][NUM_VARIANTS] = {
242212 { 128 , 64 , 32 }, // BS_OC
243213 { 16 , 32 , 16 }, // BS_ICKHKW
244214 { 128 , 32 , 256 }, // BS_NOHOW
245215};
246216
247- int get_sm_count () {
248- int device;
249- cudaGetDevice (&device);
250-
251- int sm_count;
252- cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device);
253- return sm_count;
254- }
255-
256217template <uint CONV_SHAPE>
257- void conv_2d_tensor_core (const float * src0,
258- const half * src1,
259- float * dst,
260- const Params & p,
261- const cudaStream_t & st) {
218+ static void conv_2d_tensor_core (const float * src0,
219+ const half * src1,
220+ float * dst,
221+ const Params & p,
222+ const cudaStream_t & st) {
262223 constexpr uint32_t WG_SIZE = 256 ;
263224 static_assert (WG_SIZE % WARP_SIZE == 0 );
264225
@@ -270,24 +231,24 @@ void conv_2d_tensor_core(const float * src0,
270231
271232 static_assert (BS_OC % WMMA_M == 0 && BS_NOHOW % WMMA_N == 0 );
272233
273- constexpr uint32_t NUM_WARPS_NEED = (BS_OC * BS_NOHOW) / (WMMA_M * WMMA_N);
234+ constexpr uint32_t NUM_TILES_TOTAL = (BS_OC * BS_NOHOW) / (WMMA_M * WMMA_N);
274235 constexpr uint32_t NUM_WARPS_NOHOW = BS_NOHOW / WMMA_N;
275236
276- static_assert (NUM_WARPS_NEED % NUM_WARPS == 0 );
237+ static_assert (NUM_TILES_TOTAL % NUM_WARPS == 0 );
277238
278- constexpr uint32_t NUM_TILES_PER_WARP = NUM_WARPS_NEED / NUM_WARPS;
239+ constexpr uint32_t NUM_TILES_PER_WARP = NUM_TILES_TOTAL / NUM_WARPS;
279240
280241 const int64_t NOHOW = p.B * p.OW * p.OH ;
281- const uint32_t NB_OC = CEIL_DIV (p.Cout , BS_OC);
282- const uint32_t NB_NOHOW = CEIL_DIV (NOHOW, BS_NOHOW);
242+ const uint32_t NB_OC = ceil_div (p.Cout , BS_OC);
243+ const uint32_t NB_NOHOW = ceil_div (NOHOW, BS_NOHOW);
283244
284245 cudaMemcpyToSymbolAsync (P, &p, sizeof (Params), 0 , cudaMemcpyHostToDevice, st);
285246
286247 dim3 gridDim (NB_OC, NB_NOHOW);
287248 constexpr dim3 blockDim (WARP_SIZE, NUM_WARPS);
288249
289- conv2d_tensor_cores_kernel<whcn_layout, BS_OC, BS_NOHOW, BS_ICKHKW, NUM_TILES_PER_WARP, NUM_WARPS_NEED ,
290- NUM_WARPS_NOHOW, NUM_WARPS, WG_SIZE><<<gridDim , blockDim , 0 , st>>> (src0, src1, dst);
250+ conv2d_tensor_cores_kernel<whcn_layout, BS_OC, BS_NOHOW, BS_ICKHKW, NUM_TILES_PER_WARP, NUM_WARPS_NOHOW, NUM_WARPS ,
251+ WG_SIZE><<<gridDim , blockDim , 0 , st>>> (src0, src1, dst);
291252}
292253
293254void ggml_cuda_op_conv2d_tensor_core (const uint32_t & IW,
@@ -341,15 +302,15 @@ void ggml_cuda_op_conv2d_tensor_core(const uint32_t & IW,
341302 p.N_OH_OW = B * OH * OW;
342303 p.IN_TOTAL = B * IC * IH * IW;
343304
344- init_fastdiv_values (p. KW , p. KWmp , p. KWL );
345- init_fastdiv_values (p.KW * p.KH , p. KWKHmp , p. KWKHL );
346- init_fastdiv_values (p. OW , p. OWmp , p. OWL );
347- init_fastdiv_values (p.OW * p.OH , p. OWOHmp , p. OWOHL );
305+ p. KW_fastdiv = init_fastdiv_values (p. KW );
306+ p. KWKH_fastdiv = init_fastdiv_values (p.KW * p.KH );
307+ p. OW_fastdiv = init_fastdiv_values (p. OW );
308+ p. OWOH_fastdiv = init_fastdiv_values (p.OW * p.OH );
348309
349310 // Problem size (Cout x NOHOW)
350- std::array<uint32_t , 3 > elements = { p.Cout , p.B * p.OW * p.OH , 1 };
311+ std::array<uint32_t , 2 > elements = { p.Cout , p.B * p.OW * p.OH };
351312
352- const uint32_t sm_count = get_sm_count () ;
313+ const uint32_t sm_count = ggml_cuda_info (). devices [ ggml_cuda_get_device ()]. nsm ;
353314
354315 uint32_t variant_ntiles[NUM_VARIANTS];
355316
0 commit comments