Skip to content

Commit 229035a

Browse files
committed
Renaming the get_grid function and grid_ parameter to max_active_wgs to avoid confusion with grid_size
1 parent 208d063 commit 229035a

File tree

4 files changed

+88
-87
lines changed

4 files changed

+88
-87
lines changed

projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ struct StreamKKernel
119119

120120
struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
121121
{
122-
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
122+
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t max_active_wgs)
123123
: UniversalGemmKernelArgs{host_args.as_ptr,
124124
host_args.bs_ptr,
125125
host_args.ds_ptr,
@@ -135,7 +135,8 @@ struct StreamKKernel
135135
// The workspace pointer is set to nullptr because we must first
136136
// instantiate the TilePartitioner to get the necessary size
137137
workspace_ptr{nullptr},
138-
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
138+
tile_partitioner{
139+
TilePartitioner{host_args.M, host_args.N, host_args.K, max_active_wgs}}
139140

140141
{
141142
}
@@ -206,9 +207,9 @@ struct StreamKKernel
206207
int num_cu = NumCU(),
207208
int occupancy = Occupancy())
208209
{
209-
const index_t grid = num_cu * occupancy;
210+
const index_t max_active_wgs = num_cu * occupancy;
210211

211-
return StreamKKernelArgs{host_args, grid};
212+
return StreamKKernelArgs{host_args, max_active_wgs};
212213
}
213214

214215
template <bool UseDefaultScheduler = true>

projects/composablekernel/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ namespace ck_tile {
77

88
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
99
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTilePartitionerBase(
10-
index_t m, index_t n, index_t k, index_t grid)
11-
: max_active_wgs_{grid}, n_{n}
10+
index_t m, index_t n, index_t k, index_t max_active_wgs)
11+
: max_active_wgs_{max_active_wgs}, n_{n}
1212
{
1313
iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
1414
num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock);
@@ -289,8 +289,8 @@ struct StreamKTilePartitioner;
289289
// child class for Persistent Tile Partitioner
290290
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
291291
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::StreamKTilePartitioner(
292-
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
293-
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
292+
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs)
293+
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, max_active_wgs)
294294
{ // inherit from base constructor
295295
dp_tiles_per_cta_ = this->dp_tiles_ / this->max_active_wgs_;
296296
extra_dp_tiles_ = this->dp_tiles_ % this->max_active_wgs_;
@@ -330,8 +330,8 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::get_ext
330330
// child class for Non-Persistent Tile Partitioner
331331
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
332332
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::StreamKTilePartitioner(
333-
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
334-
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
333+
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs)
334+
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, max_active_wgs)
335335
{ // inherit from base constructor
336336
dp_ctas_ = this->dp_tiles_;
337337
dp_start_block_idx_ = 0;

projects/composablekernel/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ TEST(StreamKTilePartitionerBaseConstructor, SKOnly)
88
using Config = StreamKTilePartitionerBaseConfigSKOnly;
99

1010
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
11-
Config::M, Config::N, Config::K, Config::GRID};
11+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
1212

1313
StreamKTilePartitionerBaseExpected expected_values{
14-
2, 0, 3, 4, 1, 2, 1, 0, 2, Config::GRID, Config::N};
14+
2, 0, 3, 4, 1, 2, 1, 0, 2, Config::MAX_ACTIVE_WGS, Config::N};
1515
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
1616
}
1717

@@ -20,10 +20,10 @@ TEST(StreamKTilePartitionerBaseConstructor, DPOnly)
2020
using Config = StreamKTilePartitionerBaseConfigDPOnly;
2121

2222
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
23-
Config::M, Config::N, Config::K, Config::GRID};
23+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
2424

2525
StreamKTilePartitionerBaseExpected expected_values{
26-
0, 6, 0, 0, 0, 2, 0, 12, 6, Config::GRID, Config::N};
26+
0, 6, 0, 0, 0, 2, 0, 12, 6, Config::MAX_ACTIVE_WGS, Config::N};
2727
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
2828
}
2929

@@ -32,10 +32,10 @@ TEST(StreamKTilePartitionerBaseConstructor, DP2TileSK)
3232
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
3333

