Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,33 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
set(STREAMK_EXTENDED_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp
test_gemm_streamk_util.cpp)

# We only test fp8 and bf8 on gfx942 and gfx950 since these types are not natively supported on gfx90a
if(GPU_TARGETS MATCHES "gfx942|gfx950")
list(APPEND STREAMK_EXTENDED_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp)
endif()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_common_includes.hpp"

template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV4

TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV4,
KernelTypesStreamKBf16NonPersistentCompV4);

#include "test_gemm_streamk_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_common_includes.hpp"

template <typename Tuple>
class TestCkTileStreamKBf16PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV4

TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV4, KernelTypesStreamKBf16PersistentCompV4);

#include "test_gemm_streamk_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_common_includes.hpp"

template <typename Tuple>
class TestCkTileStreamKBf8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV4

TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV4, KernelTypesStreamKBf8NonPersistentCompV4);

#include "test_gemm_streamk_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_common_includes.hpp"

template <typename Tuple>
class TestCkTileStreamKBf8PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV4

TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV4, KernelTypesStreamKBf8PersistentCompV4);

#include "test_gemm_streamk_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_common_includes.hpp"

template <typename Tuple>
class TestCkTileStreamKFp16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistentCompV4

TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistentCompV4,
KernelTypesStreamKFp16NonPersistentCompV4);

#include "test_gemm_streamk_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_common_includes.hpp"

template <typename Tuple>
class TestCkTileStreamKFp16PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKFp16PersistentCompV4

TYPED_TEST_SUITE(TestCkTileStreamKFp16PersistentCompV4, KernelTypesStreamKFp16PersistentCompV4);

#include "test_gemm_streamk_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_common_includes.hpp"

template <typename Tuple>
class TestCkTileStreamKFp8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistentCompV4

TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistentCompV4, KernelTypesStreamKFp8NonPersistentCompV4);

#include "test_gemm_streamk_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_common_includes.hpp"

template <typename Tuple>
class TestCkTileStreamKFp8PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKFp8PersistentCompV4

TYPED_TEST_SUITE(TestCkTileStreamKFp8PersistentCompV4, KernelTypesStreamKFp8PersistentCompV4);

#include "test_gemm_streamk_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

using Persistent = std::true_type;
using NonPersistent = std::false_type;

using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;

using Persistent = std::true_type;
using NonPersistent = std::false_type;

using I32 = ck_tile::number<32>;
using I128 = ck_tile::number<128>;
Expand Down Expand Up @@ -89,6 +90,66 @@ using KernelTypesStreamKFp8NonPersistentCompV3 = ::testing::Types<
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV3>
>;

// ========================== CompV4 Pipeline ==========================

using KernelTypesStreamKFp16PersistentCompV4 = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent Pipeline

std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent, CompV4>
>;

using KernelTypesStreamKBf16PersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent, CompV4>
>;

using KernelTypesStreamKBf8PersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, Persistent, CompV4>
>;

using KernelTypesStreamKFp8PersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, Persistent, CompV4>
>;

using KernelTypesStreamKFp16NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent, CompV4>
>;

using KernelTypesStreamKBf16NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent, CompV4>
>;

using KernelTypesStreamKBf8NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, BF8, BF8, F32, BF16, I128, I128, I32, NonPersistent, CompV4>
>;

using KernelTypesStreamKFp8NonPersistentCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I128, I128, I32, NonPersistent, CompV4>
>;

// ============================= Mem Pipeline =============================

using KernelTypesStreamKFp16PersistentMem = ::testing::Types<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
enum struct GemmPipelineType
{
Mem,
CompV3
CompV3,
CompV4
};

template <GemmPipelineType PT, typename Problem>
Expand All @@ -32,6 +33,12 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompV3, Problem>
using pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;
};

template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
{
using pipeline = ck_tile::GemmPipelineAgBgCrCompV4<Problem>;
};

template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
Expand Down Expand Up @@ -101,8 +108,8 @@ class TestCkTileStreamK : public ::testing::Test
constexpr bool kPadK = PadK;
constexpr bool preshuffle = Preshuffle;

constexpr bool DoubleSmemBuffer = false;
constexpr int kBlockPerCu = 1;
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false;
constexpr int kBlockPerCu = 1;
constexpr bool StructuredSparsity = false;
constexpr bool NumWaveGroup = 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def to_dict(self) -> Dict:
class TraitConfig:
"""Represents the Trait Config section of a Tile Engine config"""

pipeline: List[str] = field(default_factory=lambda: ["compv3", "mem"])
pipeline: List[str] = field(default_factory=lambda: ["compv3", "compv4", "mem"])
epilogue: List[str] = field(default_factory=lambda: ["cshuffle"])
scheduler: List[str] = field(default_factory=lambda: ["intrawave"])
pad_m: List[bool] = field(default_factory=lambda: [False])
Expand Down
Loading