Skip to content

Commit 2067a2d

Browse files
De-couple intertwined compile passes for host/device code
Allow pipeline instantiations on the host side without having to constexpr conditionaly use them within the kernel. Also utilize the warning added to the host/unsupported amdgcn struct. Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
1 parent 2187669 commit 2067a2d

File tree

5 files changed

+53
-57
lines changed

5 files changed

+53
-57
lines changed

projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,22 @@ struct MmaPipelineBase
124124
template <typename VecTA, typename VecTB, typename VecTC>
125125
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
126126
{
127-
// TODO: c++20: Call template functions with MmaPipelineOptionFlags directly
128-
auto pre = Derived::template preApply<Flags_>(
129-
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
130-
Derived::execImpl(pre);
131-
return Derived::template postApply<Flags_>(std::move(pre));
127+
if constexpr(MmaOpTraits<typename Derived::FragWiseMmaOp>::IsSupported)
128+
{
129+
// TODO: c++20: Call template functions with MmaPipelineOptionFlags directly
130+
auto pre = Derived::template preApply<Flags_>(
131+
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
132+
Derived::execImpl(pre);
133+
return Derived::template postApply<Flags_>(std::move(pre));
134+
}
135+
else
136+
{
137+
// Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp)
138+
// Code should not reach here, but HOST/DEVICE compile passes are
139+
// weirdly intertwined and instead of having constexpr in the calling
140+
// site (tests) we do this. See also changes by this commit.
141+
return Derived::FragWiseMmaOp::exec({}, {}, {});
142+
}
132143
}
133144
};
134145

projects/composablekernel/include/ck_tile/core/arch/mma/mma_transforms.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ template <typename MmaOp, typename CompilerTarget, typename Enable = void>
4242
struct MmaTransformsDefaultSelector
4343
{
4444
using SelectedTransforms = MmaDefaultPassThroughTransforms;
45+
static_assert(CompilerTarget::TARGET_ID == amdgcn_target_id::HOST,
46+
"Device code should use another specialization.");
4547
};
4648

4749
#if CK_TILE_CONCEPTS

projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ template <typename ADataType,
3737
struct SparseMma : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::COMPRESS_A), // TODO: c++20: use MmaPipelineOptionFlags directly
3838
SparseMma<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp, MmaTransforms>>
3939
{
40-
static_assert(MmaOpTraits<MmaOp>::IsSupported && MmaOpTraits<MmaOp>::IsSparse);
4140
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::COMPRESS_A), // TODO: c++20: use MmaPipelineOptionFlags directly
4241
SparseMma<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp, MmaTransforms>>;
4342
// clang-format on
4443

44+
using FragWiseMmaOp = MmaOp; // Expose the selected MmaOp
45+
4546
// Calculate the uncompressed A vector type
4647
struct InternalAVecCalculator
4748
{

projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -185,30 +185,26 @@ __global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
185185
CompilerTarget,
186186
MmaOpFamily::SPARSE>::SelectedOp;
187187

188-
using MmaTraits = MmaOpTraits<MmaOp>;
188+
using Pipeline =
189+
SparseMma<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK, CompilerTarget>;
189190

190-
if constexpr(MmaTraits::IsSupported)
191-
{
192-
using Pipeline = SparseMma<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK, CompilerTarget>;
193-
194-
using AVecType = typename Pipeline::AVecType;
195-
using BVecType = typename Pipeline::BVecType;
196-
using CVecType = typename Pipeline::CVecType;
191+
using AVecType = typename Pipeline::AVecType;
192+
using BVecType = typename Pipeline::BVecType;
193+
using CVecType = typename Pipeline::CVecType;
197194

198-
static constexpr uint32_t kIters = WaveTileK / MmaOp::kK;
195+
static constexpr uint32_t kIters = WaveTileK / MmaOp::kK;
199196

200-
// Initialize the accumulator
201-
CVecType result = *reinterpret_cast<CVecType*>(c);
197+
// Initialize the accumulator
198+
CVecType result = *reinterpret_cast<CVecType*>(c);
202199

203-
// Accumulate input AxB over FragK/BlockK iterations
204-
for(uint32_t i = 0; i < kIters; ++i)
205-
{
206-
result = Pipeline::exec(
207-
*reinterpret_cast<AVecType*>(a), *reinterpret_cast<BVecType*>(b), result);
208-
}
209-
210-
*reinterpret_cast<CVecType*>(out) = result;
200+
// Accumulate input AxB over FragK/BlockK iterations
201+
for(uint32_t i = 0; i < kIters; ++i)
202+
{
203+
result = Pipeline::exec(
204+
*reinterpret_cast<AVecType*>(a), *reinterpret_cast<BVecType*>(b), result);
211205
}
206+
207+
*reinterpret_cast<CVecType*>(out) = result;
212208
}
213209

214210
// Live test on real hardware for sparse selection and execution.

projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,38 +23,24 @@ template <typename AType,
2323
__global__ void test_pipeline(void* a, void* b, void* c)
2424
{
2525
using CompilerTarget = decltype(get_compiler_target());
26-
using MmaOp = typename MmaDefaultSelector<AType, // TODO: c++20 MmaOpI MmaOp = typename
27-
// MmaDefaultSelector<ADataType,
28-
BType,
29-
CType,
30-
WaveTileM,
31-
WaveTileN,
32-
WaveTileK,
33-
CompilerTarget,
34-
MmaOpFamily::DENSE>::SelectedOp;
35-
36-
using MmaTraits = MmaOpTraits<MmaOp>;
37-
38-
if constexpr(MmaTraits::IsSupported)
39-
{
40-
using Pipeline = WaveWiseMma<AType,
41-
BType,
42-
CType,
43-
WaveTileM,
44-
WaveTileN,
45-
WaveTileK,
46-
MmaOpFamily::DENSE,
47-
MmaAccumPolicy::ROW_MAJOR,
48-
CompilerTarget>;
49-
50-
using AVecType = typename Pipeline::AVecType;
51-
using BVecType = typename Pipeline::BVecType;
52-
using CVecType = typename Pipeline::CVecType;
53-
54-
Pipeline::exec(*reinterpret_cast<AVecType(*)[Pipeline::FragsM][Pipeline::FragsK]>(a),
55-
*reinterpret_cast<BVecType(*)[Pipeline::FragsN][Pipeline::FragsK]>(b),
56-
*reinterpret_cast<CVecType(*)[Pipeline::FragsM][Pipeline::FragsN]>(c));
57-
}
26+
27+
using Pipeline = WaveWiseMma<AType,
28+
BType,
29+
CType,
30+
WaveTileM,
31+
WaveTileN,
32+
WaveTileK,
33+
MmaOpFamily::DENSE,
34+
MmaAccumPolicy::ROW_MAJOR,
35+
CompilerTarget>;
36+
37+
using AVecType = typename Pipeline::AVecType;
38+
using BVecType = typename Pipeline::BVecType;
39+
using CVecType = typename Pipeline::CVecType;
40+
41+
Pipeline::exec(*reinterpret_cast<AVecType(*)[Pipeline::FragsM][Pipeline::FragsK]>(a),
42+
*reinterpret_cast<BVecType(*)[Pipeline::FragsN][Pipeline::FragsK]>(b),
43+
*reinterpret_cast<CVecType(*)[Pipeline::FragsM][Pipeline::FragsN]>(c));
5844
}
5945

6046
TEST(WaveWiseMmaPipeline, testKIter)

0 commit comments

Comments
 (0)