Skip to content

Commit 320ef34

Browse files
committed
Adding support for CompV4 pipeline in Stream-K GEMM through extended and
smoke tests.
1 parent 3b34060 commit 320ef34

12 files changed

+221
-7
lines changed

projects/composablekernel/test/ck_tile/gemm_streamk/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,33 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
2525
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
2626
set(STREAMK_EXTENDED_SOURCES
2727
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp
28+
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp
2829
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp
2930
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp
3032
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp
3133
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp
34+
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp
3235
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp
3336
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp
37+
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp
3438
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp
3539
test_gemm_streamk_util.cpp)
3640

3741
# We only test fp8 and bf8 on gfx942 and gfx950 since these types are not natively supported on gfx90a
3842
if(GPU_TARGETS MATCHES "gfx942|gfx950")
3943
list(APPEND STREAMK_EXTENDED_SOURCES
4044
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp
45+
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp
4146
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp
4247
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp
48+
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp
4349
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp
4450
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp
51+
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp
4552
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp
4653
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp
54+
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp
4755
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp)
4856
endif()
4957

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
template <typename Tuple>
7+
class TestCkTileStreamKBf16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
8+
{
9+
};
10+
11+
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV4
12+
13+
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV4,
14+
KernelTypesStreamKBf16NonPersistentCompV4);
15+
16+
#include "test_gemm_streamk_extended_cases.inc"
17+
18+
#undef TEST_SUITE_NAME
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
template <typename Tuple>
7+
class TestCkTileStreamKBf16PersistentCompV4 : public TestCkTileStreamK<Tuple>
8+
{
9+
};
10+
11+
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV4
12+
13+
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV4, KernelTypesStreamKBf16PersistentCompV4);
14+
15+
#include "test_gemm_streamk_extended_cases.inc"
16+
17+
#undef TEST_SUITE_NAME
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
template <typename Tuple>
7+
class TestCkTileStreamKBf8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
8+
{
9+
};
10+
11+
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV4
12+
13+
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV4, KernelTypesStreamKBf8NonPersistentCompV4);
14+
15+
#include "test_gemm_streamk_extended_cases.inc"
16+
17+
#undef TEST_SUITE_NAME
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
template <typename Tuple>
7+
class TestCkTileStreamKBf8PersistentCompV4 : public TestCkTileStreamK<Tuple>
8+
{
9+
};
10+
11+
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV4
12+
13+
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV4, KernelTypesStreamKBf8PersistentCompV4);
14+
15+
#include "test_gemm_streamk_extended_cases.inc"
16+
17+
#undef TEST_SUITE_NAME
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
template <typename Tuple>
7+
class TestCkTileStreamKFp16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
8+
{
9+
};
10+
11+
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV4
12+
13+
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV4,
14+
KernelTypesStreamKFp16NonPersistentCompV4);
15+
16+
#include "test_gemm_streamk_extended_cases.inc"
17+
18+
#undef TEST_SUITE_NAME
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
template <typename Tuple>
7+
class TestCkTileStreamKFp16PersistentCompV4 : public TestCkTileStreamK<Tuple>
8+
{
9+
};
10+
11+
#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV4
12+
13+
TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV4, KernelTypesStreamKFp16PersistentCompV4);
14+
15+
#include "test_gemm_streamk_extended_cases.inc"
16+
17+
#undef TEST_SUITE_NAME
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
template <typename Tuple>
7+
class TestCkTileStreamKFp8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
8+
{
9+
};
10+
11+
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV4
12+
13+
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV4, KernelTypesStreamKFp8NonPersistentCompV4);
14+
15+
#include "test_gemm_streamk_extended_cases.inc"
16+
17+
#undef TEST_SUITE_NAME
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_common_includes.hpp"
5+
6+
template <typename Tuple>
7+
class TestCkTileStreamKFp8PersistentCompV4 : public TestCkTileStreamK<Tuple>
8+
{
9+
};
10+
11+
#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV4
12+
13+
TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV4, KernelTypesStreamKFp8PersistentCompV4);
14+
15+
#include "test_gemm_streamk_extended_cases.inc"
16+
17+
#undef TEST_SUITE_NAME

projects/composablekernel/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ using F32 = float;
1717
using Row = ck_tile::tensor_layout::gemm::RowMajor;
1818
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
1919

20-
using Persistent = std::true_type;
21-
using NonPersistent = std::false_type;
22-
2320
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
2421
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
22+
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
23+
24+
using Persistent = std::true_type;
25+
using NonPersistent = std::false_type;
2526

2627
using I32 = ck_tile::number<32>;
2728
using I128 = ck_tile::number<128>;
@@ -89,6 +90,66 @@ using KernelTypesStreamKFp8NonPersistentCompV3 = ::testing::Types<
8990
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>
9091
>;
9192

93+
// ========================== CompV4 Pipeline ==========================
94+
95+
using KernelTypesStreamKFp16PersistentCompV4 = ::testing::Types<
96+
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline
97+
98+
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
99+
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
100+
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
101+
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>
102+
>;
103+
104+
using KernelTypesStreamKBf16PersistentCompV4 = ::testing::Types<
105+
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
106+
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
107+
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
108+
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>
109+
>;
110+
111+
using KernelTypesStreamKBf8PersistentCompV4 = ::testing::Types<
112+
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
113+
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
114+
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
115+
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>
116+
>;
117+
118+
using KernelTypesStreamKFp8PersistentCompV4 = ::testing::Types<
119+
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
120+
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
121+
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
122+
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>
123+
>;
124+
125+
using KernelTypesStreamKFp16NonPersistentCompV4 = ::testing::Types<
126+
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
127+
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
128+
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
129+
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>
130+
>;
131+
132+
using KernelTypesStreamKBf16NonPersistentCompV4 = ::testing::Types<
133+
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
134+
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
135+
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
136+
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>
137+
>;
138+
139+
using KernelTypesStreamKBf8NonPersistentCompV4 = ::testing::Types<
140+
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
141+
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
142+
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
143+
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>
144+
>;
145+
146+
using KernelTypesStreamKFp8NonPersistentCompV4 = ::testing::Types<
147+
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
148+
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
149+
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
150+
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>
151+
>;
152+
92153
// ============================= Mem Pipeline =============================
93154

94155
using KernelTypesStreamKFp16PersistentMem = ::testing::Types<

0 commit comments

Comments
 (0)