Skip to content

Commit b372b29

Browse files
committed
Add cast_to function
1 parent f2b1f23 commit b372b29

File tree

4 files changed

+58
-1
lines changed

4 files changed

+58
-1
lines changed

docs/build_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def build_index_page(groups):
199199
("where", "where(const C&)"),
200200
],
201201
"Memory read/write": [
202+
"cast_to",
202203
("load", "load(const T*, const I&)"),
203204
("load", "load(const T*, const I&, const M&)"),
204205
("loadn", "loadn(const T*, ptrdiff_t)"),

include/kernel_float/memory.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,47 @@ KERNEL_FLOAT_INLINE void storen(const V& values, T* ptr, ptrdiff_t offset, ptrdi
223223
#define KERNEL_FLOAT_ASSUME_ALIGNED(ptr, alignment) (ptr)
224224
#endif
225225

226+
template<typename T, size_t N>
227+
struct AssignConversionProxy {
228+
KERNEL_FLOAT_INLINE
229+
explicit AssignConversionProxy(T* ptr) : ptr_(ptr) {}
230+
231+
template<typename U>
232+
KERNEL_FLOAT_INLINE AssignConversionProxy& operator=(U&& values) {
233+
auto indices = range<ptrdiff_t, N>();
234+
detail::store_impl<T, N>::call(
235+
ptr_,
236+
convert_storage<T, N>(std::forward<U>(values)).data(),
237+
indices.data());
238+
239+
return *this;
240+
}
241+
242+
private:
243+
T* ptr_;
244+
};
245+
246+
/**
247+
* Takes a reference to a vector and returns a special proxy object that automatically performs the correct conversion
248+
* when a vector of a different element type is assigned. This is useful to perform implicit type conversions.
249+
*
250+
* For example, let assume that a line like `x = expression;` would not compile since `x` and `expressions` are
251+
* vectors of different element types. Then it is possible to use `cast_to(x) = expression;` to fix this error,
252+
* which possibly introduces a type conversion.
253+
*
254+
* Example
255+
* =======
256+
* ```
257+
* vec<float, 2> x;
258+
* vec<double, 2> y = {1.0, 2.0};
259+
* cast_to(x) = y; // normally, the line `x = y;` would not compile, but `cast_to` make this possible
260+
* ```
261+
*/
262+
template<typename T, typename E>
263+
KERNEL_FLOAT_INLINE AssignConversionProxy<T, E::value> cast_to(vector<T, E>& input) {
264+
return AssignConversionProxy<T, E::value>(input.data());
265+
}
266+
226267
/**
227268
* Represents a pointer of type ``T*`` that is guaranteed to be aligned to ``alignment`` bytes.
228269
*/

include/kernel_float/tiling.h

Whitespace-only changes.

tests/memory.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,19 @@ struct store_test {
127127
};
128128

129129
REGISTER_TEST_CASE("store", store_test, int, float, double)
130-
REGISTER_TEST_CASE_GPU("store", store_test, __half, __nv_bfloat16)
130+
REGISTER_TEST_CASE_GPU("store", store_test, __half, __nv_bfloat16)
131+
132+
struct assign_conversion_test {
133+
template<typename T, size_t... I, size_t N = sizeof...(I)>
134+
__host__ __device__ void operator()(generator<T> gen, std::index_sequence<I...>) {
135+
kf::vec<T, N> x = {gen.next(I)...};
136+
kf::vec<float, N> y;
137+
138+
kf::cast_to(y) = x;
139+
140+
ASSERT_EQ_ALL(float(x[I]), y[I]);
141+
}
142+
};
143+
144+
REGISTER_TEST_CASE("assign conversion", assign_conversion_test, int, float, double)
145+
REGISTER_TEST_CASE_GPU("assign conversion", assign_conversion_test, __half, __nv_bfloat16)

0 commit comments

Comments
 (0)