Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b7d119f
initial commit for debugging
amd-khushbu Feb 11, 2026
33e4890
debugging with prints
amd-khushbu Feb 26, 2026
02e5193
merge with develop
amd-khushbu Feb 26, 2026
d630492
working case for group_n:128
amd-khushbu Feb 26, 2026
f7d8393
code clean up
amd-khushbu Feb 26, 2026
b0d94ff
working preshuffleQ and preshuffleQuant for bquant
amd-khushbu Mar 2, 2026
2a7bd1b
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 2, 2026
812f24f
removing debug changes and code clean up
amd-khushbu Mar 3, 2026
5ed440a
working permuteN for abquant
amd-khushbu Mar 3, 2026
30dacb3
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 5, 2026
b89394e
disabling transposeC with permuteN
amd-khushbu Mar 6, 2026
844352e
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 6, 2026
9030336
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 7, 2026
f032947
Merge branch 'develop' into ck/khuagarw/AICK-442
ThomasNing Mar 9, 2026
67acab0
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 10, 2026
57cb0b2
resolving merge conflicts
amd-khushbu Mar 11, 2026
2f37f37
resolving merge conflicts
amd-khushbu Mar 11, 2026
860befe
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 11, 2026
227a3bf
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 11, 2026
9751bbf
addressing review comments
amd-khushbu Mar 12, 2026
cc973e8
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 12, 2026
ecc711a
fix the compilation error to make it support TransposeC
ThomasNing Mar 12, 2026
1a6cb2f
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 13, 2026
a66d0d6
Merge branch 'develop' into ck/khuagarw/AICK-442
ThomasNing Mar 15, 2026
1af05fd
resolving merge conflicts
amd-khushbu Mar 16, 2026
b182f49
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 17, 2026
d01728d
fixing CI issue
amd-khushbu Mar 17, 2026
0630786
fixing CI issue
amd-khushbu Mar 17, 2026
42887c5
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ static auto _ = []() {
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<
GemmConfigPreshuffleB_PreshuffleBQuant<ck_tile::fp8_t, false>,
GemmConfigPreshuffleB_PreshuffleBQuant<ck_tile::fp8_t, false>, // make the TranposeC
// false
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -34,7 +35,8 @@ static auto _ = []() {
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<
GemmConfigPreshuffleB_PreshuffleBQuant<ck_tile::fp8_t, true>,
GemmConfigPreshuffleB_PreshuffleBQuant<ck_tile::fp8_t, false>, // make the TranposeC
// false
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>) &&
(std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>);
constexpr bool transpose_c = GemmConfig::TransposeC;
constexpr bool transpose_c = ((QuantMode == ck_tile::QuantType::ABQuantGrouped || QuantMode == ck_tile::QuantType::BQuantGrouped) && GemmConfig::TiledMMAPermuteN) ? false : GemmConfig::TransposeC;
constexpr bool eight_waves =
#ifdef CK_GFX950_SUPPORT
IS_FP8BLOCKSCALE && (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
Expand Down Expand Up @@ -200,8 +200,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ABQuantPipeline,
BQuantPipeline>>>;

constexpr bool TiledPermuteN =
(BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
constexpr bool TiledPermuteN = (BQuantGroupSize::kN == 1 || BQuantGroupSize::kN == 128)
? GemmConfig::TiledMMAPermuteN
: false;
if(s.log_level_ > 0)
{
printf(
Expand Down Expand Up @@ -748,7 +749,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(GemmConfig::PreshuffleB)
{
if constexpr(GemmConfig::TiledMMAPermuteN && BQuantGroupSize::kN == 1)
if constexpr(GemmConfig::TiledMMAPermuteN &&
(BQuantGroupSize::kN == 1 || BQuantGroupSize::kN == 128))
{
printf("PreshuffleB with TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
Expand All @@ -775,7 +777,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
QuantMode == ck_tile::QuantType::TensorQuant)
{
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN &&
BQuantGroupSize::kN == 1)
(BQuantGroupSize::kN == 1 || BQuantGroupSize::kN == 128))
{
ck_tile::HostTensor<BQDataType> bq_permuted_host =
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, BQuantGroupSize::kN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,35 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
return shuffle_b(t, GemmConfig{});
}

/*
For the permuteN feature, the BQ (scale) tensor must be shuffled so its layout matches the MFMA
result. The required shuffle depends on group_N (the N dimension of the quant group). The BQ tensor
is treated as an N×K matrix with N = 128 / group_N: 1×8×128 → N = 16 (BQ 16×1) 1×16×128 → N = 8
1×32×128 → N = 4
1×64×128 → N = 2
1×128×128 → N = 1 (one value per full block tile)
The shuffle is given as a 4-tuple whose meaning is tied to n, N_warp, N_warp_tile, and NRepeat:
group_N = 8: Shuffle (1, 4, 2, 2) — i.e. (1, N_warp, N_warp_tile/group_N, NRepeat) with
N_warp_tile/8 = 2 (e.g. 16/8). group_N = 16: Shuffle (1, 4, 1, 2) — 1 value per N_warp_tile
(N_warp_tile/16 = 1). group_N = 32: Shuffle (1, 2, 1, 2) — (1, N_warp/2, 1, NRepeat); 2 N_warp_tiles
share the same scale. group_N = 64: Shuffle (1, 1, 1, 2) — (1, N_warp/4, 1, NRepeat); 4 N_warp_tiles
share the same value. group_N = 128: Shuffle (1, 1, 1, 1) — 1 value for the full block tile.

The alignment problem:
When the BQ tensor is shuffled according to these rules (the 4-tuples above), its layout no longer
matches what the block pipeline expects after block-level MFMA. So even with the “correct” shuffle
for each group_N, BQ is misaligned with the MFMA result at the block level. That’s why the code
today only enables TiledPermuteN for BQuantGroupSize::kN equal to 1 or 128.

Options to fix alignment
1) Update tile_distribution_encoding for permuteN
Adjust the BQ tile distribution encoding so the encoded distribution matches the shuffle layout and
aligns with how the block consumes C and BQ after MFMA. 2) Update the tile window when reading from
DRAM Keep the shuffle as defined above and change how the BQ tile window is built when reading from
device memory so that the window layout matches the post-MFMA layout. 3) Update indexes for BQ reads
Keep the shuffle and tile window; change the indexing used when reading BQ in the block so that each
thread/warp loads the scale that corresponds to its part of the MFMA result.
*/
template <typename GemmConfig, typename T>
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
Expand All @@ -129,13 +158,15 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
int n_ = t.get_lengths()[1];
int bqk_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
int dim = (group_n == 1) ? n_ / GemmConfig::N_Tile : n_;

ck_tile::HostTensor<T> t_view =
(group_n == 1) ? ck_tile::HostTensor<T>(
{dim, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_})
: ck_tile::HostTensor<T>({dim, 1, 1, 1, bqk_});

ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile / group_n,
NRepeat,
bqk_});
std::copy(t.begin(), t.end(), t_view.begin());

return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,26 +259,72 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};

index_t reg_offset = [&]() {
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
if constexpr(Traits::BPreshuffleQuant)
{
constexpr index_t reg_offset = [&]() {
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN) &&
Traits::NPerBlock == BQuantGroupSize::kN)
{
return kQScale;
}
else
{
return nIter;
}
}();

auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;

if constexpr(std::is_same_v<BQDataType, float>)
{
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ +
kQScale;
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
return nIter * KPerBlockBQ + kQScale;
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_reg_f = Base::cvt_scale_to_fp32<BQDataType>(scale_reg);

static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
});

// cross lane ops to get the value of scale_reg.
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));

