@@ -365,4 +365,228 @@ function shmem_size(conf::GemmKernels.Config, ::typeof(matmul_pipelined))
365365 max(size_c, size_a + size_b, size_d)
366366end
367367
368+ function matmul_pipelined_ng(conf:: GemmKernels.Config , a, b, c, d,
369+ transf_gl2sh_a, transf_gl2sh_b, transf_gl2sh_c, transf_sh2gl_d,
370+ transf_sh2rf_a, transf_sh2rf_b, transf_sh2rf_c, transf_rf2sh_d,
371+ epilogue)
372+ # Calculate the number of fragments needed to fully cover a warp tile
373+ num_fragments_m = conf. compute_warp. M ÷ conf. compute_op_shape. M
374+ num_fragments_n = conf. compute_warp. N ÷ conf. compute_op_shape. N
375+
376+ # Constants
377+ block_i = (blockIdx(). x - 1 ) * conf. block_shape. M
378+ block_j = (blockIdx(). y - 1 ) * conf. block_shape. N
379+
380+ warpId = (threadIdx(). x - 1 ) ÷ 32 + 1
381+ laneId = (threadIdx(). x - 1 ) % 32 + 1
382+
383+ gemm_sz = Tile(conf. matmul_shape)
384+ block_tile = Tile(conf. block_shape)
385+
386+ # (1) Cooperatively load a block_shape.M x block_shape.N tile of C from global to shared memory within one threadblock
387+ shmem_c = @inbounds CuDynamicSharedArray(Layout. eltype(conf. shared_c_layout), Layout. physical_size(conf. shared_c_layout, block_tile. MN. size))
388+
389+ @loopinfo unroll for warp_tile = parallelise(block_tile. MN, Tile(conf. mem_cd_warp), warpId, conf. warps_per_block)
390+ @loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf. mem_cd_thread), laneId, 32 )
391+ x = @inbounds Layout. load(conf. global_c_layout, c, translate_base(thread_tile, (M = block_i, N = block_j)))
392+ x = transf_gl2sh_c(x, thread_tile)
393+ @inbounds Layout. store!(conf. shared_c_layout, shmem_c, x, thread_tile)
394+ end
395+ end
396+
397+ sync_threads()
398+
399+ # (2) Load a compute_warp.M x compute_warp.N tile of C from shared memory into registers
400+ warp_tile = @inbounds subdivide(block_tile. MN, Tile(conf. compute_warp). MN, warpId, conf. warps_per_block)
401+
402+ c_frags = LocalArray{Tuple{num_fragments_m, num_fragments_n}, Operator. fragtype_accum(conf. operator, conf. shared_c_layout)}(undef)
403+
404+ @loopinfo unroll for i = 1 : num_fragments_m
405+ @loopinfo unroll for j = 1 : num_fragments_n
406+ tile = translate_offset(warp_tile, (M = (i- 1 )* conf. compute_op_shape. M, N = (j- 1 )* conf. compute_op_shape. N))
407+ @inbounds @immutable c_frags[i, j] = transf_sh2rf_c(Operator. load_c(conf. operator, conf. shared_c_layout, shmem_c, tile), tile)
408+ end
409+ end
410+
411+ sync_threads()
412+
413+ # (3) Compute a block_shape.M x block_shape.N x block_shape.K matrix product within one threadblock
414+ shmem_a = @inbounds CuDynamicSharedArray(Layout. eltype(conf. shared_a_layout), (Layout. physical_size(conf. shared_a_layout, block_tile. MK. size). .. , 2 ))
415+ shmem_b = @inbounds CuDynamicSharedArray(Layout. eltype(conf. shared_b_layout), (Layout. physical_size(conf. shared_b_layout, block_tile. KN. size). .. , 2 ),
416+ length(shmem_a) * sizeof(Layout. eltype(conf. shared_a_layout)))
417+
418+ # Sizes of a_fragment and b_fragment
419+ a_frag_i = (block_tile. size. M * block_tile. size. K) ÷ (conf. mem_a_warp. M * conf. mem_a_warp. K * conf. warps_per_block)
420+ a_frag_j = (conf. mem_a_warp. M * conf. mem_a_warp. K) ÷ (conf. mem_a_thread. M * conf. mem_a_thread. K * 32 )
421+ b_frag_i = (block_tile. size. K * block_tile. size. N) ÷ (conf. mem_b_warp. K * conf. mem_b_warp. N * conf. warps_per_block)
422+ b_frag_j = (conf. mem_b_warp. K * conf. mem_b_warp. N) ÷ (conf. mem_b_thread. K * conf. mem_b_thread. N * 32 )
423+
424+ # Fragments to buffer the loads from global memory for A and B.
425+ a_fragment = LocalArray{Tuple{a_frag_i, a_frag_j}, Layout. fragtype(conf. global_a_layout, conf. mem_a_thread)}(undef)
426+ b_fragment = LocalArray{Tuple{b_frag_i, b_frag_j}, Layout. fragtype(conf. global_b_layout, conf. mem_b_thread)}(undef)
427+
428+ # Fragments to buffer the loads from shared memory for A and B.
429+ a_frags = LocalArray{Tuple{2 , num_fragments_m}, Operator. fragtype_a(conf. operator, conf. shared_a_layout)}(undef)
430+ b_frags = LocalArray{Tuple{2 , num_fragments_n}, Operator. fragtype_b(conf. operator, conf. shared_b_layout)}(undef)
431+
432+ warp_tile_mn = subdivide(block_tile, Tile(conf. compute_warp), warpId, conf. warps_per_block)
433+
434+ # Prologue.
435+
436+ # ld_global(main_loop_it = 0)
437+ @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile. MK, Tile(conf. mem_a_warp), warpId, conf. warps_per_block, conf. is_a_col_major))
438+ @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf. mem_a_thread), laneId, 32 , conf. is_a_col_major))
439+ @inbounds @immutable a_fragment[i, j] = Layout. load(conf. global_a_layout, a, translate_base(thread_tile, (M = block_i, K = 0 )))
440+ end
441+ end
442+
443+ @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile. KN, Tile(conf. mem_b_warp), warpId, conf. warps_per_block, conf. is_b_col_major))
444+ @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf. mem_b_thread), laneId, 32 , conf. is_b_col_major))
445+ @inbounds @immutable b_fragment[i, j] = Layout. load(conf. global_b_layout, b, translate_base(thread_tile, (K = 0 , N = block_j)))
446+ end
447+ end
448+
449+ # st_shared(main_loop_it = 0)
450+ @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile. MK, Tile(conf. mem_a_warp), warpId, conf. warps_per_block, conf. is_a_col_major))
451+ @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf. mem_a_thread), laneId, 32 , conf. is_a_col_major))
452+ x = transf_gl2sh_a(@inbounds(a_fragment[i, j]), thread_tile)
453+ @inbounds Layout. store!(conf. shared_a_layout, view(shmem_a, :, :, 1 ), x, thread_tile)
454+ end
455+ end
456+
457+ @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile. KN, Tile(conf. mem_b_warp), warpId, conf. warps_per_block, conf. is_b_col_major))
458+ @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf. mem_b_thread), laneId, 32 , conf. is_b_col_major))
459+ x = transf_gl2sh_b(@inbounds(b_fragment[i, j]), thread_tile)
460+ @inbounds Layout. store!(conf. shared_b_layout, view(shmem_b, :, :, 1 ), x, thread_tile)
461+ end
462+ end
463+
464+ sync_threads()
465+
466+ # ld_shared(main_loop_it = 0, warp_mma_k = 0)
467+ warp_mma_k = 0
468+ warp_k = warp_mma_k * conf. compute_op_shape. K
469+ warp_tile = translate_offset(warp_tile_mn, (M = 0 , N = 0 , K = warp_k))
470+ @loopinfo unroll for i = 1 : num_fragments_m
471+ a_tile = translate_offset(warp_tile. MK, (M = (i- 1 )* conf. compute_op_shape. M, K = 0 ))
472+ @inbounds @immutable a_frags[warp_mma_k % 2 + 1 , i] = transf_sh2rf_a(Operator. load_a(conf. operator, conf. shared_a_layout, view(shmem_a, :, :, 1 ), a_tile), a_tile)
473+ end
474+
475+ @loopinfo unroll for j = 1 : num_fragments_n
476+ b_tile = translate_offset(warp_tile. KN, (K = 0 , N = (j- 1 )* conf. compute_op_shape. N))
477+ @inbounds @immutable b_frags[warp_mma_k % 2 + 1 , j] = transf_sh2rf_b(Operator. load_b(conf. operator, conf. shared_b_layout, view(shmem_b, :, :, 1 ), b_tile), b_tile)
478+ end
479+
480+ NUM_MAIN_LOOP_ITERS = gemm_sz. size. K ÷ block_tile. size. K
481+ @loopinfo unrollcount= 2 for main_loop_it = 0 : NUM_MAIN_LOOP_ITERS - 1
482+ block_k = main_loop_it * block_tile. size. K
483+
484+ main_loop_it_next = (main_loop_it + 1 ) % NUM_MAIN_LOOP_ITERS
485+ block_k_next = main_loop_it_next * block_tile. size. K
486+
487+ NUM_WARP_MMA_K_ITERS = conf. block_shape. K ÷ conf. compute_op_shape. K
488+ @loopinfo unroll for warp_mma_k = 0 : NUM_WARP_MMA_K_ITERS - 1
489+ warp_k = warp_mma_k * conf. compute_op_shape. K
490+ warp_tile = translate_offset(warp_tile_mn, (M = 0 , N = 0 , K = warp_k))
491+
492+ warp_mma_k_next = (warp_mma_k + 1 ) % NUM_WARP_MMA_K_ITERS
493+ warp_k_next = warp_mma_k_next * conf. compute_op_shape. K
494+ warp_tile_next = translate_offset(warp_tile_mn, (M = 0 , N = 0 , K = warp_k_next))
495+
496+ main_loop_it_next_warp_k = if warp_mma_k_next == 0
497+ main_loop_it_next
498+ else
499+ main_loop_it
500+ end
501+
502+ if warp_mma_k == conf. block_shape. K ÷ conf. compute_op_shape. K - 1 # last iteration of inner warp loop.
503+ # st.shared(main_loop_it_next)
504+ @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile. MK, Tile(conf. mem_a_warp), warpId, conf. warps_per_block, conf. is_a_col_major))
505+ @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf. mem_a_thread), laneId, 32 , conf. is_a_col_major))
506+ x = transf_gl2sh_a(@inbounds(a_fragment[i, j]), thread_tile)
507+ @inbounds Layout. store!(conf. shared_a_layout, view(shmem_a, :, :, main_loop_it_next % 2 + 1 ), x, thread_tile)
508+ end
509+ end
510+
511+ @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile. KN, Tile(conf. mem_b_warp), warpId, conf. warps_per_block, conf. is_b_col_major))
512+ @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf. mem_b_thread), laneId, 32 , conf. is_b_col_major))
513+ x = transf_gl2sh_b(@inbounds(b_fragment[i, j]), thread_tile)
514+ @inbounds Layout. store!(conf. shared_b_layout, view(shmem_b, :, :, main_loop_it_next % 2 + 1 ), x, thread_tile)
515+ end
516+ end
517+
518+ sync_threads()
519+ end
520+
521+ # ld.shared(main_loop_it, warp_mma_k_next)
522+ @loopinfo unroll for i = 1 : num_fragments_m
523+ a_tile = translate_offset(warp_tile_next. MK, (M = (i- 1 )* conf. compute_op_shape. M, K = 0 ))
524+ @inbounds @immutable a_frags[warp_mma_k_next % 2 + 1 , i] = transf_sh2rf_a(Operator. load_a(conf. operator, conf. shared_a_layout, view(shmem_a, :, :, main_loop_it_next_warp_k % 2 + 1 ), a_tile), a_tile)
525+ end
526+
527+ @loopinfo unroll for j = 1 : num_fragments_n
528+ b_tile = translate_offset(warp_tile_next. KN, (K = 0 , N = (j- 1 )* conf. compute_op_shape. N))
529+ @inbounds @immutable b_frags[warp_mma_k_next % 2 + 1 , j] = transf_sh2rf_b(Operator. load_b(conf. operator, conf. shared_b_layout, view(shmem_b, :, :, main_loop_it_next_warp_k % 2 + 1 ), b_tile), b_tile)
530+ end
531+
532+ if warp_mma_k == 0 # first iteration of inner warp loop.
533+ # ld.global(main_loop_it_next)
534+ @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile. MK, Tile(conf. mem_a_warp), warpId, conf. warps_per_block, conf. is_a_col_major))
535+ @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf. mem_a_thread), laneId, 32 , conf. is_a_col_major))
536+ @inbounds @immutable a_fragment[i, j] = Layout. load(conf. global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_k_next)))
537+ end
538+ end
539+
540+ @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile. KN, Tile(conf. mem_b_warp), warpId, conf. warps_per_block, conf. is_b_col_major))
541+ @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf. mem_b_thread), laneId, 32 , conf. is_b_col_major))
542+ @inbounds @immutable b_fragment[i, j] = Layout. load(conf. global_b_layout, b, translate_base(thread_tile, (K = block_k_next, N = block_j)))
543+ end
544+ end
545+ end
546+
547+ # mma(main_loop_it, warp_mma_k)
548+ @loopinfo unroll for i = 1 : num_fragments_m
549+ @loopinfo unroll for j = 1 : num_fragments_n
550+ @inbounds @immutable c_frags[i, j] = Operator. mma(conf. operator, a_frags[warp_mma_k % 2 + 1 , i], b_frags[warp_mma_k % 2 + 1 , j], c_frags[i, j])
551+ end
552+ end
553+ end
554+ end
555+
556+ # (4) Store the compute_warp.M x compute_warp.N tile of D from registers to shared memory
557+ shmem_d = @inbounds CuDynamicSharedArray(Layout. eltype(conf. shared_d_layout), Layout. physical_size(conf. shared_d_layout, block_tile. MN. size))
558+
559+ warp_tile = @inbounds subdivide(block_tile. MN, Tile(conf. compute_warp). MN, warpId, conf. warps_per_block)
560+
561+ @loopinfo unroll for i = 1 : num_fragments_m
562+ @loopinfo unroll for j = 1 : num_fragments_n
563+ tile = translate_offset(warp_tile, (M = (i- 1 )* conf. compute_op_shape. M, N = (j- 1 )* conf. compute_op_shape. N))
564+ @inbounds Operator. store_d(conf. operator, conf. shared_d_layout, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
565+ end
566+ end
567+
568+ sync_threads()
569+
570+ # (5) Run the epilogue
571+ epilogue(conf, d, shmem_d, transf_sh2gl_d)
572+
573+ return
574+ end
575+
576+ function shmem_size(conf:: GemmKernels.Config , :: typeof (matmul_pipelined_ng))
577+ size_a = sizeof(Layout. eltype(conf. shared_a_layout)) *
578+ prod(Layout. physical_size(conf. shared_a_layout,
579+ (; conf. block_shape. M, conf. block_shape. K)))
580+ size_b = sizeof(Layout. eltype(conf. shared_b_layout)) *
581+ prod(Layout. physical_size(conf. shared_b_layout,
582+ (; conf. block_shape. K, conf. block_shape. N)))
583+ size_c = sizeof(Layout. eltype(conf. shared_c_layout)) *
584+ prod(Layout. physical_size(conf. shared_c_layout,
585+ (; conf. block_shape. M, conf. block_shape. N)))
586+ size_d = sizeof(Layout. eltype(conf. shared_d_layout)) *
587+ prod(Layout. physical_size(conf. shared_d_layout,
588+ (; conf. block_shape. M, conf. block_shape. N)))
589+ max(size_c, 2 * (size_a + size_b), size_d)
590+ end
591+
368592end
0 commit comments