Skip to content

Commit 931c392

Browse files
committed
2:4 sparse for int8/fp8/bf16/fp16 dtype
1 parent c27a341 commit 931c392

File tree

9 files changed

+2175
-0
lines changed

9 files changed

+2175
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
16+
#pragma once
17+
18+
#include <cute/tensor.hpp>
19+
namespace cute {
20+
21+
////////////////////////////////////////////////////////////////////
22+
// layout utils
23+
////////////////////////////////////////////////////////////////////
24+
25+
// Permute layout based on indices, example:
26+
// permute_layout<1, 0>(layout) will swap the two dimensions
27+
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
28+
template <size_t... I, typename Layout>
29+
CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
30+
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
31+
return cute::make_layout(cute::get<I>(l)...);
32+
}
33+
34+
// is the layout f(x) = x
35+
template <typename Layout>
36+
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
37+
if constexpr (std::is_same_v<Layout, void>) {
38+
return true;
39+
} else {
40+
constexpr auto coalesced_layout = coalesce(Layout{});
41+
if constexpr (rank(coalesced_layout) == 1 &&
42+
stride<0>(coalesced_layout) == 1) {
43+
return true;
44+
}
45+
return false;
46+
}
47+
}
48+
49+
////////////////////////////////////////////////////////////////////
50+
// Pointer utils
51+
////////////////////////////////////////////////////////////////////
52+
53+
template <class PointerType>
54+
static constexpr auto get_logical_ptr(PointerType* ptr) {
55+
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
56+
return cute::subbyte_iterator<PointerType>(ptr);
57+
} else {
58+
return ptr;
59+
}
60+
}
61+
62+
////////////////////////////////////////////////////////////////////
63+
// Misc utils
64+
////////////////////////////////////////////////////////////////////
65+
66+
template <typename T, typename Elements>
67+
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
68+
constexpr auto bits = sizeof_bits_v<T> * Elements{};
69+
if constexpr (bits % 128 == 0) {
70+
return AutoVectorizingCopyWithAssumedAlignment<128>{};
71+
} else if constexpr (bits % 64 == 0) {
72+
return AutoVectorizingCopyWithAssumedAlignment<64>{};
73+
} else if constexpr (bits % 32 == 0) {
74+
return AutoVectorizingCopyWithAssumedAlignment<32>{};
75+
} else if constexpr (bits % 16 == 0) {
76+
return AutoVectorizingCopyWithAssumedAlignment<16>{};
77+
} else {
78+
return AutoVectorizingCopyWithAssumedAlignment<8>{};
79+
}
80+
}
81+
82+
}; // namespace cute

0 commit comments

Comments
 (0)