3434
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
35-
Config::M, Config::N, Config::K, Config::GRID};
35+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
3636

3737
StreamKTilePartitionerBaseExpected expected_values{
38-
4, 3, 3, 8, 2, 2, 2, 6, 7, Config::GRID, Config::N};
38+
4, 3, 3, 8, 2, 2, 2, 6, 7, Config::MAX_ACTIVE_WGS, Config::N};
3939
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
4040
}
4141

@@ -44,10 +44,10 @@ TEST(StreamKTilePartitionerBaseConstructor, EdgeCase)
4444
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
4545

4646
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
47-
Config::M, Config::N, Config::K, Config::GRID};
47+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
4848

4949
StreamKTilePartitionerBaseExpected expected_values{
50-
0, 1, 0, 0, 0, 2, 0, 2, 1, Config::GRID, Config::N};
50+
0, 1, 0, 0, 0, 2, 0, 2, 1, Config::MAX_ACTIVE_WGS, Config::N};
5151
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
5252
}
5353

@@ -57,7 +57,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsLessThan128Bytes)
5757

5858
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
5959
ck_tile::StreamKReductionStrategy::Linear>
60-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
60+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
6161

6262
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
6363
}
@@ -68,7 +68,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsEqual128Bytes)
6868

6969
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
7070
ck_tile::StreamKReductionStrategy::Linear>
71-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
71+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
7272

7373
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
7474
}
@@ -79,7 +79,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsGreaterThan128Bytes)
7979

8080
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
8181
ck_tile::StreamKReductionStrategy::Linear>
82-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
82+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
8383

8484
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 256);
8585
}
@@ -89,7 +89,7 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy)
8989
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
9090

9191
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
92-
Config::M, Config::N, Config::K, Config::GRID};
92+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
9393

9494
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), 0);
9595
}
@@ -100,12 +100,12 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy)
100100

101101
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
102102
ck_tile::StreamKReductionStrategy::Linear>
103-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
103+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
104104

105105
ck_tile::index_t expected_partials_size =
106-
sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID;
107-
// Since GRID is 3, the final padded flags array must be 128B to ensure the total byte size of
108-
// the flags array is 128B-aligned.
106+
sizeof(float) * Config::M_TILE * Config::N_TILE * Config::MAX_ACTIVE_WGS;
107+
// Since MAX_ACTIVE_WGS is 3, the final padded flags array must be 128B to ensure the total byte
108+
// size of the flags array is 128B-aligned.
109109
ck_tile::index_t expected_flags_size = 128;
110110

111111
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)),
@@ -117,7 +117,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileLower
117117
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
118118

119119
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
120-
Config::M, Config::N, Config::K, Config::GRID};
120+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
121121

122122
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2);
123123
}
@@ -127,7 +127,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileEqual
127127
using Config = StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile;
128128

129129
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
130-
Config::M, Config::N, Config::K, Config::GRID};
130+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
131131

132132
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2);
133133
}
@@ -232,7 +232,7 @@ TEST(StreamKTilePartitionerBaseGetTileBoundaries, GetTileBoundaries)
232232

233233
// Test parameters
234234
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
235-
Config::M, Config::N, Config::K, Config::GRID};
235+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
236236
ck_tile::DeviceMem tile_iter_start_dev(sizeof(ck_tile::index_t));
237237
ck_tile::DeviceMem tile_iter_end_dev(sizeof(ck_tile::index_t));
238238
ck_tile::index_t tile_idx = 1;
@@ -267,7 +267,7 @@ TEST(StreamKTilePartitionerBaseGetTileIndex, GetTileIndex)
267267

268268
// Test parameters
269269
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
270-
Config::M, Config::N, Config::K, Config::GRID};
270+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
271271
ck_tile::DeviceMem tile_idx_dev(sizeof(ck_tile::index_t));
272272
ck_tile::index_t iter_start = 8;
273273

@@ -299,7 +299,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, ZeroExtraItersBeforeMe)
299299

