Skip to content

Support PermuteN for 2D block scale GEMM with block 128(N)#5040

Open
amd-khushbu wants to merge 29 commits intodevelopfrom
ck/khuagarw/AICK-442
Open

Support PermuteN for 2D block scale GEMM with block 128(N)#5040
amd-khushbu wants to merge 29 commits intodevelopfrom
ck/khuagarw/AICK-442

Conversation

@amd-khushbu
Copy link
Contributor

@amd-khushbu amd-khushbu commented Mar 3, 2026

Proposed changes

This PR enables preshuffleB with PermuteN for 2D block scale GEMM operations when the block quantization group size is 128 in the N dimension (BQuantGroupSize::kN == 128).

Motivation

  • Jira Ticket: AICK-442
  • PermuteN is a feature that aligns the matrix in memory for coalesced access, improving performance
  • Previously, PermuteN was only supported for BQuantGroupSize::kN == 1 (per-element quantization)
  • This change extends support to BQuantGroupSize::kN == 128 (block-wise quantization)

Key Changes

1. Extended PermuteN Support in GEMM Pipeline (run_gemm_quant_example.inc)

  • Modified TiledPermuteN condition to enable PermuteN when BQuantGroupSize::kN == 1 || BQuantGroupSize::kN == 128
  • Updated shuffle_b_permuteN and bq_permuteN invocation conditions to support the new block size

2. Enhanced Tensor Shuffle Utilities (tensor_shuffle_utils.hpp)

  • Updated bq_permuteN function to handle both per-element (group_n == 1) and block-128 (group_n == 128) quantization
  • For group_n == 1: Uses full N-tile decomposition with NWarp, N_Warp_Tile, and NRepeat dimensions
  • For group_n == 128: Uses a simplified view where the entire block is treated as a single unit

3. Block GEMM Kernel Changes (block_universal_gemm_ar_flatbr_bquant_cr.hpp)

  • Added NPerBlock constant for proper dimension tracking
  • Modified scale register offset calculation in the BPreshuffleQuant path:
    • When BQuantGroupSize::kN > (NWarp * WG::kN) and NPerBlock == BQuantGroupSize::kN: Uses a single quant group per block (prefill scenario)
    • Otherwise: Uses nIter for decode or multiple groups per warp scenarios

Technical Details

The key insight is that when BQuantGroupSize::kN == 128 (matching the N block size), each thread block processes exactly one quantization group in the N dimension. This allows the same PermuteN optimization to be applied, as the scale values can be broadcast efficiently within the block.

Checklist

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

The implementation distinguishes between two scenarios:

  1. Per-element quantization (kN == 1): Each element has its own scale, full N-tile decomposition is used
  2. Block quantization (kN == 128): One scale per 128 elements in N, simplified view where scales are broadcast within the block

This approach maintains backward compatibility while enabling performance optimizations for the common 128-block quantization case used in modern quantized models.

@ThomasNing
Copy link
Contributor

@amd-khushbu When the transpose C is enabled, we need to have a new algorithm of the PermuteN in C-shuffle epilogue, which does not treat the M-dimension as the outer loop, but instead treats the N-dimension as the outer loop.

@ThomasNing
Copy link
Contributor

@amd-khushbu CI error.

Copy link
Contributor

@ThomasNing ThomasNing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add the protection and limitation on some of the blockscale size on N dimension that doesn't support PermuteN, and why?

using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t, false>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment here to tell the user what the false boolean in here means?

@ThomasNing
Copy link
Contributor

@amd-khushbu CI failed again. PTAL?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants