Skip to content

Commit 79449e3

Browse files
committed
Refactor functions/types in memory.h
1 parent b1d8f9c commit 79449e3

File tree

6 files changed

+407
-248
lines changed

6 files changed

+407
-248
lines changed

examples/vector_add/main.cu

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ void cuda_check(cudaError_t code) {
1313
}
1414

1515
template<int N>
16-
__global__ void my_kernel(int length, const khalf<N>* input, double constant, kfloat<N>* output) {
16+
__global__ void my_kernel(int length, const __half* input, double constant, float* output) {
1717
int i = blockIdx.x * blockDim.x + threadIdx.x;
1818

1919
if (i * N < length) {
20-
kf::cast_to(output[i]) = (input[i] * input[i]) * constant;
20+
auto a = kf::read_aligned<N>(input + i * N);
21+
auto b = (a * a) * constant;
22+
kf::write_aligned(output + i * N, b);
2123
}
2224
}
2325

@@ -35,8 +37,8 @@ void run_kernel(int n) {
3537
}
3638

3739
// Allocate device memory
38-
khalf<items_per_thread>* input_dev;
39-
kfloat<items_per_thread>* output_dev;
40+
__half* input_dev;
41+
float* output_dev;
4042
cuda_check(cudaMalloc(&input_dev, sizeof(half) * n));
4143
cuda_check(cudaMalloc(&output_dev, sizeof(float) * n));
4244

@@ -47,7 +49,11 @@ void run_kernel(int n) {
4749
int block_size = 256;
4850
int items_per_block = block_size * items_per_thread;
4951
int grid_size = (n + items_per_block - 1) / items_per_block;
50-
my_kernel<items_per_thread><<<grid_size, block_size>>>(n, input_dev, constant, output_dev);
52+
my_kernel<items_per_thread><<<grid_size, block_size>>>(
53+
n,
54+
kf::aligned_ptr(input_dev),
55+
constant,
56+
kf::aligned_ptr(output_dev));
5157

5258
// Copy results back
5359
cuda_check(cudaMemcpy(output_dev, output_result.data(), sizeof(float) * n, cudaMemcpyDefault));

examples/vector_add_tiling/main.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ __global__ void my_kernel(
2727
auto points = int(blockIdx.x * tiling.tile_size(0)) + tiling.local_points(0);
2828
auto mask = tiling.local_mask();
2929

30-
auto a = kf::load(input.get(), points, mask);
30+
auto a = input.read(points, mask);
3131
auto b = (a * a) * constant;
32-
kf::store(b, output.get(), points, mask);
32+
output.write(points, b, mask);
3333
}
3434

3535
template<int items_per_thread, int block_size = 256>

include/kernel_float/macros.h

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,35 @@
99
#define KERNEL_FLOAT_IS_DEVICE (1)
1010
#define KERNEL_FLOAT_IS_HOST (0)
1111
#define KERNEL_FLOAT_CUDA_ARCH (__CUDA_ARCH__)
12-
#else
12+
#else // __CUDA_ARCH__
1313
#define KERNEL_FLOAT_INLINE __forceinline__ __host__
1414
#define KERNEL_FLOAT_IS_DEVICE (0)
1515
#define KERNEL_FLOAT_IS_HOST (1)
1616
#define KERNEL_FLOAT_CUDA_ARCH (0)
17-
#endif
18-
#else
17+
#endif // __CUDA_ARCH__
18+
#else // __CUDACC__
1919
#define KERNEL_FLOAT_INLINE inline
2020
#define KERNEL_FLOAT_CUDA (0)
2121
#define KERNEL_FLOAT_IS_HOST (1)
2222
#define KERNEL_FLOAT_IS_DEVICE (0)
2323
#define KERNEL_FLOAT_CUDA_ARCH (0)
24-
#endif
24+
#endif // __CUDACC__
2525

2626
#ifndef KERNEL_FLOAT_FP16_AVAILABLE
2727
#define KERNEL_FLOAT_FP16_AVAILABLE (1)
28-
#endif
28+
#endif // KERNEL_FLOAT_FP16_AVAILABLE
2929

3030
#ifndef KERNEL_FLOAT_BF16_AVAILABLE
3131
#define KERNEL_FLOAT_BF16_AVAILABLE (1)
32-
#endif
32+
#endif // KERNEL_FLOAT_BF16_AVAILABLE
3333

3434
#ifndef KERNEL_FLOAT_FP8_AVAILABLE
3535
#ifdef __CUDACC_VER_MAJOR__
3636
#define KERNEL_FLOAT_FP8_AVAILABLE (__CUDACC_VER_MAJOR__ >= 12)
37-
#else
37+
#else // __CUDACC_VER_MAJOR__
3838
#define KERNEL_FLOAT_FP8_AVAILABLE (0)
39-
#endif
40-
#endif
39+
#endif // __CUDACC_VER_MAJOR__
40+
#endif // KERNEL_FLOAT_FP8_AVAILABLE
4141

4242
#define KERNEL_FLOAT_ASSERT(expr) \
4343
do { \
@@ -49,4 +49,14 @@
4949
#define KERNEL_FLOAT_CONCAT(A, B) KERNEL_FLOAT_CONCAT_IMPL(A, B)
5050
#define KERNEL_FLOAT_CALL(F, ...) F(__VA_ARGS__)
5151

52+
// TOOD: check if this way is support across all compilers
53+
#if defined(__has_builtin) && __has_builtin(__builtin_assume_aligned) && 0
54+
#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) \
55+
static_cast<TYPE*>(__builtin_assume_aligned(static_cast<TYPE*>(PTR), (ALIGNMENT)))
56+
#else
57+
#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR)
58+
#endif
59+
60+
#define KERNEL_FLOAT_MAX_ALIGNMENT (32)
61+
5262
#endif //KERNEL_FLOAT_MACROS_H

0 commit comments

Comments
 (0)