Skip to content

Commit 6d967a5

Browse files
Defined the sparse pipeline and used it in test
Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
1 parent 55c08fe commit 6d967a5

File tree

3 files changed

+150
-22
lines changed

3 files changed

+150
-22
lines changed

projects/composablekernel/include/ck_tile/core/arch/mma/mma_selector.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,4 @@ concept MmaSelectorI = requires(MmaSelector op) {
7373
// Include the implementations
7474
#include "wmma/wmma_selector.hpp"
7575
#include "mfma/mfma_selector.hpp"
76+
#include "sparse/sparse_selector.hpp"
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
#pragma once
4+
5+
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
6+
#include "ck_tile/core/arch/mma/mma_selector.hpp"
7+
#include "ck_tile/core/arch/mma/mma_traits.hpp"
8+
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
9+
#include "ck_tile/core/numeric/vector_type.hpp"
10+
#include <cstdint>
11+
#include <type_traits>
12+
13+
namespace ck_tile::core::arch::mma {
14+
15+
template <typename ADataType,
16+
typename BDataType,
17+
typename CDataType,
18+
uint32_t FragM,
19+
uint32_t FragN,
20+
uint32_t FragK,
21+
typename CompilerTarget =
22+
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
23+
// get_compiler_target(),
24+
typename MmaOp =
25+
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
26+
// MmaDefaultSelector<ADataType,
27+
BDataType,
28+
CDataType,
29+
FragM,
30+
FragN,
31+
FragK,
32+
CompilerTarget,
33+
MmaOpFamily::SPARSE>::SelectedOp,
34+
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
35+
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
36+
struct SparseMma : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::COMPRESS_A),
37+
SparseMma<ADataType,
38+
BDataType,
39+
CDataType,
40+
FragM,
41+
FragN,
42+
FragK,
43+
CompilerTarget,
44+
MmaOp,
45+
MmaTransforms>>
46+
{
47+
static_assert(MmaOpTraits<MmaOp>::IsSupported && MmaOpTraits<MmaOp>::IsSparse);
48+
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::COMPRESS_A),
49+
SparseMma<ADataType,
50+
BDataType,
51+
CDataType,
52+
FragM,
53+
FragN,
54+
FragK,
55+
CompilerTarget,
56+
MmaOp,
57+
MmaTransforms>>;
58+
59+
using MmaOpTraits = MmaOpTraits<MmaOp>;
60+
61+
// Expose B and C vector types from the traits
62+
using BVecType = typename MmaOpTraits::BVecType;
63+
using CVecType = typename MmaOpTraits::CVecType;
64+
65+
// Calculate the uncompressed A vector type
66+
static constexpr auto COMPRESSION_RATIO = 2;
67+
using ATraits = vector_traits<typename MmaOp::AVecType>;
68+
static constexpr index_t ASize = ATraits::vector_size * COMPRESSION_RATIO;
69+
using AVecType = ext_vector_t<typename ATraits::scalar_type, ASize>;
70+
71+
// Transforms
72+
using ATransform = typename MmaTransforms::ATransform;
73+
using BTransform = typename MmaTransforms::BTransform;
74+
using CTransform = typename MmaTransforms::CTransform;
75+
using DTransform = typename MmaTransforms::DTransform;
76+
77+
template <MmaPipelineOptionFlags::Type Flags, typename VecTA, typename VecTB, typename VecTC>
78+
CK_TILE_DEVICE static decltype(auto) preApply(VecTA&& a, VecTB&& b, VecTC&& accum)
79+
{
80+
static_assert(Flags == MmaPipelineOptionFlags(MmaPipelineOptionFlag::COMPRESS_A));
81+
static_assert(std::is_same_v<ATransform, SparseCompressTransform<2>>);
82+
83+
int32_t idx{};
84+
auto a_frag =
85+
Base::template preApplyTransform<typename MmaOpTraits::AVecType, ATransform>(a, idx);
86+
auto b_frag =
87+
Base::template preApplyTransform<typename MmaOpTraits::BVecType, BTransform>(b);
88+
auto c_frag =
89+
Base::template preApplyTransform<typename MmaOpTraits::CVecType, CTransform>(accum);
90+
91+
return std::make_tuple(a_frag, b_frag, c_frag, idx);
92+
}
93+
94+
template <MmaPipelineOptionFlags::Type Flags, typename VecTA, typename VecTB, typename VecTC>
95+
CK_TILE_DEVICE static decltype(auto) postApply(std::tuple<VecTA, VecTB, VecTC, int32_t>& vecs)
96+
{
97+
static_assert(Flags == MmaPipelineOptionFlags(MmaPipelineOptionFlag::COMPRESS_A));
98+
99+
auto& [a_frag, b_frag, c_frag, idx] = vecs;
100+
// Convert native vector results back to the output fragment format
101+
// and then return after we apply the final output transform.
102+
return Base::template postApplyTransform<std::decay_t<VecTC>, DTransform>(c_frag);
103+
}
104+
105+
template <typename VecTA, typename VecTB, typename VecTC>
106+
CK_TILE_DEVICE static void execImpl(std::tuple<VecTA, VecTB, VecTC, int32_t>& vecs)
107+
{
108+
109+
static_assert(MmaOpTraits::IsSupported);
110+
static_assert(MmaOpTraits::IsSparse);
111+
if constexpr(MmaOpTraits::IsSupported && MmaOpTraits::IsSparse)
112+
{
113+
auto& [a_frag, b_frag, c_frag, idx] = vecs;
114+
MmaOp::exec(a_frag, b_frag, c_frag, idx);
115+
}
116+
}
117+
};
118+
119+
} // namespace ck_tile::core::arch::mma

projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
99
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
1010
#include "ck_tile/core/arch/mma/mma_selector.hpp"
11+
#include "ck_tile/core/arch/mma/sparse_mma.hpp"
1112
#include <hip/hip_runtime.h>
13+
#include "ck_tile/core/numeric/integer.hpp"
1214
#include "ck_tile/host/hip_check_error.hpp"
1315
#include "ck_tile/core/arch/mma/mma_traits.hpp"
1416
#include "ck_tile/core/utility/type_traits.hpp"
@@ -150,34 +152,40 @@ template <typename AType,
150152
__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
151153
{
152154
using CompilerTarget = decltype(get_compiler_target());
153-
using Selector = MmaDefaultSelector<AType,
154-
BType,
155-
CType,
156-
FragM,
157-
FragN,
158-
FragK,
159-
CompilerTarget,
160-
MmaOpFamily::SPARSE>;
161-
162-
using MmaOp = typename Selector::SelectedOp;
155+
using MmaOp = typename MmaDefaultSelector<AType, // TODO: c++20 MmaOpI MmaOp = typename
156+
// MmaDefaultSelector<ADataType,
157+
BType,
158+
CType,
159+
FragM,
160+
FragN,
161+
FragK,
162+
CompilerTarget,
163+
MmaOpFamily::SPARSE>::SelectedOp;
164+
163165
using MmaTraits = MmaOpTraits<MmaOp>;
164166

165-
using CVecType = typename MmaOp::CVecType;
167+
if constexpr(MmaTraits::IsSupported)
168+
{
169+
using Pipeline = SparseMma<AType, BType, CType, FragM, FragN, FragK, CompilerTarget>;
166170

167-
static constexpr uint32_t kIters = FragK / MmaTraits::BlockK;
171+
using AVecType = typename Pipeline::AVecType;
172+
using BVecType = typename Pipeline::BVecType;
173+
using CVecType = typename Pipeline::CVecType;
168174

169-
// Initialize the accumulator
170-
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
175+
static constexpr uint32_t kIters = FragK / MmaTraits::BlockK;
171176

172-
// Accumulate input AxB over FragK/BlockK iterations
173-
for(uint32_t i = 0; i < kIters; ++i)
174-
{
175-
result = MmaOp::exec(*reinterpret_cast<typename MmaOp::AVecType*>(a),
176-
*reinterpret_cast<typename MmaOp::BVecType*>(b),
177-
result);
178-
}
177+
// Initialize the accumulator
178+
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
179179

180-
*reinterpret_cast<typename MmaOp::CVecType*>(out) = result;
180+
// Accumulate input AxB over FragK/BlockK iterations
181+
for(uint32_t i = 0; i < kIters; ++i)
182+
{
183+
result = Pipeline::exec(
184+
*reinterpret_cast<AVecType*>(a), *reinterpret_cast<BVecType*>(b), result);
185+
}
186+
187+
*reinterpret_cast<typename MmaOp::CVecType*>(out) = result;
188+
}
181189
}
182190

183191
// Live test on real hardware for sparse selection and execution.

0 commit comments

Comments
 (0)