Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b7d119f
initial commit for debugging
amd-khushbu Feb 11, 2026
33e4890
debugging with prints
amd-khushbu Feb 26, 2026
02e5193
merge with develop
amd-khushbu Feb 26, 2026
d630492
working case for group_n:128
amd-khushbu Feb 26, 2026
f7d8393
code clean up
amd-khushbu Feb 26, 2026
b0d94ff
working preshuffleQ and preshuffleQuant for bquant
amd-khushbu Mar 2, 2026
2a7bd1b
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 2, 2026
812f24f
removing debug changes and code clean up
amd-khushbu Mar 3, 2026
5ed440a
working permuteN for abquant
amd-khushbu Mar 3, 2026
30dacb3
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 5, 2026
b89394e
disabling transposeC with permuteN
amd-khushbu Mar 6, 2026
844352e
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 6, 2026
9030336
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 7, 2026
f032947
Merge branch 'develop' into ck/khuagarw/AICK-442
ThomasNing Mar 9, 2026
67acab0
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 10, 2026
57cb0b2
resolving merge conflicts
amd-khushbu Mar 11, 2026
2f37f37
resolving merge conflicts
amd-khushbu Mar 11, 2026
860befe
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 11, 2026
227a3bf
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 11, 2026
9751bbf
addressing review comments
amd-khushbu Mar 12, 2026
cc973e8
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 12, 2026
ecc711a
fix the compilation error to make it support TransposeC
ThomasNing Mar 12, 2026
1a6cb2f
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 13, 2026
a66d0d6
Merge branch 'develop' into ck/khuagarw/AICK-442
ThomasNing Mar 15, 2026
1af05fd
resolving merge conflicts
amd-khushbu Mar 16, 2026
b182f49
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 17, 2026
d01728d
fixing CI issue
amd-khushbu Mar 17, 2026
0630786
fixing CI issue
amd-khushbu Mar 17, 2026
42887c5
Merge branch 'develop' into ck/khuagarw/AICK-442
amd-khushbu Mar 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t, false>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -31,7 +31,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t, false>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment here to tell the user what the false boolean in here means?

TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t, false>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -30,7 +30,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t, false>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ABQuantPipeline,
BQuantPipeline>>>;

constexpr bool TiledPermuteN =
(BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
constexpr bool TiledPermuteN = (BQuantGroupSize::kN == 1 || BQuantGroupSize::kN == 128)
? GemmConfig::TiledMMAPermuteN
: false;
if(s.log_level_ > 0)
{
printf(
Expand Down Expand Up @@ -748,7 +749,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(GemmConfig::PreshuffleB)
{
if constexpr(GemmConfig::TiledMMAPermuteN && BQuantGroupSize::kN == 1)
if constexpr(GemmConfig::TiledMMAPermuteN &&
(BQuantGroupSize::kN == 1 || BQuantGroupSize::kN == 128))
{
printf("PreshuffleB with TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
Expand All @@ -775,7 +777,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
QuantMode == ck_tile::QuantType::TensorQuant)
{
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN &&
BQuantGroupSize::kN == 1)
(BQuantGroupSize::kN == 1 || BQuantGroupSize::kN == 128))
{
ck_tile::HostTensor<BQDataType> bq_permuted_host =
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, BQuantGroupSize::kN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,15 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
int n_ = t.get_lengths()[1];
int bqk_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
int dim = (group_n == 1) ? n_ / GemmConfig::N_Tile : n_;

ck_tile::HostTensor<T> t_view =
(group_n == 1) ? ck_tile::HostTensor<T>(
{dim, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_})
: ck_tile::HostTensor<T>({dim, 1, 1, 1, bqk_});

ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile / group_n,
NRepeat,
bqk_});
std::copy(t.begin(), t.end(), t_view.begin());

return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,26 +259,72 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};

index_t reg_offset = [&]() {
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
if constexpr(Traits::BPreshuffleQuant)
{
constexpr index_t reg_offset = [&]() {
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN) &&
Traits::NPerBlock == BQuantGroupSize::kN)
{
return kQScale;
}
else
{
return nIter;
}
}();

auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;

if constexpr(std::is_same_v<BQDataType, float>)
{
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ +
kQScale;
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
return nIter * KPerBlockBQ + kQScale;
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_reg_f = Base::cvt_scale_to_fp32<BQDataType>(scale_reg);

static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
});

// cross lane ops to get the value of scale_reg.
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));

float b_scale_reg_f =
Base::cvt_scale_to_fp32<BQDataType>(gathered_scale_reg);

static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
});
}
else
{
index_t reg_offset = [&]() {
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN *
KPerBlockBQ +
kQScale;
}
else
{
return nIter * KPerBlockBQ + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_reg_f = Base::cvt_scale_to_fp32<BQDataType>(scale_reg);

static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
});
}
});
});
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr index_t NWarp = config.template at<2>();

static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;

static constexpr index_t kBlockSize = Problem::kBlockSize;
Expand Down Expand Up @@ -165,6 +166,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
}
});
});

static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
constexpr auto tbuf_offset =
Expand All @@ -175,7 +177,18 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg

if constexpr(BPreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
constexpr index_t reg_offset = [&]() {
if constexpr(BQuantGroupSize::kN > (NWarp * WG::kN) &&
NPerBlock == BQuantGroupSize::kN)
{
return kQScale; // prefill: one quant group per block
}
else
{
return nIter; // decode or multiple groups per warp
}
}();

auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
Expand Down
Loading