Skip to content

Commit ed7143d

Browse files
committed
Change function signature of read/write_aligned
1 parent 79449e3 commit ed7143d

File tree

4 files changed

+363
-364
lines changed

4 files changed

+363
-364
lines changed

examples/vector_add/main.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ __global__ void my_kernel(int length, const __half* input, double constant, floa
1919
if (i * N < length) {
2020
auto a = kf::read_aligned<N>(input + i * N);
2121
auto b = (a * a) * constant;
22-
kf::write_aligned(output + i * N, b);
22+
kf::write_aligned<N>(output + i * N, b);
2323
}
2424
}
2525

include/kernel_float/memory.h

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,21 @@ constexpr size_t gcd(size_t a, size_t b) {
112112
return b == 0 ? a : gcd(b, a % b);
113113
}
114114

115-
template<typename T, size_t N, size_t alignment>
115+
template<typename T, size_t N, size_t alignment, typename = void>
116116
struct copy_aligned_impl {
117-
static constexpr size_t half = N > 8 ? 8 : (N > 4 ? 4 : (N > 2 ? 2 : 1));
118-
static constexpr size_t new_alignment = gcd(alignment, sizeof(T) * half);
117+
static constexpr size_t K = N > 8 ? 8 : (N > 4 ? 4 : (N > 2 ? 2 : 1));
118+
static constexpr size_t alignment_K = gcd(alignment, sizeof(T) * K);
119119

120120
KERNEL_FLOAT_INLINE
121121
static void load(T* output, const T* input) {
122-
copy_aligned_impl<T, half, new_alignment>::load(output, input);
123-
copy_aligned_impl<T, N - half, new_alignment>::load(output + half, input + half);
122+
copy_aligned_impl<T, K, alignment>::load(output, input);
123+
copy_aligned_impl<T, N - K, alignment_K>::load(output + K, input + K);
124124
}
125125

126126
KERNEL_FLOAT_INLINE
127127
static void store(T* output, const T* input) {
128-
copy_aligned_impl<T, half, new_alignment>::store(output, input);
129-
copy_aligned_impl<T, N - half, new_alignment>::store(output + half, input + half);
128+
copy_aligned_impl<T, K, alignment>::store(output, input);
129+
copy_aligned_impl<T, N - K, alignment_K>::store(output + K, input + K);
130130
}
131131
};
132132

@@ -141,6 +141,8 @@ struct copy_aligned_impl<T, 0, alignment> {
141141

142142
template<typename T, size_t alignment>
143143
struct copy_aligned_impl<T, 1, alignment> {
144+
using storage_type = T;
145+
144146
KERNEL_FLOAT_INLINE
145147
static void load(T* output, const T* input) {
146148
output[0] = input[0];
@@ -153,9 +155,9 @@ struct copy_aligned_impl<T, 1, alignment> {
153155
};
154156

155157
template<typename T, size_t alignment>
156-
struct copy_aligned_impl<T, 2, alignment> {
157-
static constexpr size_t new_alignment = gcd(alignment, 2 * sizeof(T));
158-
struct alignas(new_alignment) storage_type {
158+
struct copy_aligned_impl<T, 2, alignment, enable_if_t<(alignment > sizeof(T))>> {
159+
static constexpr size_t storage_alignment = gcd(alignment, 2 * sizeof(T));
160+
struct alignas(storage_alignment) storage_type {
159161
T v0, v1;
160162
};
161163

@@ -173,9 +175,9 @@ struct copy_aligned_impl<T, 2, alignment> {
173175
};
174176

175177
template<typename T, size_t alignment>
176-
struct copy_aligned_impl<T, 4, alignment> {
177-
static constexpr size_t new_alignment = gcd(alignment, 4 * sizeof(T));
178-
struct alignas(new_alignment) storage_type {
178+
struct copy_aligned_impl<T, 4, alignment, enable_if_t<(alignment > 2 * sizeof(T))>> {
179+
static constexpr size_t storage_alignment = gcd(alignment, 4 * sizeof(T));
180+
struct alignas(storage_alignment) storage_type {
179181
T v0, v1, v2, v3;
180182
};
181183

@@ -199,9 +201,9 @@ struct copy_aligned_impl<T, 4, alignment> {
199201
};
200202

201203
template<typename T, size_t alignment>
202-
struct copy_aligned_impl<T, 8, alignment> {
203-
static constexpr size_t new_alignment = gcd(alignment, 8 * sizeof(T));
204-
struct alignas(new_alignment) storage_type {
204+
struct copy_aligned_impl<T, 8, alignment, enable_if_t<(alignment > 4 * sizeof(T))>> {
205+
static constexpr size_t storage_alignment = gcd(alignment, 8 * sizeof(T));
206+
struct alignas(storage_alignment) storage_type {
205207
T v0, v1, v2, v3, v4, v5, v6, v7;
206208
};
207209

@@ -248,9 +250,9 @@ struct copy_aligned_impl<T, 8, alignment> {
248250
* vec<T, 4> values2 = read_aligned<4>(data + 10);
249251
* ```
250252
*/
251-
template<size_t N, typename T>
253+
template<size_t Align, size_t N = Align, typename T>
252254
KERNEL_FLOAT_INLINE vector<T, extent<N>> read_aligned(const T* ptr) {
253-
static constexpr size_t alignment = detail::gcd(N * sizeof(T), KERNEL_FLOAT_MAX_ALIGNMENT);
255+
static constexpr size_t alignment = detail::gcd(Align * sizeof(T), KERNEL_FLOAT_MAX_ALIGNMENT);
254256
vector_storage<T, N> result;
255257
detail::copy_aligned_impl<T, N, alignment>::load(
256258
result.data(),
@@ -273,10 +275,10 @@ KERNEL_FLOAT_INLINE vector<T, extent<N>> read_aligned(const T* ptr) {
273275
* write_aligned(data + 10, values);
274276
* ```
275277
*/
276-
template<typename V, typename T>
278+
template<size_t Align, typename V, typename T>
277279
KERNEL_FLOAT_INLINE void write_aligned(T* ptr, const V& values) {
278280
static constexpr size_t N = vector_extent<V>;
279-
static constexpr size_t alignment = detail::gcd(N * sizeof(T), KERNEL_FLOAT_MAX_ALIGNMENT);
281+
static constexpr size_t alignment = detail::gcd(Align * sizeof(T), KERNEL_FLOAT_MAX_ALIGNMENT);
280282

281283
return detail::copy_aligned_impl<T, N, alignment>::store(
282284
KERNEL_FLOAT_ASSUME_ALIGNED(T, ptr, alignment),

0 commit comments

Comments
 (0)