From a42af893842c3f480c84e50b49a72d9a229726f6 Mon Sep 17 00:00:00 2001 From: Astha Date: Tue, 20 Jan 2026 04:01:48 -0500 Subject: [PATCH 1/3] Addition of code for XCD remapping This change adds in a function to remap block ids from their original round robin assignment to a contiguous layout across XCDs. This function is added to the StreamKTilePartitioner and called in the operator() functions. There are also unit tests to verify the correctness of the function on minimal arrays. These changes should improve locality and the cache hit rate, therefore improving performance overall. --- .../streamk_gemm/streamk_gemm_kernel.hpp | 11 ++-- .../streamk_gemm_tile_partitioner.hpp | 17 +++++- .../streamk_gemm_tile_partitioner_impl.hpp | 53 ++++++++++++++++++- .../test_streamk_tile_partitioner.cpp | 46 ++++++++++++++++ .../test_streamk_tile_partitioner_common.hpp | 17 ++++++ 5 files changed, 137 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp index 47e59c47047f..aee436fb15ec 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp @@ -673,10 +673,12 @@ struct StreamKKernel __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()]; index_t block_idx = ck_tile::get_block_1d_id(); + index_t grid_size = kargs.tile_partitioner.grid_size().x; index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile(); index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas(); bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas(); + block_idx = kargs.tile_partitioner.RemapXCD(block_idx, grid_size); // Check if at the data parallel section if(is_dp_ctas) { @@ -706,11 +708,14 @@ struct StreamKKernel __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()]; index_t block_idx = ck_tile::get_block_1d_id(); + index_t grid_size = kargs.tile_partitioner.grid_size().x; index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile(); - // Data-parallel section - for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles(); - tile_idx += kargs.tile_partitioner.get_grid()) + block_idx = + kargs.tile_partitioner.RemapXCD(block_idx, grid_size) + // Data-parallel section + for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles(); + tile_idx += kargs.tile_partitioner.get_grid()) { BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0); block_sync_lds(); diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp index f028ba0c626f..19aac8367e95 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp @@ -213,6 +213,19 @@ struct StreamKTilePartitionerBase */ CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept; + /** + * @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous + * XCD segments + * + * @param block_1d_id grid 1D id + * @param total_num_tiles size of the 1D grid + * @param NUM_XCDS number of XCDs + * @return index_t The id after XCD remap + */ + CK_TILE_HOST_DEVICE index_t RemapXCD(index_t block_1d_id, + index_t total_num_tiles, + index_t NUM_XCDS = 8) const noexcept; + protected: index_t num_tiles_; index_t grid_; @@ -281,7 +294,7 @@ struct StreamKTilePartitioner * * @return dim_3 The launching grid size for the kernel. */ - CK_TILE_HOST auto grid_size() const noexcept -> dim3; + CK_TILE_HOST_DEVICE auto grid_size() const noexcept -> dim3; /** * @brief Returns the total number of DP tiles per workgroup. @@ -328,7 +341,7 @@ struct StreamKTilePartitioner * * @return dim_3 The launching grid size for the kernel. */ - CK_TILE_HOST auto grid_size() const noexcept -> dim3; + CK_TILE_HOST_DEVICE auto grid_size() const noexcept -> dim3; /** * @brief Returns the total number of DP workgroups. diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp index f80eec844ccb..2edfcb1437f2 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp @@ -279,6 +279,55 @@ StreamKTilePartitionerBase::estimate_ return std::max(num_wgs_per_tile, 1); } +/** + * @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous + * XCD segments + * + * @param block_1d_id grid 1D id + * @param total_num_tiles size of the 1D grid + * @param NUM_XCDS number of XCDs + * @return index_t The id after XCD remap + */ +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::RemapXCD( + index_t block_1d_id, index_t total_num_tiles, index_t NUM_XCDS) const noexcept +{ + // Number of ids per XCD in the new arrangement + index_t ids_per_xcd = (total_num_tiles + NUM_XCDS - 1) / NUM_XCDS; + + // When total_num_tiles cannot divide NUM_XCDS, some xcds will have + // ids_per_xcd ids, the other will have ids_per_xcd - 1 ids. + // We calculate the number of xcds that have ids_per_xcd ids as tall_xcds + index_t tall_xcds = total_num_tiles % NUM_XCDS; + tall_xcds = (tall_xcds == 0) ? NUM_XCDS : tall_xcds; + + // Compute current XCD and local id within the XCD + index_t xcd = block_1d_id % NUM_XCDS; + index_t local_id = block_1d_id / NUM_XCDS; + + // Calculate new id based on the new grouping + if(xcd < tall_xcds) + { + block_1d_id = xcd * ids_per_xcd + local_id; + } + else + { + block_1d_id = tall_xcds * ids_per_xcd + (xcd - tall_xcds) * (ids_per_xcd - 1) + local_id; + } + + /** + * original ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + * XCD 0 gets: [0, 8], XCD 1 gets: [1, 9], ... + * + * post-remap ids: [0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15] + * XCD 0 gets: [0, 1], XCD 1 gets: [2, 3], ... + * + * after remap the ids are continguous on each XCD + */ + return block_1d_id; +} + template @@ -295,7 +344,7 @@ StreamKTilePartitioner::StreamK } template -CK_TILE_HOST auto +CK_TILE_HOST_DEVICE auto StreamKTilePartitioner::grid_size() const noexcept -> dim3 { @@ -337,7 +386,7 @@ StreamKTilePartitioner::Stream } template -CK_TILE_HOST auto +CK_TILE_HOST_DEVICE auto StreamKTilePartitioner::grid_size() const noexcept -> dim3 { diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index 30b1b878c5d2..a2e688fc6c70 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -407,6 +407,52 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings) } } +TEST(StreamKTilePartitionerBaseRemapXCD, SmallArray) +{ + using Config = StreamKTilePartitionerBaseConfigSKOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + const std::vector initial_values = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const std::vector expected_values = { + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + + test_remap_xcd(initial_values, expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitionerBaseRemapXCD, MidArray) +{ + using Config = StreamKTilePartitionerBaseConfigSKOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + const std::vector initial_values = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126}; + const std::vector expected_values = { + 0, 16, 32, 48, 64, 80, 96, 112, 1, 17, 33, 49, 65, 81, 97, 113, 2, 18, 34, 50, + 66, 82, 98, 114, 3, 19, 35, 51, 67, 83, 99, 115, 4, 20, 36, 52, 68, 84, 100, 116, + 5, 21, 37, 53, 69, 85, 101, 117, 6, 22, 38, 54, 70, 86, 102, 118, 7, 23, 39, 55, + 71, 87, 103, 119, 8, 24, 40, 56, 72, 88, 104, 120, 9, 25, 41, 57, 73, 89, 105, 121, + 10, 26, 42, 58, 74, 90, 106, 122, 11, 27, 43, 59, 75, 91, 107, 123, 12, 28, 44, 60, + 76, 92, 108, 124, 13, 29, 45, 61, 77, 93, 109, 125, 14, 30, 46, 62, 78, 94, 110, 126, + 15, 31, 47, 63, 79, 95, 111}; + test_remap_xcd(initial_values, expected_values, tile_partitioner); +} + TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, SKOnlyLargeK) { /* diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp index 31217ba10149..214c123c4a77 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp @@ -421,6 +421,23 @@ void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start, EXPECT_EQ(tile_local_cta_idx, expected_tile_local_cta_idx); } +template +void test_remap_xcd( + const std::vector& initial_values, + const std::vector& expected_values, + ck_tile::StreamKTilePartitioner_v2& + tile_partitioner) +{ + // ck_tile::index_t grid_size = tile_partitioner.grid_size().x; + std::vector remapped_values(initial_values.size()); + for(std::size_t i = 0; i < initial_values.size(); ++i) + { + remapped_values[i] = tile_partitioner.RemapXCD(initial_values[i], initial_values.size()); + } + + EXPECT_EQ(remapped_values, expected_values); +} + // Configs for TilePartitioner Child structs struct StreamKTilePartitionerV2PersistentExpected { From 1d1c3a9db9261473d4947f88d0c6905a75074e64 Mon Sep 17 00:00:00 2001 From: Astha Rai Date: Wed, 11 Feb 2026 10:04:59 +0000 Subject: [PATCH 2/3] Using enum to assign number of XCDs based on architecture and limiting runs to the gfx942 architecture --- .../streamk_gemm/streamk_gemm_kernel.hpp | 22 +++- .../streamk_gemm_tile_partitioner.hpp | 8 +- .../streamk_gemm_tile_partitioner_impl.hpp | 18 +-- .../kernel/streamk_gemm/streamk_gemm_xcd.hpp | 30 +++++ .../test_streamk_tile_partitioner.cpp | 113 +++++++++++++----- .../test_streamk_tile_partitioner_common.hpp | 9 +- 6 files changed, 144 insertions(+), 56 deletions(-) create mode 100644 projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_xcd.hpp diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp index aee436fb15ec..1ee8158455bf 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/host/concat.hpp" #include "streamk_gemm_coherency.hpp" +#include "streamk_gemm_xcd.hpp" namespace ck_tile { @@ -678,7 +679,12 @@ struct StreamKKernel index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas(); bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas(); - block_idx = kargs.tile_partitioner.RemapXCD(block_idx, grid_size); + using CompilerTargetT = decltype(core::arch::get_compiler_target()); + if constexpr(CompilerTargetT::TARGET_ID == core::arch::amdgcn_target_id::GFX942) + { + constexpr int num_xcds = ck_tile::NumXCD::num_xcds; + block_idx = kargs.tile_partitioner.remap_xcd(block_idx, grid_size, num_xcds); + } // Check if at the data parallel section if(is_dp_ctas) { @@ -711,11 +717,15 @@ struct StreamKKernel index_t grid_size = kargs.tile_partitioner.grid_size().x; index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile(); - block_idx = - kargs.tile_partitioner.RemapXCD(block_idx, grid_size) - // Data-parallel section - for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles(); - tile_idx += kargs.tile_partitioner.get_grid()) + using CompilerTargetT = decltype(core::arch::get_compiler_target()); + if constexpr(CompilerTargetT::TARGET_ID == core::arch::amdgcn_target_id::GFX942) + { + constexpr int num_xcds = ck_tile::NumXCD::num_xcds; + block_idx = kargs.tile_partitioner.remap_xcd(block_idx, grid_size, num_xcds); + } + // Data-parallel section + for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles(); + tile_idx += kargs.tile_partitioner.get_grid()) { BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0); block_sync_lds(); diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp index 19aac8367e95..944ae6e82135 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp @@ -219,12 +219,12 @@ struct StreamKTilePartitionerBase * * @param block_1d_id grid 1D id * @param total_num_tiles size of the 1D grid - * @param NUM_XCDS number of XCDs + * @param num_xcds number of XCDs * @return index_t The id after XCD remap */ - CK_TILE_HOST_DEVICE index_t RemapXCD(index_t block_1d_id, - index_t total_num_tiles, - index_t NUM_XCDS = 8) const noexcept; + CK_TILE_HOST_DEVICE index_t remap_xcd(index_t block_1d_id, + index_t total_num_tiles, + index_t num_xcds = 8) const noexcept; protected: index_t num_tiles_; diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp index 2edfcb1437f2..efcaf11c6708 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp @@ -285,26 +285,26 @@ StreamKTilePartitionerBase::estimate_ * * @param block_1d_id grid 1D id * @param total_num_tiles size of the 1D grid - * @param NUM_XCDS number of XCDs + * @param num_xcds number of XCDs * @return index_t The id after XCD remap */ template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::RemapXCD( - index_t block_1d_id, index_t total_num_tiles, index_t NUM_XCDS) const noexcept +StreamKTilePartitionerBase::remap_xcd( + index_t block_1d_id, index_t total_num_tiles, index_t num_xcds) const noexcept { // Number of ids per XCD in the new arrangement - index_t ids_per_xcd = (total_num_tiles + NUM_XCDS - 1) / NUM_XCDS; + index_t ids_per_xcd = (total_num_tiles + num_xcds - 1) / num_xcds; - // When total_num_tiles cannot divide NUM_XCDS, some xcds will have + // When total_num_tiles cannot divide num_xcds, some xcds will have // ids_per_xcd ids, the other will have ids_per_xcd - 1 ids. // We calculate the number of xcds that have ids_per_xcd ids as tall_xcds - index_t tall_xcds = total_num_tiles % NUM_XCDS; - tall_xcds = (tall_xcds == 0) ? NUM_XCDS : tall_xcds; + index_t tall_xcds = total_num_tiles % num_xcds; + tall_xcds = (tall_xcds == 0) ? num_xcds : tall_xcds; // Compute current XCD and local id within the XCD - index_t xcd = block_1d_id % NUM_XCDS; - index_t local_id = block_1d_id / NUM_XCDS; + index_t xcd = block_1d_id % num_xcds; + index_t local_id = block_1d_id / num_xcds; // Calculate new id based on the new grouping if(xcd < tall_xcds) diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_xcd.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_xcd.hpp new file mode 100644 index 000000000000..c2606db0e0ac --- /dev/null +++ b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_xcd.hpp @@ -0,0 +1,30 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once +#include "ck_tile/core/arch/arch.hpp" +namespace ck_tile { + +template +struct NumXCD +{ + static constexpr int num_xcds = 1; +}; + +/**template +struct NumXCD> +{ + static constexpr int num_xcds = 6; +};**/ + +template +struct NumXCD< + CompilerTarget, + core::arch::enable_if_target_id_t> +{ + static constexpr int num_xcds = 8; +}; + +} // namespace ck_tile diff --git a/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index a2e688fc6c70..30f511633aa0 100644 --- a/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #include "test_streamk_tile_partitioner_common.hpp" +#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_xcd.hpp" TEST(StreamKTilePartitionerBaseConstructor, SKOnly) { @@ -409,48 +410,94 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings) TEST(StreamKTilePartitionerBaseRemapXCD, SmallArray) { - using Config = StreamKTilePartitionerBaseConfigSKOnly; + using CompilerTargetT = decltype(ck_tile::core::arch::get_compiler_target()); + if constexpr(CompilerTargetT::TARGET_ID == ck_tile::core::arch::amdgcn_target_id::GFX942) + { + constexpr int num_xcds = ck_tile::NumXCD::num_xcds; + using Config = StreamKTilePartitionerBaseConfigSKOnly; - ck_tile::StreamKTilePartitioner_v2 - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + ck_tile::StreamKTilePartitioner + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; - const std::vector initial_values = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const std::vector expected_values = { - 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const std::vector initial_values = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const std::vector expected_values = { + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - test_remap_xcd(initial_values, expected_values, tile_partitioner); + test_remap_xcd( + initial_values, expected_values, tile_partitioner, num_xcds); + } + else + { + GTEST_SKIP() << "Skipping: not gfx942 build"; + } } TEST(StreamKTilePartitionerBaseRemapXCD, MidArray) { - using Config = StreamKTilePartitionerBaseConfigSKOnly; + using CompilerTargetT = decltype(ck_tile::core::arch::get_compiler_target()); + if constexpr(CompilerTargetT::TARGET_ID == ck_tile::core::arch::amdgcn_target_id::GFX942) + { + constexpr int num_xcds = ck_tile::NumXCD::num_xcds; + using Config = StreamKTilePartitionerBaseConfigSKOnly; - ck_tile::StreamKTilePartitioner_v2 - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + ck_tile::StreamKTilePartitioner + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; - const std::vector initial_values = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, - 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, - 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126}; - const std::vector expected_values = { - 0, 16, 32, 48, 64, 80, 96, 112, 1, 17, 33, 49, 65, 81, 97, 113, 2, 18, 34, 50, - 66, 82, 98, 114, 3, 19, 35, 51, 67, 83, 99, 115, 4, 20, 36, 52, 68, 84, 100, 116, - 5, 21, 37, 53, 69, 85, 101, 117, 6, 22, 38, 54, 70, 86, 102, 118, 7, 23, 39, 55, - 71, 87, 103, 119, 8, 24, 40, 56, 72, 88, 104, 120, 9, 25, 41, 57, 73, 89, 105, 121, - 10, 26, 42, 58, 74, 90, 106, 122, 11, 27, 43, 59, 75, 91, 107, 123, 12, 28, 44, 60, - 76, 92, 108, 124, 13, 29, 45, 61, 77, 93, 109, 125, 14, 30, 46, 62, 78, 94, 110, 126, - 15, 31, 47, 63, 79, 95, 111}; - test_remap_xcd(initial_values, expected_values, tile_partitioner); + const std::vector initial_values = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126}; + const std::vector expected_values = { + 0, 16, 32, 48, 64, 80, 96, 112, 1, 17, 33, 49, 65, 81, 97, 113, + 2, 18, 34, 50, 66, 82, 98, 114, 3, 19, 35, 51, 67, 83, 99, 115, + 4, 20, 36, 52, 68, 84, 100, 116, 5, 21, 37, 53, 69, 85, 101, 117, + 6, 22, 38, 54, 70, 86, 102, 118, 7, 23, 39, 55, 71, 87, 103, 119, + 8, 24, 40, 56, 72, 88, 104, 120, 9, 25, 41, 57, 73, 89, 105, 121, + 10, 26, 42, 58, 74, 90, 106, 122, 11, 27, 43, 59, 75, 91, 107, 123, + 12, 28, 44, 60, 76, 92, 108, 124, 13, 29, 45, 61, 77, 93, 109, 125, + 14, 30, 46, 62, 78, 94, 110, 126, 15, 31, 47, 63, 79, 95, 111}; + test_remap_xcd( + initial_values, expected_values, tile_partitioner, num_xcds); + } + else + { + GTEST_SKIP() << "Skipping: not gfx942 build"; + } +} + +TEST(StreamKTilePartitionerBaseRemapXCD, UnevenXCD) +{ + using CompilerTargetT = decltype(ck_tile::core::arch::get_compiler_target()); + if constexpr(CompilerTargetT::TARGET_ID == ck_tile::core::arch::amdgcn_target_id::GFX942) + { + using Config = StreamKTilePartitionerBaseConfigSKOnly; + + ck_tile::StreamKTilePartitioner + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + const std::vector initial_values = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const std::vector expected_values = { + 0, 4, 7, 10, 13, 1, 5, 8, 11, 14, 2, 6, 9, 12, 15, 3}; + + test_remap_xcd(initial_values, expected_values, tile_partitioner, 5); + } + else + { + GTEST_SKIP() << "Skipping: not gfx942 build"; + } } TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, SKOnlyLargeK) diff --git a/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp b/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp index 214c123c4a77..5911270b0f38 100644 --- a/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp +++ b/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp @@ -425,14 +425,15 @@ template void test_remap_xcd( const std::vector& initial_values, const std::vector& expected_values, - ck_tile::StreamKTilePartitioner_v2& - tile_partitioner) + ck_tile::StreamKTilePartitioner& + tile_partitioner, + const int num_xcds = 8) { - // ck_tile::index_t grid_size = tile_partitioner.grid_size().x; std::vector remapped_values(initial_values.size()); for(std::size_t i = 0; i < initial_values.size(); ++i) { - remapped_values[i] = tile_partitioner.RemapXCD(initial_values[i], initial_values.size()); + remapped_values[i] = + tile_partitioner.remap_xcd(initial_values[i], initial_values.size(), num_xcds); } EXPECT_EQ(remapped_values, expected_values); From 23361e2ac107d0cc8dff3c92678737de45f1da81 Mon Sep 17 00:00:00 2001 From: Astha Rai Date: Wed, 11 Mar 2026 07:24:56 +0000 Subject: [PATCH 3/3] Querying the device for number of XCDs This commit removes the use of an enum to map XCD values by architecture and switches to querying the number of XCDs from the device through the hip API. The unit tests have been changed to hardcode XCD values to simplify them. --- .../include/ck_tile/host/device_prop.hpp | 17 +++ .../streamk_gemm/streamk_gemm_kernel.hpp | 30 ++-- .../streamk_gemm_tile_partitioner.hpp | 5 +- .../streamk_gemm_tile_partitioner_impl.hpp | 8 +- .../kernel/streamk_gemm/streamk_gemm_xcd.hpp | 30 ---- .../test_streamk_tile_partitioner.cpp | 128 ++++++++---------- 6 files changed, 96 insertions(+), 122 deletions(-) delete mode 100644 projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_xcd.hpp diff --git a/projects/composablekernel/include/ck_tile/host/device_prop.hpp b/projects/composablekernel/include/ck_tile/host/device_prop.hpp index 5f021d7bc5c1..ea4064ca9c11 100644 --- a/projects/composablekernel/include/ck_tile/host/device_prop.hpp +++ b/projects/composablekernel/include/ck_tile/host/device_prop.hpp @@ -84,6 +84,23 @@ inline size_t get_num_cus() return static_cast(props.multiProcessorCount); } +inline size_t get_num_xccs() +{ + int device = 0; + int num_xccs = 1; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return 0; + } + status = hipDeviceGetAttribute(&num_xccs, hipDeviceAttributeNumberOfXccs, device); + if(status == hipSuccess) + { + return num_xccs; + } + return 1; +} + } // namespace ck_tile #endif diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp index ccdfbdcc8abd..74b616b92a19 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp @@ -6,8 +6,8 @@ #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/host/concat.hpp" +#include "ck_tile/host/device_prop.hpp" #include "streamk_gemm_coherency.hpp" -#include "streamk_gemm_xcd.hpp" namespace ck_tile { @@ -120,7 +120,7 @@ struct StreamKKernel struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<> { - StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid) + StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid, int num_xccs_ = 1) : UniversalGemmKernelArgs{host_args.as_ptr, host_args.bs_ptr, host_args.ds_ptr, @@ -136,7 +136,8 @@ struct StreamKKernel // The workspace pointer is set to nullptr because we must first // instantiate the TilePartitioner to get the necessary size workspace_ptr{nullptr}, - tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}} + tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}, + num_xccs{num_xccs_} { } @@ -150,6 +151,11 @@ struct StreamKKernel * the C tensor. */ TilePartitioner tile_partitioner; + /** + * @brief An int for the number of xcds available on a given device for remapping the block + * indices to be contiguous. + */ + int num_xccs; }; using KernelArgs = StreamKKernelArgs; @@ -208,8 +214,8 @@ struct StreamKKernel int occupancy = Occupancy()) { const index_t grid = num_cu * occupancy; - - return StreamKKernelArgs{host_args, grid}; + const int num_xccs = get_num_xccs(); + return StreamKKernelArgs{host_args, grid, num_xccs}; } template @@ -759,12 +765,7 @@ struct StreamKKernel index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas(); bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas(); - using CompilerTargetT = decltype(core::arch::get_compiler_target()); - if constexpr(CompilerTargetT::TARGET_ID == core::arch::amdgcn_target_id::GFX942) - { - constexpr int num_xcds = ck_tile::NumXCD::num_xcds; - block_idx = kargs.tile_partitioner.remap_xcd(block_idx, grid_size, num_xcds); - } + block_idx = kargs.tile_partitioner.remap_xcd(block_idx, grid_size, kargs.num_xccs); // Check if at the data parallel section if(is_dp_ctas) { @@ -797,12 +798,7 @@ struct StreamKKernel index_t grid_size = kargs.tile_partitioner.grid_size().x; index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile(); - using CompilerTargetT = decltype(core::arch::get_compiler_target()); - if constexpr(CompilerTargetT::TARGET_ID == core::arch::amdgcn_target_id::GFX942) - { - constexpr int num_xcds = ck_tile::NumXCD::num_xcds; - block_idx = kargs.tile_partitioner.remap_xcd(block_idx, grid_size, num_xcds); - } + block_idx = kargs.tile_partitioner.remap_xcd(block_idx, grid_size, kargs.num_xccs); // Data-parallel section for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles(); tile_idx += kargs.tile_partitioner.get_grid()) diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp index 944ae6e82135..cb079aaccee1 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp @@ -222,9 +222,8 @@ struct StreamKTilePartitionerBase * @param num_xcds number of XCDs * @return index_t The id after XCD remap */ - CK_TILE_HOST_DEVICE index_t remap_xcd(index_t block_1d_id, - index_t total_num_tiles, - index_t num_xcds = 8) const noexcept; + CK_TILE_HOST_DEVICE static index_t + remap_xcd(index_t block_1d_id, index_t total_num_tiles, index_t num_xcds = 8) noexcept; protected: index_t num_tiles_; diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp index c8ba06b4b415..e329c8978660 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp @@ -289,10 +289,14 @@ StreamKTilePartitionerBase::estimate_ * @return index_t The id after XCD remap */ template -CK_TILE_HOST_DEVICE index_t +CK_TILE_HOST_DEVICE /* static */ index_t StreamKTilePartitionerBase::remap_xcd( - index_t block_1d_id, index_t total_num_tiles, index_t num_xcds) const noexcept + index_t block_1d_id, index_t total_num_tiles, index_t num_xcds) noexcept { + if(num_xcds == 1) + { + return block_1d_id; + } // Number of ids per XCD in the new arrangement index_t ids_per_xcd = (total_num_tiles + num_xcds - 1) / num_xcds; diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_xcd.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_xcd.hpp deleted file mode 100644 index c2606db0e0ac..000000000000 --- a/projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_xcd.hpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once -#include "ck_tile/core/arch/arch.hpp" -namespace ck_tile { - -template -struct NumXCD -{ - static constexpr int num_xcds = 1; -}; - -/**template -struct NumXCD> -{ - static constexpr int num_xcds = 6; -};**/ - -template -struct NumXCD< - CompilerTarget, - core::arch::enable_if_target_id_t> -{ - static constexpr int num_xcds = 8; -}; - -} // namespace ck_tile diff --git a/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index d2e368d596b5..b541182fe60d 100644 --- a/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include "test_streamk_tile_partitioner_common.hpp" -#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_xcd.hpp" +#include "ck_tile/host/device_prop.hpp" TEST(StreamKTilePartitionerBaseConstructor, SKOnly) { @@ -410,94 +410,82 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings) TEST(StreamKTilePartitionerBaseRemapXCD, SmallArray) { - using CompilerTargetT = decltype(ck_tile::core::arch::get_compiler_target()); - if constexpr(CompilerTargetT::TARGET_ID == ck_tile::core::arch::amdgcn_target_id::GFX942) - { - constexpr int num_xcds = ck_tile::NumXCD::num_xcds; - using Config = StreamKTilePartitionerBaseConfigSKOnly; + int num_xcds = 8; + using Config = StreamKTilePartitionerBaseConfigSKOnly; - ck_tile::StreamKTilePartitioner + ck_tile:: + StreamKTilePartitioner tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; - const std::vector initial_values = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const std::vector expected_values = { - 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const std::vector initial_values = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const std::vector expected_values = { + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - test_remap_xcd( - initial_values, expected_values, tile_partitioner, num_xcds); - } - else - { - GTEST_SKIP() << "Skipping: not gfx942 build"; - } + test_remap_xcd(initial_values, expected_values, tile_partitioner, num_xcds); } TEST(StreamKTilePartitionerBaseRemapXCD, MidArray) { - using CompilerTargetT = decltype(ck_tile::core::arch::get_compiler_target()); - if constexpr(CompilerTargetT::TARGET_ID == ck_tile::core::arch::amdgcn_target_id::GFX942) - { - constexpr int num_xcds = ck_tile::NumXCD::num_xcds; - using Config = StreamKTilePartitionerBaseConfigSKOnly; + int num_xcds = 8; + using Config = StreamKTilePartitionerBaseConfigSKOnly; - ck_tile::StreamKTilePartitioner + ck_tile:: + StreamKTilePartitioner tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; - const std::vector initial_values = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, - 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, - 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126}; - const std::vector expected_values = { - 0, 16, 32, 48, 64, 80, 96, 112, 1, 17, 33, 49, 65, 81, 97, 113, - 2, 18, 34, 50, 66, 82, 98, 114, 3, 19, 35, 51, 67, 83, 99, 115, - 4, 20, 36, 52, 68, 84, 100, 116, 5, 21, 37, 53, 69, 85, 101, 117, - 6, 22, 38, 54, 70, 86, 102, 118, 7, 23, 39, 55, 71, 87, 103, 119, - 8, 24, 40, 56, 72, 88, 104, 120, 9, 25, 41, 57, 73, 89, 105, 121, - 10, 26, 42, 58, 74, 90, 106, 122, 11, 27, 43, 59, 75, 91, 107, 123, - 12, 28, 44, 60, 76, 92, 108, 124, 13, 29, 45, 61, 77, 93, 109, 125, - 14, 30, 46, 62, 78, 94, 110, 126, 15, 31, 47, 63, 79, 95, 111}; - test_remap_xcd( - initial_values, expected_values, tile_partitioner, num_xcds); - } - else - { - GTEST_SKIP() << "Skipping: not gfx942 build"; - } + const std::vector initial_values = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126}; + const std::vector expected_values = { + 0, 16, 32, 48, 64, 80, 96, 112, 1, 17, 33, 49, 65, 81, 97, 113, 2, 18, 34, 50, + 66, 82, 98, 114, 3, 19, 35, 51, 67, 83, 99, 115, 4, 20, 36, 52, 68, 84, 100, 116, + 5, 21, 37, 53, 69, 85, 101, 117, 6, 22, 38, 54, 70, 86, 102, 118, 7, 23, 39, 55, + 71, 87, 103, 119, 8, 24, 40, 56, 72, 88, 104, 120, 9, 25, 41, 57, 73, 89, 105, 121, + 10, 26, 42, 58, 74, 90, 106, 122, 11, 27, 43, 59, 75, 91, 107, 123, 12, 28, 44, 60, + 76, 92, 108, 124, 13, 29, 45, 61, 77, 93, 109, 125, 14, 30, 46, 62, 78, 94, 110, 126, + 15, 31, 47, 63, 79, 95, 111}; + test_remap_xcd(initial_values, expected_values, tile_partitioner, num_xcds); } TEST(StreamKTilePartitionerBaseRemapXCD, UnevenXCD) { - using CompilerTargetT = decltype(ck_tile::core::arch::get_compiler_target()); - if constexpr(CompilerTargetT::TARGET_ID == ck_tile::core::arch::amdgcn_target_id::GFX942) - { - using Config = StreamKTilePartitionerBaseConfigSKOnly; + constexpr int num_xcds = 5; + using Config = StreamKTilePartitionerBaseConfigSKOnly; - ck_tile::StreamKTilePartitioner + ck_tile:: + StreamKTilePartitioner tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; - const std::vector initial_values = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const std::vector expected_values = { - 0, 4, 7, 10, 13, 1, 5, 8, 11, 14, 2, 6, 9, 12, 15, 3}; + const std::vector initial_values = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const std::vector expected_values = { + 0, 4, 7, 10, 13, 1, 5, 8, 11, 14, 2, 6, 9, 12, 15, 3}; - test_remap_xcd(initial_values, expected_values, tile_partitioner, 5); - } - else - { - GTEST_SKIP() << "Skipping: not gfx942 build"; - } + test_remap_xcd(initial_values, expected_values, tile_partitioner, num_xcds); +} + +TEST(StreamKTilePartitionerBaseRemapXCD, SingleXCD) +{ + constexpr int num_xcds = 1; + using Config = StreamKTilePartitionerBaseConfigSKOnly; + + ck_tile:: + StreamKTilePartitioner + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + const std::vector initial_values = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const std::vector expected_values = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + test_remap_xcd(initial_values, expected_values, tile_partitioner, num_xcds); } TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, SKOnlyLargeK)