Skip to content

Commit aaf8645

Browse files
committed
Fix bugs in aligned_ptr
1 parent 4b68f09 commit aaf8645

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

include/kernel_float/memory.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,6 @@ KERNEL_FLOAT_INLINE void storen(const V& values, T* ptr, size_t offset, size_t m
215215
return store(values, ptr, indices, indices < max_length);
216216
}
217217

218-
// TOOD: check if this way is support across all compilers
219-
#if defined(__has_builtin) && __has_builtin(__builtin_assume_aligned)
220-
#define KERNEL_FLOAT_ASSUME_ALIGNED(ptr, alignment) (__builtin_assume_aligned(ptr, alignment))
221-
#else
222-
#define KERNEL_FLOAT_ASSUME_ALIGNED(ptr, alignment) (ptr)
223-
#endif
224-
225218
template<typename T, size_t N>
226219
struct AssignConversionProxy {
227220
KERNEL_FLOAT_INLINE
@@ -263,6 +256,20 @@ KERNEL_FLOAT_INLINE AssignConversionProxy<T, E::value> cast_to(vector<T, E>& inp
263256
return AssignConversionProxy<T, E::value>(input.data());
264257
}
265258

259+
/**
260+
* Returns the original pointer ``ptr`` and hints to the compiler that this pointer is aligned to ``alignment`` bytes.
261+
* If this is not actually the case, compiler optimizations will break things and generate invalid code. Be careful!
262+
*/
263+
template<typename T>
264+
KERNEL_FLOAT_INLINE T* unsafe_assume_aligned(T* ptr, size_t alignment) {
265+
// TOOD: check if this way is support across all compilers
266+
#if defined(__has_builtin) && __has_builtin(__builtin_assume_aligned)
267+
return static_cast<T*>(__builtin_assume_aligned(ptr, alignment));
268+
#else
269+
return ptr;
270+
#endif
271+
}
272+
266273
/**
267274
* Represents a pointer of type ``T*`` that is guaranteed to be aligned to ``alignment`` bytes.
268275
*/
@@ -281,7 +288,7 @@ struct aligned_ptr {
281288
*/
282289
KERNEL_FLOAT_INLINE
283290
T* get() const {
284-
return KERNEL_FLOAT_ASSUME_ALIGNED(ptr_, alignment);
291+
return unsafe_assume_aligned(ptr_, alignment);
285292
}
286293

287294
KERNEL_FLOAT_INLINE
@@ -360,12 +367,18 @@ struct aligned_ptr<const T, alignment> {
360367
KERNEL_FLOAT_INLINE
361368
explicit aligned_ptr(const T* ptr) : ptr_(ptr) {}
362369

370+
KERNEL_FLOAT_INLINE
371+
aligned_ptr(const aligned_ptr<T>& ptr) : ptr_(ptr.get()) {}
372+
373+
KERNEL_FLOAT_INLINE
374+
aligned_ptr(const aligned_ptr<const T>& ptr) : ptr_(ptr.get()) {}
375+
363376
/**
364377
* Return the pointer value.
365378
*/
366379
KERNEL_FLOAT_INLINE
367380
const T* get() const {
368-
return KERNEL_FLOAT_ASSUME_ALIGNED(ptr_, alignment);
381+
return unsafe_assume_aligned(ptr_, alignment);
369382
}
370383

371384
KERNEL_FLOAT_INLINE
@@ -406,6 +419,9 @@ struct aligned_ptr<const T, alignment> {
406419
const T* ptr_ = nullptr;
407420
};
408421

422+
template<typename T>
423+
aligned_ptr(T*) -> aligned_ptr<T>;
424+
409425
} // namespace kernel_float
410426

411427
#endif //KERNEL_FLOAT_MEMORY_H

0 commit comments

Comments
 (0)