Skip to content

Commit e2e5634

Browse files
Addressed some review comments
Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
1 parent 12a70a0 commit e2e5634

File tree

6 files changed

+16
-26
lines changed

6 files changed

+16
-26
lines changed

projects/composablekernel/include/ck_tile/core.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
2929
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
3030
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
31+
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
3132
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
3233
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
3334
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
3435
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
3536
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
36-
#include "ck_tile/core/arch/mma/sparse_mma.hpp"
3737
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
3838
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
3939
#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp"

projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,14 @@
88
#include "mma_selector.hpp"
99
#include "mma_traits.hpp"
1010
#include "mma_transforms.hpp"
11-
#include <netdb.h>
12-
#include <optional>
13-
#include <type_traits>
1411

1512
namespace ck_tile::core::arch::mma {
1613

1714
enum struct MmaPipelineOptionFlag
1815
{
19-
NONE = 0x0,
20-
C_TRANSPOSE = 0x1,
21-
SWIZZLE_A = 0x2,
22-
SWIZZLE_B = 0x4,
23-
DOUBLE_ATTR_NUM_ACCESS = 0x8,
24-
QUAD_ATTR_NUM_ACCESS = 0x10,
25-
COMPRESS_A = 0x20,
16+
NONE = 0x0,
17+
C_TRANSPOSE = 0x1,
18+
COMPRESS_A = 0x2,
2619
};
2720

2821
struct MmaPipelineOptionFlags
@@ -81,12 +74,6 @@ struct MmaPipelineBase
8174
static constexpr auto Flags = MmaPipelineOptionFlags(Flags_);
8275
// TODO: Implement those cases
8376
static_assert(!(Flags & MmaPipelineOptionFlag::C_TRANSPOSE), "Flag not yet implemented");
84-
static_assert(!(Flags & MmaPipelineOptionFlag::SWIZZLE_A), "Flag not yet implemented");
85-
static_assert(!(Flags & MmaPipelineOptionFlag::SWIZZLE_B), "Flag not yet implemented");
86-
static_assert(!(Flags & MmaPipelineOptionFlag::DOUBLE_ATTR_NUM_ACCESS),
87-
"Flag not yet implemented");
88-
static_assert(!(Flags & MmaPipelineOptionFlag::QUAD_ATTR_NUM_ACCESS),
89-
"Flag not yet implemented");
9077

9178
private:
9279
template <typename DstT, typename SrcT>
@@ -106,7 +93,7 @@ struct MmaPipelineBase
10693

10794
protected:
10895
template <MmaPipelineOptionFlag Flag>
109-
CK_TILE_DEVICE static bool hasFlag()
96+
constexpr CK_TILE_DEVICE static bool hasFlag()
11097
{
11198
return Flags & Flag;
11299
}

projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ struct WaveWiseMma : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFl
8686
constexpr static uint32_t FragM = MmaOp::kM;
8787
constexpr static uint32_t FragN = MmaOp::kN;
8888
constexpr static uint32_t FragK = MmaOp::kK;
89-
89+
9090
using BlockWiseMmaOp = MmaOp;
9191
using BlockWiseMmaOpTraits = MmaOpTraits<BlockWiseMmaOp>;
9292

@@ -174,7 +174,7 @@ struct WaveWiseMma : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFl
174174
}
175175
}
176176
}
177-
else
177+
else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
178178
{
179179
// "Col-major" accumulation over the M-dimension blocks first.
180180
// Pseudo code here, but we would basically iterate over the blocks in col-major order
@@ -190,6 +190,10 @@ struct WaveWiseMma : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFl
190190
}
191191
}
192192
}
193+
else
194+
{
195+
static_assert(false);
196+
}
193197
}
194198
};
195199

projects/composablekernel/include/ck_tile/core/arch/mma/sparse_mma.hpp renamed to projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp

File renamed without changes.

projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
4141
int32_t non_zero_pos = 0;
4242

4343
static_for<0, 3, 1>{}([&](auto j) {
44-
if(a_vec[i * 4 + j] != 0.0f)
44+
if(static_cast<float>(a_vec[i * 4 + j]) != 0.0f)
4545
{
4646
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
4747
// clear the two‑bit field for this output and insert j
@@ -68,17 +68,16 @@ struct SparseCompressTransform
6868
template <typename VecType>
6969
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v, int32_t& idx)
7070
{
71-
using VecTraits = vector_traits<std::decay_t<VecType>>;
71+
using VecTraits = vector_traits<remove_cvref_t<VecType>>;
7272
using ScalarT = typename VecTraits::scalar_type;
7373
static constexpr auto VecN = VecTraits::vector_size;
7474
static constexpr index_t CompressedSize = VecN / CompressionRatio;
7575
using VecCompressed = ext_vector_t<ScalarT, CompressedSize>;
7676

7777
idx = detail::compress_a_impl<ScalarT, CompressedSize>(v);
7878

79-
VecCompressed result;
80-
__builtin_memcpy(&result, &v, sizeof(VecCompressed));
81-
return result;
79+
// TODO c++20: Use bit_cast
80+
return *std::launder(reinterpret_cast<VecCompressed*>(&v));
8281
}
8382
};
8483

projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
99
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
1010
#include "ck_tile/core/arch/mma/mma_selector.hpp"
11-
#include "ck_tile/core/arch/mma/sparse_mma.hpp"
11+
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
1212
#include <hip/hip_runtime.h>
1313
#include "ck_tile/core/numeric/integer.hpp"
1414
#include "ck_tile/host/hip_check_error.hpp"

0 commit comments

Comments
 (0)