300300
// Test parameters
301301
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
302-
Config::M, Config::N, Config::K, Config::GRID};
302+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
303303
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
304304
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
305305
ck_tile::index_t cta_idx = 0;
@@ -333,7 +333,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, NonZeroExtraItersBeforeMe)
333333

334334
// Test parameters
335335
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
336-
Config::M, Config::N, Config::K, Config::GRID};
336+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
337337
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
338338
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
339339
ck_tile::index_t cta_idx = 1;
@@ -367,7 +367,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, MinIsExtraIters)
367367

368368
// Test parameters
369369
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
370-
Config::M, Config::N, Config::K, Config::GRID};
370+
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
371371
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
372372
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
373373
ck_tile::index_t cta_idx = 2;
@@ -493,7 +493,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly)
493493

494494
ck_tile::
495495
StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true>
496-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
496+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
497497

498498
StreamKTilePartitionerV2PersistentExpected expected_values{0, 0, 3};
499499
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -506,7 +506,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, DPOnly)
506506
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
507507
ck_tile::StreamKReductionStrategy::Atomic,
508508
true>
509-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
509+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
510510

511511
StreamKTilePartitionerV2PersistentExpected expected_values{2, 0, 3};
512512
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -519,7 +519,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, DP2TileSK)
519519
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
520520
ck_tile::StreamKReductionStrategy::Atomic,
521521
true>
522-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
522+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
523523

524524
StreamKTilePartitionerV2PersistentExpected expected_values{1, 0, 3};
525525
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -532,7 +532,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, EdgeCase)
532532
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
533533
ck_tile::StreamKReductionStrategy::Atomic,
534534
true>
535-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
535+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
536536

537537
StreamKTilePartitionerV2PersistentExpected expected_values{0, 1, 4};
538538
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -545,10 +545,10 @@ TEST(StreamKTilePartitioner_GridSize_Persistent, SKOnly)
545545
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
546546
ck_tile::StreamKReductionStrategy::Atomic,
547547
true>
548-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
548+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
549549

550550
const auto g = tile_partitioner.grid_size();
551-
EXPECT_EQ(g.x, Config::GRID);
551+
EXPECT_EQ(g.x, Config::MAX_ACTIVE_WGS);
552552
}
553553

554554
TEST(StreamKTilePartitioner_GridSize_Persistent, EdgeCase)
@@ -558,7 +558,7 @@ TEST(StreamKTilePartitioner_GridSize_Persistent, EdgeCase)
558558
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
559559
ck_tile::StreamKReductionStrategy::Atomic,
560560
true>
561-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
561+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
562562

563563
const auto g = tile_partitioner.grid_size();
564564
EXPECT_EQ(g.x, 1);
@@ -571,7 +571,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, SKOnly)
571571

572572
ck_tile::
573573
StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, false>
574-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
574+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
575575

576576
StreamKTilePartitionerV2NonPersistentExpected expected_values{0, 0, 0, 3};
577577
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -584,7 +584,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, DPOnly)
584584
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
585585
ck_tile::StreamKReductionStrategy::Atomic,
586586
false>
587-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
587+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
588588

589589
StreamKTilePartitionerV2NonPersistentExpected expected_values{6, 0, 6, 3};
590590
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -597,7 +597,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, DP2TileSK)
597597
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
598598
ck_tile::StreamKReductionStrategy::Atomic,
599599
false>
600-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
600+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
601601

602602
StreamKTilePartitionerV2NonPersistentExpected expected_values{3, 0, 3, 3};
603603
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -610,7 +610,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, EdgeCase)
610610
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
611611
ck_tile::StreamKReductionStrategy::Atomic,
612612
false>
613-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
613+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
614614

615615
StreamKTilePartitionerV2NonPersistentExpected expected_values{1, 0, 1, 4};
616616
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -623,7 +623,7 @@ TEST(StreamKTilePartitioner_GridSize_NonPersistent, DP2TileSK)
623623
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
624624
ck_tile::StreamKReductionStrategy::Atomic,
625625
false>
626-
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
626+
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
627627

628628
const auto g = tile_partitioner.grid_size();
629629
EXPECT_EQ(g.x, 6);

0 commit comments

Comments
 (0)