Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions projects/composablekernel/include/ck_tile/host/device_prop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,23 @@ inline size_t get_num_cus()
return static_cast<size_t>(props.multiProcessorCount);
}

inline size_t get_num_xccs()
{
int device = 0;
int num_xccs = 1;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return 0;
}
status = hipDeviceGetAttribute(&num_xccs, hipDeviceAttributeNumberOfXccs, device);
if(status == hipSuccess)
{
return num_xccs;
}
return 1;
}

} // namespace ck_tile

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "streamk_gemm_coherency.hpp"

namespace ck_tile {
Expand Down Expand Up @@ -119,7 +120,7 @@ struct StreamKKernel

struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
{
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid, int num_xccs_ = 1)
: UniversalGemmKernelArgs{host_args.as_ptr,
host_args.bs_ptr,
host_args.ds_ptr,
Expand All @@ -135,7 +136,8 @@ struct StreamKKernel
// The workspace pointer is set to nullptr because we must first
// instantiate the TilePartitioner to get the necessary size
workspace_ptr{nullptr},
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}},
num_xccs{num_xccs_}

{
}
Expand All @@ -149,6 +151,11 @@ struct StreamKKernel
* the C tensor.
*/
TilePartitioner tile_partitioner;
/**
* @brief An int for the number of xcds available on a given device for remapping the block
* indices to be contiguous.
*/
int num_xccs;
};

using KernelArgs = StreamKKernelArgs;
Expand Down Expand Up @@ -207,8 +214,8 @@ struct StreamKKernel
int occupancy = Occupancy())
{
const index_t grid = num_cu * occupancy;

return StreamKKernelArgs{host_args, grid};
const int num_xccs = get_num_xccs();
return StreamKKernelArgs{host_args, grid, num_xccs};
}

template <bool UseDefaultScheduler = true>
Expand Down Expand Up @@ -753,10 +760,12 @@ struct StreamKKernel
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];

index_t block_idx = ck_tile::get_block_1d_id();
index_t grid_size = kargs.tile_partitioner.grid_size().x;
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();

block_idx = kargs.tile_partitioner.remap_xcd(block_idx, grid_size, kargs.num_xccs);
// Check if at the data parallel section
if(is_dp_ctas)
{
Expand Down Expand Up @@ -786,8 +795,10 @@ struct StreamKKernel
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];

index_t block_idx = ck_tile::get_block_1d_id();
index_t grid_size = kargs.tile_partitioner.grid_size().x;
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();

block_idx = kargs.tile_partitioner.remap_xcd(block_idx, grid_size, kargs.num_xccs);
// Data-parallel section
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
tile_idx += kargs.tile_partitioner.get_grid())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,18 @@ struct StreamKTilePartitionerBase
*/
CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept;

/**
* @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous
* XCD segments
*
* @param block_1d_id grid 1D id
* @param total_num_tiles size of the 1D grid
* @param num_xcds number of XCDs
* @return index_t The id after XCD remap
*/
CK_TILE_HOST_DEVICE static index_t
remap_xcd(index_t block_1d_id, index_t total_num_tiles, index_t num_xcds = 8) noexcept;

protected:
index_t num_tiles_;
index_t grid_;
Expand Down Expand Up @@ -281,7 +293,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>
*
* @return dim_3 The launching grid size for the kernel.
*/
CK_TILE_HOST auto grid_size() const noexcept -> dim3;
CK_TILE_HOST_DEVICE auto grid_size() const noexcept -> dim3;

/**
* @brief Returns the total number of DP tiles per workgroup.
Expand Down Expand Up @@ -328,7 +340,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>
*
* @return dim_3 The launching grid size for the kernel.
*/
CK_TILE_HOST auto grid_size() const noexcept -> dim3;
CK_TILE_HOST_DEVICE auto grid_size() const noexcept -> dim3;

/**
* @brief Returns the total number of DP workgroups.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,59 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_
return std::max(num_wgs_per_tile, 1);
}

/**
* @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous
* XCD segments
*
* @param block_1d_id grid 1D id
* @param total_num_tiles size of the 1D grid
* @param num_xcds number of XCDs
* @return index_t The id after XCD remap
*/
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST_DEVICE /* static */ index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::remap_xcd(
index_t block_1d_id, index_t total_num_tiles, index_t num_xcds) noexcept
{
if(num_xcds == 1)
{
return block_1d_id;
}
// Number of ids per XCD in the new arrangement
index_t ids_per_xcd = (total_num_tiles + num_xcds - 1) / num_xcds;

// When total_num_tiles cannot divide num_xcds, some xcds will have
// ids_per_xcd ids, the other will have ids_per_xcd - 1 ids.
// We calculate the number of xcds that have ids_per_xcd ids as tall_xcds
index_t tall_xcds = total_num_tiles % num_xcds;
tall_xcds = (tall_xcds == 0) ? num_xcds : tall_xcds;

// Compute current XCD and local id within the XCD
index_t xcd = block_1d_id % num_xcds;
index_t local_id = block_1d_id / num_xcds;

// Calculate new id based on the new grouping
if(xcd < tall_xcds)
{
block_1d_id = xcd * ids_per_xcd + local_id;
}
else
{
block_1d_id = tall_xcds * ids_per_xcd + (xcd - tall_xcds) * (ids_per_xcd - 1) + local_id;
}

/**
* original ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
* XCD 0 gets: [0, 8], XCD 1 gets: [1, 9], ...
*
* post-remap ids: [0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15]
* XCD 0 gets: [0, 1], XCD 1 gets: [2, 3], ...
*
* after remap the ids are continguous on each XCD
*/
return block_1d_id;
}

