Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ struct StreamKKernel

struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
{
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t max_active_wgs)
: UniversalGemmKernelArgs{host_args.as_ptr,
host_args.bs_ptr,
host_args.ds_ptr,
Expand All @@ -135,7 +135,8 @@ struct StreamKKernel
// The workspace pointer is set to nullptr because we must first
// instantiate the TilePartitioner to get the necessary size
workspace_ptr{nullptr},
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
tile_partitioner{
TilePartitioner{host_args.M, host_args.N, host_args.K, max_active_wgs}}

{
}
Expand Down Expand Up @@ -206,9 +207,9 @@ struct StreamKKernel
int num_cu = NumCU(),
int occupancy = Occupancy())
{
const index_t grid = num_cu * occupancy;
const index_t max_active_wgs = num_cu * occupancy;

return StreamKKernelArgs{host_args, grid};
return StreamKKernelArgs{host_args, max_active_wgs};
}

template <bool UseDefaultScheduler = true>
Expand Down Expand Up @@ -790,7 +791,7 @@ struct StreamKKernel

// Data-parallel section
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
tile_idx += kargs.tile_partitioner.get_grid())
tile_idx += kargs.tile_partitioner.get_max_active_wgs())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
block_sync_lds();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct StreamKTilePartitionerBase
? memory_operation_enum::atomic_add
: memory_operation_enum::set;

StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t max_number_wgs);

/**
* @brief Calculates the total space needed for the partials buffer.
Expand Down Expand Up @@ -156,7 +156,7 @@ struct StreamKTilePartitionerBase
* @brief Returns the maximum number of active workgroups; this is assumed to be number of CUs *
* occupancy.
*/
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept;
CK_TILE_HOST_DEVICE index_t get_max_active_wgs() const noexcept;

/**
* @brief Returns the number of tiles in the C tensor that will use the data-parallel (DP)
Expand Down Expand Up @@ -215,7 +215,7 @@ struct StreamKTilePartitionerBase

protected:
index_t num_tiles_;
index_t grid_;
index_t max_active_wgs_;
index_t dp_tiles_;

private:
Expand Down Expand Up @@ -270,7 +270,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>
StreamKTilePartitioner(ck_tile::index_t m,
ck_tile::index_t n,
ck_tile::index_t k,
ck_tile::index_t grid);
ck_tile::index_t max_active_wgs);

public:
static constexpr bool PERSISTENT = true;
Expand All @@ -290,7 +290,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>

/**
* @brief Returns the total number of DP tiles left over when `dp_tiles_` is not evenly
* divisible by `grid_`.
* divisible by `max_active_wgs_`.
*/
CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept;

Expand All @@ -317,7 +317,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>
StreamKTilePartitioner(ck_tile::index_t m,
ck_tile::index_t n,
ck_tile::index_t k,
ck_tile::index_t grid);
ck_tile::index_t max_number_wgs);

public:
static constexpr bool PERSISTENT = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,25 @@ namespace ck_tile {

template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTilePartitionerBase(
index_t m, index_t n, index_t k, index_t grid)
: grid_{grid}, n_{n}
index_t m, index_t n, index_t k, index_t max_active_wgs)
: max_active_wgs_{max_active_wgs}, n_{n}
{
iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock);

bool big_enough = num_tiles_ > grid_;
index_t remainder_tiles = num_tiles_ % grid_;
bool big_enough = num_tiles_ > max_active_wgs_;
index_t remainder_tiles = num_tiles_ % max_active_wgs_;

if(remainder_tiles)
{
sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
sk_tiles_ = min(num_tiles_, sk_tiles_);
sk_ctas_ = grid_;
sk_tiles_ = big_enough ? full_tiles_ * max_active_wgs_ + (num_tiles_ % max_active_wgs_)
: num_tiles_;
sk_tiles_ = min(num_tiles_, sk_tiles_);
sk_ctas_ = max_active_wgs_;
total_sk_iters_ = sk_tiles_ * iters_per_tile_;

// If there still isn't enough work to saturate all CUs, then just revert to DP only.
if(total_sk_iters_ < grid_)
if(total_sk_iters_ < max_active_wgs_)
{
sk_tiles_ = 0;
sk_ctas_ = 0;
Expand Down Expand Up @@ -175,9 +176,10 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_num_t

template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_grid() const noexcept
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_max_active_wgs()
const noexcept
{
return grid_;
return max_active_wgs_;
}

template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
Expand Down Expand Up @@ -287,11 +289,11 @@ struct StreamKTilePartitioner;
// child class for Persistent Tile Partitioner
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::StreamKTilePartitioner(
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, max_active_wgs)
{ // inherit from base constructor
dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
dp_tiles_per_cta_ = this->dp_tiles_ / this->max_active_wgs_;
extra_dp_tiles_ = this->dp_tiles_ % this->max_active_wgs_;
}

template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
Expand All @@ -301,7 +303,7 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::grid_si
{
if(extra_dp_tiles_ == 0)
{
return dim3(this->grid_, 1, 1);
return dim3(this->max_active_wgs_, 1, 1);
}
else
{
Expand All @@ -328,8 +330,8 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::get_ext
// child class for Non-Persistent Tile Partitioner
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::StreamKTilePartitioner(
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, max_active_wgs)
{ // inherit from base constructor
dp_ctas_ = this->dp_tiles_;
dp_start_block_idx_ = 0;
Expand Down
Loading
Loading