@@ -8,10 +8,10 @@ TEST(StreamKTilePartitionerBaseConstructor, SKOnly)
88 using Config = StreamKTilePartitionerBaseConfigSKOnly;
99
1010 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
11- Config::M, Config::N, Config::K, Config::GRID };
11+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
1212
1313 StreamKTilePartitionerBaseExpected expected_values{
14- 2 , 0 , 3 , 4 , 1 , 2 , 1 , 0 , 2 , Config::GRID , Config::N};
14+ 2 , 0 , 3 , 4 , 1 , 2 , 1 , 0 , 2 , Config::MAX_ACTIVE_WGS , Config::N};
1515 validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
1616}
1717
@@ -20,10 +20,10 @@ TEST(StreamKTilePartitionerBaseConstructor, DPOnly)
2020 using Config = StreamKTilePartitionerBaseConfigDPOnly;
2121
2222 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
23- Config::M, Config::N, Config::K, Config::GRID };
23+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
2424
2525 StreamKTilePartitionerBaseExpected expected_values{
26- 0 , 6 , 0 , 0 , 0 , 2 , 0 , 12 , 6 , Config::GRID , Config::N};
26+ 0 , 6 , 0 , 0 , 0 , 2 , 0 , 12 , 6 , Config::MAX_ACTIVE_WGS , Config::N};
2727 validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
2828}
2929
@@ -32,10 +32,10 @@ TEST(StreamKTilePartitionerBaseConstructor, DP2TileSK)
3232 using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
3333
3434 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
35- Config::M, Config::N, Config::K, Config::GRID };
35+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
3636
3737 StreamKTilePartitionerBaseExpected expected_values{
38- 4 , 3 , 3 , 8 , 2 , 2 , 2 , 6 , 7 , Config::GRID , Config::N};
38+ 4 , 3 , 3 , 8 , 2 , 2 , 2 , 6 , 7 , Config::MAX_ACTIVE_WGS , Config::N};
3939 validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
4040}
4141
@@ -44,10 +44,10 @@ TEST(StreamKTilePartitionerBaseConstructor, EdgeCase)
4444 using Config = StreamKTilePartitionerBaseConfigEdgeCase;
4545
4646 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
47- Config::M, Config::N, Config::K, Config::GRID };
47+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
4848
4949 StreamKTilePartitionerBaseExpected expected_values{
50- 0 , 1 , 0 , 0 , 0 , 2 , 0 , 2 , 1 , Config::GRID , Config::N};
50+ 0 , 1 , 0 , 0 , 0 , 2 , 0 , 2 , 1 , Config::MAX_ACTIVE_WGS , Config::N};
5151 validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
5252}
5353
@@ -57,7 +57,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsLessThan128Bytes)
5757
5858 ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
5959 ck_tile::StreamKReductionStrategy::Linear>
60- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
60+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
6161
6262 EXPECT_EQ (tile_partitioner.get_flags_buffer_size (), 128 );
6363}
@@ -68,7 +68,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsEqual128Bytes)
6868
6969 ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
7070 ck_tile::StreamKReductionStrategy::Linear>
71- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
71+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
7272
7373 EXPECT_EQ (tile_partitioner.get_flags_buffer_size (), 128 );
7474}
@@ -79,7 +79,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsGreaterThan128Bytes)
7979
8080 ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
8181 ck_tile::StreamKReductionStrategy::Linear>
82- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
82+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
8383
8484 EXPECT_EQ (tile_partitioner.get_flags_buffer_size (), 256 );
8585}
@@ -89,7 +89,7 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy)
8989 using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
9090
9191 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
92- Config::M, Config::N, Config::K, Config::GRID };
92+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
9393
9494 EXPECT_EQ (tile_partitioner.get_workspace_size (sizeof (float )), 0 );
9595}
@@ -100,12 +100,12 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy)
100100
101101 ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
102102 ck_tile::StreamKReductionStrategy::Linear>
103- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
103+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
104104
105105 ck_tile::index_t expected_partials_size =
106- sizeof (float ) * Config::M_TILE * Config::N_TILE * Config::GRID ;
107- // Since GRID is 3, the final padded flags array must be 128B to ensure the total byte size of
108- // the flags array is 128B-aligned.
106+ sizeof (float ) * Config::M_TILE * Config::N_TILE * Config::MAX_ACTIVE_WGS ;
107+ // Since MAX_ACTIVE_WGS is 3, the final padded flags array must be 128B to ensure the total byte
108+ // size of the flags array is 128B-aligned.
109109 ck_tile::index_t expected_flags_size = 128 ;
110110
111111 EXPECT_EQ (tile_partitioner.get_workspace_size (sizeof (float )),
@@ -117,7 +117,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileLower
117117 using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
118118
119119 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
120- Config::M, Config::N, Config::K, Config::GRID };
120+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
121121
122122 EXPECT_EQ (tile_partitioner.estimate_num_wgs_per_tile (), 2 );
123123}
@@ -127,7 +127,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileEqual
127127 using Config = StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile;
128128
129129 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
130- Config::M, Config::N, Config::K, Config::GRID };
130+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
131131
132132 EXPECT_EQ (tile_partitioner.estimate_num_wgs_per_tile (), 2 );
133133}
@@ -232,7 +232,7 @@ TEST(StreamKTilePartitionerBaseGetTileBoundaries, GetTileBoundaries)
232232
233233 // Test parameters
234234 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
235- Config::M, Config::N, Config::K, Config::GRID };
235+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
236236 ck_tile::DeviceMem tile_iter_start_dev (sizeof (ck_tile::index_t ));
237237 ck_tile::DeviceMem tile_iter_end_dev (sizeof (ck_tile::index_t ));
238238 ck_tile::index_t tile_idx = 1 ;
@@ -267,7 +267,7 @@ TEST(StreamKTilePartitionerBaseGetTileIndex, GetTileIndex)
267267
268268 // Test parameters
269269 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
270- Config::M, Config::N, Config::K, Config::GRID };
270+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
271271 ck_tile::DeviceMem tile_idx_dev (sizeof (ck_tile::index_t ));
272272 ck_tile::index_t iter_start = 8 ;
273273
@@ -299,7 +299,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, ZeroExtraItersBeforeMe)
299299
300300 // Test parameters
301301 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
302- Config::M, Config::N, Config::K, Config::GRID };
302+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
303303 ck_tile::DeviceMem iter_start_dev (sizeof (ck_tile::index_t ));
304304 ck_tile::DeviceMem iter_end_dev (sizeof (ck_tile::index_t ));
305305 ck_tile::index_t cta_idx = 0 ;
@@ -333,7 +333,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, NonZeroExtraItersBeforeMe)
333333
334334 // Test parameters
335335 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
336- Config::M, Config::N, Config::K, Config::GRID };
336+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
337337 ck_tile::DeviceMem iter_start_dev (sizeof (ck_tile::index_t ));
338338 ck_tile::DeviceMem iter_end_dev (sizeof (ck_tile::index_t ));
339339 ck_tile::index_t cta_idx = 1 ;
@@ -367,7 +367,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, MinIsExtraIters)
367367
368368 // Test parameters
369369 ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
370- Config::M, Config::N, Config::K, Config::GRID };
370+ Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
371371 ck_tile::DeviceMem iter_start_dev (sizeof (ck_tile::index_t ));
372372 ck_tile::DeviceMem iter_end_dev (sizeof (ck_tile::index_t ));
373373 ck_tile::index_t cta_idx = 2 ;
@@ -493,7 +493,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly)
493493
494494 ck_tile::
495495 StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true >
496- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
496+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
497497
498498 StreamKTilePartitionerV2PersistentExpected expected_values{0 , 0 , 3 };
499499 validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -506,7 +506,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, DPOnly)
506506 ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
507507 ck_tile::StreamKReductionStrategy::Atomic,
508508 true >
509- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
509+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
510510
511511 StreamKTilePartitionerV2PersistentExpected expected_values{2 , 0 , 3 };
512512 validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -519,7 +519,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, DP2TileSK)
519519 ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
520520 ck_tile::StreamKReductionStrategy::Atomic,
521521 true >
522- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
522+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
523523
524524 StreamKTilePartitionerV2PersistentExpected expected_values{1 , 0 , 3 };
525525 validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -532,7 +532,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, EdgeCase)
532532 ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
533533 ck_tile::StreamKReductionStrategy::Atomic,
534534 true >
535- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
535+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
536536
537537 StreamKTilePartitionerV2PersistentExpected expected_values{0 , 1 , 4 };
538538 validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -545,10 +545,10 @@ TEST(StreamKTilePartitioner_GridSize_Persistent, SKOnly)
545545 ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
546546 ck_tile::StreamKReductionStrategy::Atomic,
547547 true >
548- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
548+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
549549
550550 const auto g = tile_partitioner.grid_size ();
551- EXPECT_EQ (g.x , Config::GRID );
551+ EXPECT_EQ (g.x , Config::MAX_ACTIVE_WGS );
552552}
553553
554554TEST (StreamKTilePartitioner_GridSize_Persistent, EdgeCase)
@@ -558,7 +558,7 @@ TEST(StreamKTilePartitioner_GridSize_Persistent, EdgeCase)
558558 ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
559559 ck_tile::StreamKReductionStrategy::Atomic,
560560 true >
561- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
561+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
562562
563563 const auto g = tile_partitioner.grid_size ();
564564 EXPECT_EQ (g.x , 1 );
@@ -571,7 +571,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, SKOnly)
571571
572572 ck_tile::
573573 StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, false >
574- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
574+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
575575
576576 StreamKTilePartitionerV2NonPersistentExpected expected_values{0 , 0 , 0 , 3 };
577577 validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -584,7 +584,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, DPOnly)
584584 ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
585585 ck_tile::StreamKReductionStrategy::Atomic,
586586 false >
587- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
587+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
588588
589589 StreamKTilePartitionerV2NonPersistentExpected expected_values{6 , 0 , 6 , 3 };
590590 validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -597,7 +597,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, DP2TileSK)
597597 ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
598598 ck_tile::StreamKReductionStrategy::Atomic,
599599 false >
600- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
600+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
601601
602602 StreamKTilePartitionerV2NonPersistentExpected expected_values{3 , 0 , 3 , 3 };
603603 validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -610,7 +610,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, EdgeCase)
610610 ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
611611 ck_tile::StreamKReductionStrategy::Atomic,
612612 false >
613- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
613+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
614614
615615 StreamKTilePartitionerV2NonPersistentExpected expected_values{1 , 0 , 1 , 4 };
616616 validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -623,7 +623,7 @@ TEST(StreamKTilePartitioner_GridSize_NonPersistent, DP2TileSK)
623623 ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
624624 ck_tile::StreamKReductionStrategy::Atomic,
625625 false >
626- tile_partitioner{Config::M, Config::N, Config::K, Config::GRID };
626+ tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS };
627627
628628 const auto g = tile_partitioner.grid_size ();
629629 EXPECT_EQ (g.x , 6 );
0 commit comments