Skip to content

Commit cac92e7

Browse files
committed
Extend implementation and documentation of vector_ptr
1 parent 4b08356 commit cac92e7

File tree

3 files changed

+496
-216
lines changed

3 files changed

+496
-216
lines changed

examples/vector_add/main.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ void cuda_check(cudaError_t code) {
1515
template<int N>
1616
__global__ void my_kernel(
1717
int length,
18-
kf::vec_ptr<const __half, N> input,
18+
kf::vec_ptr<const half, N> input,
1919
double constant,
20-
kf::vec_ptr<float, N> output) {
20+
kf::vec_ptr<half, N, float> output) {
2121
int i = blockIdx.x * blockDim.x + threadIdx.x;
2222

2323
if (i * N < length) {
@@ -53,9 +53,9 @@ void run_kernel(int n) {
5353
int grid_size = (n + items_per_block - 1) / items_per_block;
5454
my_kernel<items_per_thread><<<grid_size, block_size>>>(
5555
n,
56-
kf::vector_ptr<const half, items_per_thread>(input_dev),
56+
kf::assert_aligned(input_dev),
5757
constant,
58-
kf::vector_ptr<float, items_per_thread>(output_dev));
58+
kf::assert_aligned(output_dev));
5959

6060
// Copy results back
6161
cuda_check(cudaMemcpy(output_dev, output_result.data(), sizeof(float) * n, cudaMemcpyDefault));

include/kernel_float/memory.h

Lines changed: 172 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -285,34 +285,67 @@ KERNEL_FLOAT_INLINE void write_aligned(T* ptr, const V& values) {
285285
convert_storage<T, N>(values).data());
286286
}
287287

288+
/**
289+
* @brief A reference wrapper that allows reading/writing a vector of type `T`and length `N` with optional data
290+
* conversion.
291+
*
292+
* @tparam T The type of the elements as seen from the user's perspective.
293+
* @tparam N The number of elements in the vector.
294+
* @tparam U The underlying storage type. Defaults to the same type as T.
295+
* @tparam Align The alignment constraint for read and write operations.
296+
*/
288297
template<typename T, size_t N, typename U = T, size_t Align = 1>
289298
struct vector_ref {
290299
using pointer_type = U*;
291300
using value_type = decay_t<T>;
292301
using vector_type = vector<value_type, extent<N>>;
293302

303+
/**
304+
* Constructs a vector_ref to manage access to a raw data pointer.
305+
*
306+
* @param data Pointer to the raw data this vector_ref will manage.
307+
*/
294308
KERNEL_FLOAT_INLINE explicit vector_ref(pointer_type data) : data_(data) {}
295309

310+
/**
311+
* Reads data from the underlying raw pointer, converting it to type `T`.
312+
*
313+
* @return vector_type A vector of type vector_type containing the read and converted data.
314+
*/
296315
KERNEL_FLOAT_INLINE vector_type read() const {
297316
return convert<value_type, N>(read_aligned<Align, N>(data_));
298317
}
299318

319+
/**
320+
* Writes data to the underlying raw pointer, converting it from the input vector if necessary.
321+
*
322+
* @tparam V The type of the input vector, defaults to `T`.
323+
* @param values The values to be written.
324+
*/
300325
template<typename V = vector_type>
301326
KERNEL_FLOAT_INLINE void write(const V& values) const {
302-
U* x = data_;
303-
write_aligned<Align>(x, convert<U, N>(values));
327+
write_aligned<Align>(data_, convert<U, N>(values));
304328
}
305329

330+
/**
331+
* Conversion operator that is shorthand for `read()`.
332+
*/
306333
KERNEL_FLOAT_INLINE operator vector_type() const {
307334
return read();
308335
}
309336

337+
/**
338+
* Assignment operator that is shorthand for `write(values)`.
339+
*/
310340
template<typename V>
311341
KERNEL_FLOAT_INLINE vector_ref operator=(const V& values) const {
312342
write(values);
313343
return *this;
314344
}
315345

346+
/**
347+
* Gets the raw data pointer managed by this vector_ref
348+
*/
316349
KERNEL_FLOAT_INLINE pointer_type get() const {
317350
return data_;
318351
}
@@ -321,6 +354,9 @@ struct vector_ref {
321354
pointer_type data_ = nullptr;
322355
};
323356

357+
/**
358+
* Specialization for `vector_ref` if the backing storage is const.
359+
*/
324360
template<typename T, size_t N, typename U, size_t Align>
325361
struct vector_ref<T, N, const U, Align> {
326362
using pointer_type = const U*;
@@ -351,43 +387,103 @@ struct vector_ref<T, N, const U, Align> {
351387
vector_ref<T, N, U, Align> ptr, \
352388
const V& value) { \
353389
ptr.write(ptr.read() OP value); \
390+
return ptr; \
354391
}
355392

356393
KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(+, +=)
357394
KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(-, -=)
358395
KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(*, *=)
359396
KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(/, /=)
360397

361-
template<typename T, size_t Align, typename U = T>
398+
/**
399+
* A wrapper for a pointer that enables vectorized access and supports type conversions..
400+
*
401+
* The `vector_ptr<T, N, U>` type is designed to function as if its a `vec<T, N>*` pointer, allowing of reading and
402+
* writing `vec<T, N>` elements. However, the actual type of underlying storage is a pointer of type `U*`, where
403+
* automatic conversion is performed between `T` and `U` when reading/writing items.
404+
*
405+
* For example, a `vector_ptr<double, N, half>` is useful where the data is stored in low precision (here 16 bit)
406+
* but it should be accessed as if it was in a higher precision format (here 64 bit).
407+
*
408+
* @tparam T The type of the elements as viewed by the user.
409+
* @tparam N The alignment of T in number of elements.
410+
* @tparam U The underlying storage type, defaults to T.
411+
*/
412+
template<typename T, size_t N, typename U = T>
362413
struct vector_ptr {
363414
using pointer_type = U*;
364415
using value_type = decay_t<T>;
365416

417+
/**
418+
* Default constructor sets the pointer to `NULL`.
419+
*/
366420
vector_ptr() = default;
421+
422+
/**
423+
* Constructor from a given pointer. It is up to the user to assert that the pointer is aligned to `Align` elements.
424+
*/
367425
KERNEL_FLOAT_INLINE explicit vector_ptr(pointer_type p) : data_(p) {}
368426

369-
template<typename T2>
370-
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, Align, U> p) : data_(p.get()) {}
427+
/**
428+
* Constructs a vector_ptr from another vector_ptr with potentially different alignment and type. This constructor
429+
* only allows conversion if the alignment of the source is greater than or equal to the alignment of the target.
430+
*/
431+
template<typename T2, size_t N2>
432+
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, N2, U> p, enable_if_t<(N2 >= N), int> = {}) :
433+
data_(p.get()) {}
434+
435+
/**
436+
* Accesses a reference to a vector at a specific index with optional alignment considerations.
437+
*
438+
* @tparam N The number of elements in the vector to access, defaults to the alignment.
439+
* @param index The index at which to access the vector.
440+
*/
441+
template<size_t K = N>
442+
KERNEL_FLOAT_INLINE vector_ref<T, K, U, N> at(size_t index) const {
443+
return vector_ref<T, K, U, N> {data_ + index * N};
444+
}
371445

372-
template<size_t N = Align>
373-
KERNEL_FLOAT_INLINE vector_ref<T, N, U, Align> at(size_t index) const {
374-
return vector_ref<T, N, U, Align> {data_ + index * Align};
446+
/**
447+
* Accesses a vector at a specific index.
448+
*
449+
* @tparam K The number of elements to read, defaults to `N`.
450+
* @param index The index from which to read the data.
451+
*/
452+
template<size_t K = N>
453+
KERNEL_FLOAT_INLINE vector<value_type, extent<K>> read(size_t index) const {
454+
return this->template at<K>(index).read();
375455
}
376456

377-
KERNEL_FLOAT_INLINE vector<value_type, extent<Align>> operator[](size_t index) const {
378-
return this->template at<Align>(index).read();
457+
/**
458+
* Shorthand for `read(index)`.
459+
*/
460+
KERNEL_FLOAT_INLINE vector<value_type, extent<N>> operator[](size_t index) const {
461+
return read(index);
379462
}
380463

381-
template<size_t N = Align>
382-
KERNEL_FLOAT_INLINE vector<value_type, extent<N>> read(size_t index = 0) const {
383-
return this->template at<N>(index).read();
464+
/**
465+
* Shorthand for `read(0)`.
466+
*/
467+
KERNEL_FLOAT_INLINE vector<value_type, extent<N>> operator*() const {
468+
return read(0);
384469
}
385470

386-
template<size_t N = Align, typename V>
471+
/**
472+
* @brief Writes data to a specific index.
473+
*
474+
* @tparam K The number of elements to write, defaults to `N`.
475+
* @tparam V The type of the values being written.
476+
* @param index The index at which to write the data.
477+
* @param values The vector of values to write.
478+
*/
479+
template<size_t K = N, typename V>
387480
KERNEL_FLOAT_INLINE void write(size_t index, const V& values) const {
388-
this->template at<N>(index).write(values);
481+
this->template at<K>(index).write(values);
389482
}
390483

484+
/**
485+
* Gets the raw data pointer managed by this `vector_ptr`.
486+
*/
391487
KERNEL_FLOAT_INLINE pointer_type get() const {
392488
return data_;
393489
}
@@ -396,29 +492,42 @@ struct vector_ptr {
396492
pointer_type data_ = nullptr;
397493
};
398494

399-
template<typename T, size_t Align, typename U>
400-
struct vector_ptr<T, Align, const U> {
495+
/**
496+
* Specialization for `vector_ptr` if the backing storage is const.
497+
*/
498+
template<typename T, size_t N, typename U>
499+
struct vector_ptr<T, N, const U> {
401500
using pointer_type = const U*;
402501
using value_type = decay_t<T>;
403502

404503
vector_ptr() = default;
405504
KERNEL_FLOAT_INLINE explicit vector_ptr(pointer_type p) : data_(p) {}
406505

407-
template<typename T2>
408-
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, Align, U> p) : data_(p.get()) {}
506+
template<typename T2, size_t N2>
507+
KERNEL_FLOAT_INLINE
508+
vector_ptr(vector_ptr<T2, N2, const U> p, enable_if_t<(N2 >= N), int> = {}) :
509+
data_(p.get()) {}
510+
511+
template<typename T2, size_t N2>
512+
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, N2, U> p, enable_if_t<(N2 >= N), int> = {}) :
513+
data_(p.get()) {}
409514

410-
template<size_t N = Align>
411-
KERNEL_FLOAT_INLINE vector_ref<T, N, const U, Align> at(size_t index) const {
412-
return vector_ref<T, N, const U, Align> {data_ + index * Align};
515+
template<size_t K = N>
516+
KERNEL_FLOAT_INLINE vector_ref<T, K, const U, N> at(size_t index) const {
517+
return vector_ref<T, K, const U, N> {data_ + index * N};
413518
}
414519

415-
KERNEL_FLOAT_INLINE vector<value_type, extent<Align>> operator[](size_t index) const {
416-
return this->template at<Align>(index).read();
520+
template<size_t K = N>
521+
KERNEL_FLOAT_INLINE vector<value_type, extent<K>> read(size_t index = 0) const {
522+
return this->template at<K>(index).read();
417523
}
418524

419-
template<size_t N = Align>
420-
KERNEL_FLOAT_INLINE vector<value_type, extent<N>> read(size_t index = 0) const {
421-
return this->template at<N>(index).read();
525+
KERNEL_FLOAT_INLINE vector<value_type, extent<N>> operator[](size_t index) const {
526+
return read(index);
527+
}
528+
529+
KERNEL_FLOAT_INLINE vector<value_type, extent<N>> operator*() const {
530+
return read(0);
422531
}
423532

424533
KERNEL_FLOAT_INLINE pointer_type get() const {
@@ -429,6 +538,42 @@ struct vector_ptr<T, Align, const U> {
429538
pointer_type data_ = nullptr;
430539
};
431540

541+
template<typename T, size_t N, typename U>
542+
KERNEL_FLOAT_INLINE vector_ptr<T, N, U> operator+(vector_ptr<T, N, U> p, size_t i) {
543+
return vector_ptr<T, N, U> {p.get() + i * N};
544+
}
545+
546+
template<typename T, size_t N, typename U>
547+
KERNEL_FLOAT_INLINE vector_ptr<T, N, U> operator+(size_t i, vector_ptr<T, N, U> p) {
548+
return p + i;
549+
}
550+
551+
/**
552+
* Creates a `vector_ptr<T, N>` from a raw pointer `U*` by asserting a specific alignment `N`.
553+
*
554+
* @tparam T The type of the elements as viewed by the user. This type may differ from `U`.
555+
* @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT.
556+
* @tparam U The type of the elements pointed to by the raw pointer.
557+
*/
558+
template<typename T, size_t N = KERNEL_FLOAT_MAX_ALIGNMENT, typename U>
559+
KERNEL_FLOAT_INLINE vector_ptr<T, N, U> assert_aligned(U* ptr) {
560+
return vector_ptr<T, N, U> {ptr};
561+
}
562+
563+
// Doxygen cannot deal with the `assert_aligned` being defined twice, we ignore the second definition.
564+
/// @cond IGNORE
565+
/**
566+
* Creates a `vector_ptr<T, N>` from a raw pointer `T*` by asserting a specific alignment `N`.
567+
*
568+
* @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT.
569+
* @tparam T The type of the elements pointed to by the raw pointer.
570+
*/
571+
template<size_t N = KERNEL_FLOAT_MAX_ALIGNMENT, typename T>
572+
KERNEL_FLOAT_INLINE vector_ptr<T, N> assert_aligned(T* ptr) {
573+
return vector_ptr<T, N> {ptr};
574+
}
575+
/// @endcond
576+
432577
template<typename T, size_t N = 1, typename U = T>
433578
using vec_ptr = vector_ptr<T, N, U>;
434579

0 commit comments

Comments
 (0)