@@ -8,23 +8,24 @@ namespace ck_tile {
88template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
99StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTilePartitionerBase(
1010 index_t m, index_t n, index_t k, index_t grid)
11- : grid_ {grid}, n_{n}
11+ : max_active_wgs_ {grid}, n_{n}
1212{
1313 iters_per_tile_ = integer_divide_ceil (k, KPerBlock);
1414 num_tiles_ = integer_divide_ceil (m, MPerBlock) * integer_divide_ceil (n_, NPerBlock);
1515
16- bool big_enough = num_tiles_ > grid_ ;
17- index_t remainder_tiles = num_tiles_ % grid_ ;
16+ bool big_enough = num_tiles_ > max_active_wgs_ ;
17+ index_t remainder_tiles = num_tiles_ % max_active_wgs_ ;
1818
1919 if (remainder_tiles)
2020 {
21- sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
22- sk_tiles_ = min (num_tiles_, sk_tiles_);
23- sk_ctas_ = grid_;
21+ sk_tiles_ = big_enough ? full_tiles_ * max_active_wgs_ + (num_tiles_ % max_active_wgs_)
22+ : num_tiles_;
23+ sk_tiles_ = min (num_tiles_, sk_tiles_);
24+ sk_ctas_ = max_active_wgs_;
2425 total_sk_iters_ = sk_tiles_ * iters_per_tile_;
2526
2627 // If there still isn't enough work to saturate all CUs, then just revert to DP only.
27- if (total_sk_iters_ < grid_ )
28+ if (total_sk_iters_ < max_active_wgs_ )
2829 {
2930 sk_tiles_ = 0 ;
3031 sk_ctas_ = 0 ;
@@ -175,9 +176,10 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_num_t
175176
176177template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
177178CK_TILE_HOST_DEVICE index_t
178- StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_grid() const noexcept
179+ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_max_active_wgs()
180+ const noexcept
179181{
180- return grid_ ;
182+ return max_active_wgs_ ;
181183}
182184
183185template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
@@ -290,8 +292,8 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::StreamK
290292 ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
291293 : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
292294{ // inherit from base constructor
293- dp_tiles_per_cta_ = this ->dp_tiles_ / this ->grid_ ;
294- extra_dp_tiles_ = this ->dp_tiles_ % this ->grid_ ;
295+ dp_tiles_per_cta_ = this ->dp_tiles_ / this ->max_active_wgs_ ;
296+ extra_dp_tiles_ = this ->dp_tiles_ % this ->max_active_wgs_ ;
295297}
296298
297299template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
@@ -301,7 +303,7 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::grid_si
301303{
302304 if (extra_dp_tiles_ == 0 )
303305 {
304- return dim3 (this ->grid_ , 1 , 1 );
306+ return dim3 (this ->max_active_wgs_ , 1 , 1 );
305307 }
306308 else
307309 {
0 commit comments