float b_scale_reg_f =
Base::cvt_scale_to_fp32<BQDataType>(gathered_scale_reg);

static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
});
}
else
{
index_t reg_offset = [&]() {
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN *
KPerBlockBQ +
kQScale;
}
else
{
return nIter * KPerBlockBQ + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_reg_f = Base::cvt_scale_to_fp32<BQDataType>(scale_reg);

static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
});
}
});
});
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr index_t NWarp = config.template at<2>();

static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;

static constexpr index_t kBlockSize = Problem::kBlockSize;
Expand Down Expand Up @@ -165,6 +166,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
}
});
});

static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
constexpr auto tbuf_offset =
Expand All @@ -175,7 +177,18 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg

if constexpr(BPreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
constexpr index_t reg_offset = [&]() {
if constexpr(BQuantGroupSize::kN > (NWarp * WG::kN) &&
NPerBlock == BQuantGroupSize::kN)
{
return kQScale; // prefill: one quant group per block
}
else
{
return nIter; // decode or multiple groups per warp
}
}();

auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,18 +286,46 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase

if constexpr(Traits::NQPerBlock / NWarp == 1)
{
constexpr auto cw_spans = CWarpTensor::get_distributed_spans();
static_assert(cw_spans[I0{}].impl_.size() == 0);
sweep_tile_span(cw_spans[I1{}], [&](auto in) {
constexpr auto block_idx_m = tile_distributed_index<mIter>{};
constexpr auto block_idx_n = detail::make_tile_distributed_index(
constexpr auto cw_spans = CWarpTensor::get_distributed_spans();
constexpr auto empty_idx = tile_distributed_index<>{};

auto accumulate_c = [&](auto im, auto in) {
constexpr auto c_block_idx_m = detail::make_tile_distributed_index(
merge_sequences(sequence<mIter>{}, im.impl_));
constexpr auto c_block_idx_n = detail::make_tile_distributed_index(
merge_sequences(sequence<nIter>{}, in.impl_));
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};
constexpr auto empty_idx = tile_distributed_index<>{};
c_block_tensor(make_tuple(block_idx_m, block_idx_n)) +=
c_warp_tensor(make_tuple(empty_idx, in)) *
q_block_tensor(make_tuple(block_idx_m, block_idx_kq));
});
// q_block_tensor's M distribution only has mIter
// as its Y index (no warp-internal M indices),
// so use a separate M index for it.
constexpr auto q_block_idx_m = tile_distributed_index<mIter>{};
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};
c_block_tensor(make_tuple(c_block_idx_m, c_block_idx_n)) +=
c_warp_tensor(make_tuple(im, in)) *
q_block_tensor(make_tuple(q_block_idx_m, block_idx_kq));
};

// Handle both transposed C (M span empty, N span
// non-empty) and non-transposed C (M span non-empty,
// N span empty). sweep_tile_span cannot handle empty
// spans, so dispatch based on span sizes.
if constexpr(cw_spans[I0{}].impl_.size() > 0 &&
cw_spans[I1{}].impl_.size() > 0)
{
sweep_tile_span(cw_spans[I0{}], [&](auto im) {
sweep_tile_span(cw_spans[I1{}],
[&](auto in) { accumulate_c(im, in); });
});
}
else if constexpr(cw_spans[I0{}].impl_.size() > 0)
{
sweep_tile_span(cw_spans[I0{}],
[&](auto im) { accumulate_c(im, empty_idx); });
}
else if constexpr(cw_spans[I1{}].impl_.size() > 0)
{
sweep_tile_span(cw_spans[I1{}],
[&](auto in) { accumulate_c(empty_idx, in); });
}
}
else
{
Expand Down
Loading