template <typename BlockGemmShapeType,
StreamKReductionStrategy ReductionStrategyType,
bool Persistent>
Expand All @@ -295,7 +348,7 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::StreamK
}

template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST auto
CK_TILE_HOST_DEVICE auto
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::grid_size() const noexcept
-> dim3
{
Expand Down Expand Up @@ -337,7 +390,7 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::Stream
}

template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST auto
CK_TILE_HOST_DEVICE auto
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::grid_size() const noexcept
-> dim3
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT

#include "test_streamk_tile_partitioner_common.hpp"
#include "ck_tile/host/device_prop.hpp"

TEST(StreamKTilePartitionerBaseConstructor, SKOnly)
{
Expand Down Expand Up @@ -407,6 +408,86 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings)
}
}

TEST(StreamKTilePartitionerBaseRemapXCD, SmallArray)
{
int num_xcds = 8;
using Config = StreamKTilePartitionerBaseConfigSKOnly;

ck_tile::
StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};

const std::vector<ck_tile::index_t> initial_values = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
const std::vector<ck_tile::index_t> expected_values = {
0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};

test_remap_xcd<Config::GemmShape>(initial_values, expected_values, tile_partitioner, num_xcds);
}

TEST(StreamKTilePartitionerBaseRemapXCD, MidArray)
{
int num_xcds = 8;
using Config = StreamKTilePartitionerBaseConfigSKOnly;

ck_tile::
StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};

const std::vector<ck_tile::index_t> initial_values = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126};
const std::vector<ck_tile::index_t> expected_values = {
0, 16, 32, 48, 64, 80, 96, 112, 1, 17, 33, 49, 65, 81, 97, 113, 2, 18, 34, 50,
66, 82, 98, 114, 3, 19, 35, 51, 67, 83, 99, 115, 4, 20, 36, 52, 68, 84, 100, 116,
5, 21, 37, 53, 69, 85, 101, 117, 6, 22, 38, 54, 70, 86, 102, 118, 7, 23, 39, 55,
71, 87, 103, 119, 8, 24, 40, 56, 72, 88, 104, 120, 9, 25, 41, 57, 73, 89, 105, 121,
10, 26, 42, 58, 74, 90, 106, 122, 11, 27, 43, 59, 75, 91, 107, 123, 12, 28, 44, 60,
76, 92, 108, 124, 13, 29, 45, 61, 77, 93, 109, 125, 14, 30, 46, 62, 78, 94, 110, 126,
15, 31, 47, 63, 79, 95, 111};
test_remap_xcd<Config::GemmShape>(initial_values, expected_values, tile_partitioner, num_xcds);
}

TEST(StreamKTilePartitionerBaseRemapXCD, UnevenXCD)
{
constexpr int num_xcds = 5;
using Config = StreamKTilePartitionerBaseConfigSKOnly;

ck_tile::
StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};

const std::vector<ck_tile::index_t> initial_values = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
const std::vector<ck_tile::index_t> expected_values = {
0, 4, 7, 10, 13, 1, 5, 8, 11, 14, 2, 6, 9, 12, 15, 3};

test_remap_xcd<Config::GemmShape>(initial_values, expected_values, tile_partitioner, num_xcds);
}

TEST(StreamKTilePartitionerBaseRemapXCD, SingleXCD)
{
constexpr int num_xcds = 1;
using Config = StreamKTilePartitionerBaseConfigSKOnly;

ck_tile::
StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};

const std::vector<ck_tile::index_t> initial_values = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
const std::vector<ck_tile::index_t> expected_values = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};

test_remap_xcd<Config::GemmShape>(initial_values, expected_values, tile_partitioner, num_xcds);
}

TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, SKOnlyLargeK)
{
/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,24 @@ void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start,
EXPECT_EQ(tile_local_cta_idx, expected_tile_local_cta_idx);
}

template <typename GemmShape>
void test_remap_xcd(
const std::vector<ck_tile::index_t>& initial_values,
const std::vector<ck_tile::index_t>& expected_values,
ck_tile::StreamKTilePartitioner<GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true>&
tile_partitioner,
const int num_xcds = 8)
{
std::vector<ck_tile::index_t> remapped_values(initial_values.size());
for(std::size_t i = 0; i < initial_values.size(); ++i)
{
remapped_values[i] =
tile_partitioner.remap_xcd(initial_values[i], initial_values.size(), num_xcds);
}

EXPECT_EQ(remapped_values, expected_values);
}

// Configs for TilePartitioner Child structs
struct StreamKTilePartitionerV2PersistentExpected
{
Expand Down
Loading