Skip to content

Commit dc14f61

Browse files
committed
Renaming the get_grid function and grid_ parameter to max_active_wgs to avoid confusion with grid_size
1 parent 36ca523 commit dc14f61

File tree

4 files changed

+21
-19
lines changed

4 files changed

+21
-19
lines changed

projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ struct StreamKKernel
790790

791791
// Data-parallel section
792792
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
793-
tile_idx += kargs.tile_partitioner.get_grid())
793+
tile_idx += kargs.tile_partitioner.get_max_active_wgs())
794794
{
795795
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
796796
block_sync_lds();

projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ struct StreamKTilePartitionerBase
156156
* @brief Returns the maximum number of active workgroups; this is assumed to be number of CUs *
157157
* occupancy.
158158
*/
159-
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept;
159+
CK_TILE_HOST_DEVICE index_t get_max_active_wgs() const noexcept;
160160

161161
/**
162162
* @brief Returns the number of tiles in the C tensor that will use the data-parallel (DP)
@@ -215,7 +215,7 @@ struct StreamKTilePartitionerBase
215215

216216
protected:
217217
index_t num_tiles_;
218-
index_t grid_;
218+
index_t max_active_wgs_;
219219
index_t dp_tiles_;
220220

221221
private:
@@ -290,7 +290,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>
290290

291291
/**
292292
* @brief Returns the total number of DP tiles left over when `dp_tiles_` is not evenly
293-
* divisible by `grid_`.
293+
* divisible by `max_active_wgs_`.
294294
*/
295295
CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept;
296296

projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,24 @@ namespace ck_tile {
88
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
99
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTilePartitionerBase(
1010
index_t m, index_t n, index_t k, index_t grid)
11-
: grid_{grid}, n_{n}
11+
: max_active_wgs_{grid}, n_{n}
1212
{
1313
iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
1414
num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock);
1515

16-
bool big_enough = num_tiles_ > grid_;
17-
index_t remainder_tiles = num_tiles_ % grid_;
16+
bool big_enough = num_tiles_ > max_active_wgs_;
17+
index_t remainder_tiles = num_tiles_ % max_active_wgs_;
1818

1919
if(remainder_tiles)
2020
{
21-
sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
22-
sk_tiles_ = min(num_tiles_, sk_tiles_);
23-
sk_ctas_ = grid_;
21+
sk_tiles_ = big_enough ? full_tiles_ * max_active_wgs_ + (num_tiles_ % max_active_wgs_)
22+
: num_tiles_;
23+
sk_tiles_ = min(num_tiles_, sk_tiles_);
24+
sk_ctas_ = max_active_wgs_;
2425
total_sk_iters_ = sk_tiles_ * iters_per_tile_;
2526

2627
// If there still isn't enough work to saturate all CUs, then just revert to DP only.
27-
if(total_sk_iters_ < grid_)
28+
if(total_sk_iters_ < max_active_wgs_)
2829
{
2930
sk_tiles_ = 0;
3031
sk_ctas_ = 0;
@@ -175,9 +176,10 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_num_t
175176

176177
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
177178
CK_TILE_HOST_DEVICE index_t
178-
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_grid() const noexcept
179+
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_max_active_wgs()
180+
const noexcept
179181
{
180-
return grid_;
182+
return max_active_wgs_;
181183
}
182184

183185
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
@@ -290,8 +292,8 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::StreamK
290292
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
291293
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
292294
{ // inherit from base constructor
293-
dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
294-
extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
295+
dp_tiles_per_cta_ = this->dp_tiles_ / this->max_active_wgs_;
296+
extra_dp_tiles_ = this->dp_tiles_ % this->max_active_wgs_;
295297
}
296298

297299
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
@@ -301,7 +303,7 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::grid_si
301303
{
302304
if(extra_dp_tiles_ == 0)
303305
{
304-
return dim3(this->grid_, 1, 1);
306+
return dim3(this->max_active_wgs_, 1, 1);
305307
}
306308
else
307309
{

projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ void validate_streamk_base_constructor(
183183
EXPECT_EQ(tile_partitioner.get_iters_per_tile(), expected_values.iters_per_tile_);
184184
EXPECT_EQ(tile_partitioner.get_total_dp_iters(), expected_values.total_dp_iters_);
185185
EXPECT_EQ(tile_partitioner.get_num_tiles(), expected_values.num_tiles_);
186-
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
186+
EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.grid_);
187187
EXPECT_EQ(tile_partitioner.get_n(), expected_values.n_);
188188
}
189189

@@ -446,7 +446,7 @@ void validate_streamk_persistent(
446446
{
447447
EXPECT_EQ(tile_partitioner.get_dp_tiles_per_cta(), expected_values.dp_tiles_per_cta_);
448448
EXPECT_EQ(tile_partitioner.get_extra_dp_tiles(), expected_values.extra_dp_tiles_);
449-
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
449+
EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.grid_);
450450
}
451451

452452
// Non-Persistent
@@ -459,5 +459,5 @@ void validate_streamk_nonpersistent(
459459
EXPECT_EQ(tile_partitioner.get_dp_ctas(), expected_values.dp_ctas_);
460460
EXPECT_EQ(tile_partitioner.get_dp_start_block_idx(), expected_values.dp_start_block_idx_);
461461
EXPECT_EQ(tile_partitioner.get_sk_start_block_idx(), expected_values.sk_start_block_idx_);
462-
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
462+
EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.grid_);
463463
}

0 commit comments

Comments
 (0)