Skip to content

Commit 8a2e443

Browse files
Add new pipelined kernel (#196)
Add alternative pipelining kernel. Compared to the old pipelining kernel, the loads/stores are reordered somewhat, and shared memory is split in two stages. This reduces the number of necessary bar.syncs to 1/3, but necessitates halving the BLOCK_K tile size.
1 parent 8c894fd commit 8a2e443

File tree

1 file changed

+224
-0
lines changed

1 file changed

+224
-0
lines changed

src/kernel.jl

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,4 +365,228 @@ function shmem_size(conf::GemmKernels.Config, ::typeof(matmul_pipelined))
365365
max(size_c, size_a + size_b, size_d)
366366
end
367367

368+
function matmul_pipelined_ng(conf::GemmKernels.Config, a, b, c, d,
369+
transf_gl2sh_a, transf_gl2sh_b, transf_gl2sh_c, transf_sh2gl_d,
370+
transf_sh2rf_a, transf_sh2rf_b, transf_sh2rf_c, transf_rf2sh_d,
371+
epilogue)
372+
# Calculate the number of fragments needed to fully cover a warp tile
373+
num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M
374+
num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N
375+
376+
# Constants
377+
block_i = (blockIdx().x - 1) * conf.block_shape.M
378+
block_j = (blockIdx().y - 1) * conf.block_shape.N
379+
380+
warpId = (threadIdx().x - 1) ÷ 32 + 1
381+
laneId = (threadIdx().x - 1) % 32 + 1
382+
383+
gemm_sz = Tile(conf.matmul_shape)
384+
block_tile = Tile(conf.block_shape)
385+
386+
# (1) Cooperatively load a block_shape.M x block_shape.N tile of C from global to shared memory within one threadblock
387+
shmem_c = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_c_layout), Layout.physical_size(conf.shared_c_layout, block_tile.MN.size))
388+
389+
@loopinfo unroll for warp_tile = parallelise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block)
390+
@loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32)
391+
x = @inbounds Layout.load(conf.global_c_layout, c, translate_base(thread_tile, (M = block_i, N = block_j)))
392+
x = transf_gl2sh_c(x, thread_tile)
393+
@inbounds Layout.store!(conf.shared_c_layout, shmem_c, x, thread_tile)
394+
end
395+
end
396+
397+
sync_threads()
398+
399+
# (2) Load a compute_warp.M x compute_warp.N tile of C from shared memory into registers
400+
warp_tile = @inbounds subdivide(block_tile.MN, Tile(conf.compute_warp).MN, warpId, conf.warps_per_block)
401+
402+
c_frags = LocalArray{Tuple{num_fragments_m, num_fragments_n}, Operator.fragtype_accum(conf.operator, conf.shared_c_layout)}(undef)
403+
404+
@loopinfo unroll for i = 1 : num_fragments_m
405+
@loopinfo unroll for j = 1 : num_fragments_n
406+
tile = translate_offset(warp_tile, (M = (i-1)*conf.compute_op_shape.M, N = (j-1)*conf.compute_op_shape.N))
407+
@inbounds @immutable c_frags[i, j] = transf_sh2rf_c(Operator.load_c(conf.operator, conf.shared_c_layout, shmem_c, tile), tile)
408+
end
409+
end
410+
411+
sync_threads()
412+
413+
# (3) Compute a block_shape.M x block_shape.N x block_shape.K matrix product within one threadblock
414+
shmem_a = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_a_layout), (Layout.physical_size(conf.shared_a_layout, block_tile.MK.size)..., 2))
415+
shmem_b = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_b_layout), (Layout.physical_size(conf.shared_b_layout, block_tile.KN.size)..., 2),
416+
length(shmem_a) * sizeof(Layout.eltype(conf.shared_a_layout)))
417+
418+
# Sizes of a_fragment and b_fragment
419+
a_frag_i = (block_tile.size.M * block_tile.size.K) ÷ (conf.mem_a_warp.M * conf.mem_a_warp.K * conf.warps_per_block)
420+
a_frag_j = (conf.mem_a_warp.M * conf.mem_a_warp.K) ÷ (conf.mem_a_thread.M * conf.mem_a_thread.K * 32)
421+
b_frag_i = (block_tile.size.K * block_tile.size.N) ÷ (conf.mem_b_warp.K * conf.mem_b_warp.N * conf.warps_per_block)
422+
b_frag_j = (conf.mem_b_warp.K * conf.mem_b_warp.N) ÷ (conf.mem_b_thread.K * conf.mem_b_thread.N * 32)
423+
424+
# Fragments to buffer the loads from global memory for A and B.
425+
a_fragment = LocalArray{Tuple{a_frag_i, a_frag_j}, Layout.fragtype(conf.global_a_layout, conf.mem_a_thread)}(undef)
426+
b_fragment = LocalArray{Tuple{b_frag_i, b_frag_j}, Layout.fragtype(conf.global_b_layout, conf.mem_b_thread)}(undef)
427+
428+
# Fragments to buffer the loads from shared memory for A and B.
429+
a_frags = LocalArray{Tuple{2, num_fragments_m}, Operator.fragtype_a(conf.operator, conf.shared_a_layout)}(undef)
430+
b_frags = LocalArray{Tuple{2, num_fragments_n}, Operator.fragtype_b(conf.operator, conf.shared_b_layout)}(undef)
431+
432+
warp_tile_mn = subdivide(block_tile, Tile(conf.compute_warp), warpId, conf.warps_per_block)
433+
434+
# Prologue.
435+
436+
# ld_global(main_loop_it = 0)
437+
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
438+
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
439+
@inbounds @immutable a_fragment[i, j] = Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = 0)))
440+
end
441+
end
442+
443+
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
444+
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
445+
@inbounds @immutable b_fragment[i, j] = Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = 0, N = block_j)))
446+
end
447+
end
448+
449+
# st_shared(main_loop_it = 0)
450+
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
451+
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
452+
x = transf_gl2sh_a(@inbounds(a_fragment[i, j]), thread_tile)
453+
@inbounds Layout.store!(conf.shared_a_layout, view(shmem_a, :, :, 1), x, thread_tile)
454+
end
455+
end
456+
457+
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
458+
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
459+
x = transf_gl2sh_b(@inbounds(b_fragment[i, j]), thread_tile)
460+
@inbounds Layout.store!(conf.shared_b_layout, view(shmem_b, :, :, 1), x, thread_tile)
461+
end
462+
end
463+
464+
sync_threads()
465+
466+
# ld_shared(main_loop_it = 0, warp_mma_k = 0)
467+
warp_mma_k = 0
468+
warp_k = warp_mma_k * conf.compute_op_shape.K
469+
warp_tile = translate_offset(warp_tile_mn, (M = 0, N = 0, K = warp_k))
470+
@loopinfo unroll for i = 1 : num_fragments_m
471+
a_tile = translate_offset(warp_tile.MK, (M = (i-1)*conf.compute_op_shape.M, K = 0))
472+
@inbounds @immutable a_frags[warp_mma_k % 2 + 1, i] = transf_sh2rf_a(Operator.load_a(conf.operator, conf.shared_a_layout, view(shmem_a, :, :, 1), a_tile), a_tile)
473+
end
474+
475+
@loopinfo unroll for j = 1 : num_fragments_n
476+
b_tile = translate_offset(warp_tile.KN, (K = 0, N = (j-1)*conf.compute_op_shape.N))
477+
@inbounds @immutable b_frags[warp_mma_k % 2 + 1, j] = transf_sh2rf_b(Operator.load_b(conf.operator, conf.shared_b_layout, view(shmem_b, :, :, 1), b_tile), b_tile)
478+
end
479+
480+
NUM_MAIN_LOOP_ITERS = gemm_sz.size.K ÷ block_tile.size.K
481+
@loopinfo unrollcount=2 for main_loop_it = 0 : NUM_MAIN_LOOP_ITERS - 1
482+
block_k = main_loop_it * block_tile.size.K
483+
484+
main_loop_it_next = (main_loop_it + 1) % NUM_MAIN_LOOP_ITERS
485+
block_k_next = main_loop_it_next * block_tile.size.K
486+
487+
NUM_WARP_MMA_K_ITERS = conf.block_shape.K ÷ conf.compute_op_shape.K
488+
@loopinfo unroll for warp_mma_k = 0 : NUM_WARP_MMA_K_ITERS - 1
489+
warp_k = warp_mma_k * conf.compute_op_shape.K
490+
warp_tile = translate_offset(warp_tile_mn, (M = 0, N = 0, K = warp_k))
491+
492+
warp_mma_k_next = (warp_mma_k + 1) % NUM_WARP_MMA_K_ITERS
493+
warp_k_next = warp_mma_k_next * conf.compute_op_shape.K
494+
warp_tile_next = translate_offset(warp_tile_mn, (M = 0, N = 0, K = warp_k_next))
495+
496+
main_loop_it_next_warp_k = if warp_mma_k_next == 0
497+
main_loop_it_next
498+
else
499+
main_loop_it
500+
end
501+
502+
if warp_mma_k == conf.block_shape.K ÷ conf.compute_op_shape.K - 1 # last iteration of inner warp loop.
503+
# st.shared(main_loop_it_next)
504+
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
505+
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
506+
x = transf_gl2sh_a(@inbounds(a_fragment[i, j]), thread_tile)
507+
@inbounds Layout.store!(conf.shared_a_layout, view(shmem_a, :, :, main_loop_it_next % 2 + 1), x, thread_tile)
508+
end
509+
end
510+
511+
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
512+
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
513+
x = transf_gl2sh_b(@inbounds(b_fragment[i, j]), thread_tile)
514+
@inbounds Layout.store!(conf.shared_b_layout, view(shmem_b, :, :, main_loop_it_next % 2 + 1), x, thread_tile)
515+
end
516+
end
517+
518+
sync_threads()
519+
end
520+
521+
# ld.shared(main_loop_it, warp_mma_k_next)
522+
@loopinfo unroll for i = 1 : num_fragments_m
523+
a_tile = translate_offset(warp_tile_next.MK, (M = (i-1)*conf.compute_op_shape.M, K = 0))
524+
@inbounds @immutable a_frags[warp_mma_k_next % 2 + 1, i] = transf_sh2rf_a(Operator.load_a(conf.operator, conf.shared_a_layout, view(shmem_a, :, :, main_loop_it_next_warp_k % 2 + 1), a_tile), a_tile)
525+
end
526+
527+
@loopinfo unroll for j = 1 : num_fragments_n
528+
b_tile = translate_offset(warp_tile_next.KN, (K = 0, N = (j-1)*conf.compute_op_shape.N))
529+
@inbounds @immutable b_frags[warp_mma_k_next % 2 + 1, j] = transf_sh2rf_b(Operator.load_b(conf.operator, conf.shared_b_layout, view(shmem_b, :, :, main_loop_it_next_warp_k % 2 + 1), b_tile), b_tile)
530+
end
531+
532+
if warp_mma_k == 0 # first iteration of inner warp loop.
533+
# ld.global(main_loop_it_next)
534+
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
535+
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
536+
@inbounds @immutable a_fragment[i, j] = Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_k_next)))
537+
end
538+
end
539+
540+
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
541+
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
542+
@inbounds @immutable b_fragment[i, j] = Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = block_k_next, N = block_j)))
543+
end
544+
end
545+
end
546+
547+
# mma(main_loop_it, warp_mma_k)
548+
@loopinfo unroll for i = 1 : num_fragments_m
549+
@loopinfo unroll for j = 1 : num_fragments_n
550+
@inbounds @immutable c_frags[i, j] = Operator.mma(conf.operator, a_frags[warp_mma_k % 2 + 1, i], b_frags[warp_mma_k % 2 + 1, j], c_frags[i, j])
551+
end
552+
end
553+
end
554+
end
555+
556+
# (4) Store the compute_warp.M x compute_warp.N tile of D from registers to shared memory
557+
shmem_d = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_d_layout), Layout.physical_size(conf.shared_d_layout, block_tile.MN.size))
558+
559+
warp_tile = @inbounds subdivide(block_tile.MN, Tile(conf.compute_warp).MN, warpId, conf.warps_per_block)
560+
561+
@loopinfo unroll for i = 1 : num_fragments_m
562+
@loopinfo unroll for j = 1 : num_fragments_n
563+
tile = translate_offset(warp_tile, (M = (i-1)*conf.compute_op_shape.M, N = (j-1)*conf.compute_op_shape.N))
564+
@inbounds Operator.store_d(conf.operator, conf.shared_d_layout, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
565+
end
566+
end
567+
568+
sync_threads()
569+
570+
# (5) Run the epilogue
571+
epilogue(conf, d, shmem_d, transf_sh2gl_d)
572+
573+
return
574+
end
575+
576+
function shmem_size(conf::GemmKernels.Config, ::typeof(matmul_pipelined_ng))
577+
size_a = sizeof(Layout.eltype(conf.shared_a_layout)) *
578+
prod(Layout.physical_size(conf.shared_a_layout,
579+
(; conf.block_shape.M, conf.block_shape.K)))
580+
size_b = sizeof(Layout.eltype(conf.shared_b_layout)) *
581+
prod(Layout.physical_size(conf.shared_b_layout,
582+
(; conf.block_shape.K, conf.block_shape.N)))
583+
size_c = sizeof(Layout.eltype(conf.shared_c_layout)) *
584+
prod(Layout.physical_size(conf.shared_c_layout,
585+
(; conf.block_shape.M, conf.block_shape.N)))
586+
size_d = sizeof(Layout.eltype(conf.shared_d_layout)) *
587+
prod(Layout.physical_size(conf.shared_d_layout,
588+
(; conf.block_shape.M, conf.block_shape.N)))
589+
max(size_c, 2 * (size_a + size_b), size_d)
590+
end
591+
368592
end

0 commit comments

Comments
 (0)