Skip to content

Commit 3f3edaa

Browse files
committed
Rewrite magnitude_impl and dot_impl to take direct pointers instead of vector_storage
1 parent da0a46b commit 3f3edaa

File tree

4 files changed

+76
-98
lines changed

4 files changed

+76
-98
lines changed

include/kernel_float/bf16.h

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -259,20 +259,16 @@ namespace detail {
259259
template<>
260260
struct dot_impl<__nv_bfloat16, 0> {
261261
KERNEL_FLOAT_INLINE
262-
static __nv_bfloat16 call(
263-
const vector_storage<__nv_bfloat16, 0>& left,
264-
const vector_storage<__nv_bfloat16, 0>& right) {
262+
static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) {
265263
return __nv_bfloat16(0);
266264
}
267265
};
268266

269267
template<>
270268
struct dot_impl<__nv_bfloat16, 1> {
271269
KERNEL_FLOAT_INLINE
272-
static __nv_bfloat16 call(
273-
const vector_storage<__nv_bfloat16, 1>& left,
274-
const vector_storage<__nv_bfloat16, 1>& right) {
275-
return __hmul(left.data()[0], right.data()[0]);
270+
static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) {
271+
return __hmul(left[0], right[0]);
276272
}
277273
};
278274

@@ -281,25 +277,23 @@ struct dot_impl<__nv_bfloat16, N> {
281277
static_assert(N >= 2, "internal error");
282278

283279
KERNEL_FLOAT_INLINE
284-
static __nv_bfloat16 call(
285-
const vector_storage<__nv_bfloat16, N>& left,
286-
const vector_storage<__nv_bfloat16, N>& right) {
287-
__nv_bfloat162 first_a = {left.data()[0], left.data()[1]};
288-
__nv_bfloat162 first_b = {right.data()[0], right.data()[1]};
280+
static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) {
281+
__nv_bfloat162 first_a = {left[0], left[1]};
282+
__nv_bfloat162 first_b = {right[0], right[1]};
289283
__nv_bfloat162 accum = __hmul2(first_a, first_b);
290284

291285
#pragma unroll
292-
for (size_t i = 2; i + 2 <= N; i += 2) {
293-
__nv_bfloat162 a = {left.data()[i], left.data()[i + 1]};
294-
__nv_bfloat162 b = {right.data()[i], right.data()[i + 1]};
286+
for (size_t i = 2; i + 1 < N; i += 2) {
287+
__nv_bfloat162 a = {left[i], left[i + 1]};
288+
__nv_bfloat162 b = {right[i], right[i + 1]};
295289
accum = __hfma2(a, b, accum);
296290
}
297291

298292
__nv_bfloat16 result = __hadd(accum.x, accum.y);
299293

300294
if (N % 2 != 0) {
301-
__nv_bfloat16 a = left.data()[N - 1];
302-
__nv_bfloat16 b = right.data()[N - 1];
295+
__nv_bfloat16 a = left[N - 1];
296+
__nv_bfloat16 b = right[N - 1];
303297
result = __hfma(a, b, result);
304298
}
305299

include/kernel_float/fp16.h

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -251,18 +251,16 @@ namespace detail {
251251
template<>
252252
struct dot_impl<__half, 0> {
253253
KERNEL_FLOAT_INLINE
254-
static __half
255-
call(const vector_storage<__half, 0>& left, const vector_storage<__half, 0>& right) {
254+
static __half call(const __half* left, const __half* right) {
256255
return __half(0);
257256
}
258257
};
259258

260259
template<>
261260
struct dot_impl<__half, 1> {
262261
KERNEL_FLOAT_INLINE
263-
static __half
264-
call(const vector_storage<__half, 1>& left, const vector_storage<__half, 1>& right) {
265-
return __hmul(left.data()[0], right.data()[0]);
262+
static __half call(const __half* left, const __half* right) {
263+
return __hmul(left[0], right[0]);
266264
}
267265
};
268266

@@ -271,24 +269,23 @@ struct dot_impl<__half, N> {
271269
static_assert(N >= 2, "internal error");
272270

273271
KERNEL_FLOAT_INLINE
274-
static __half
275-
call(const vector_storage<__half, N>& left, const vector_storage<__half, N>& right) {
276-
__half2 first_a = {left.data()[0], left.data()[1]};
277-
__half2 first_b = {right.data()[0], right.data()[1]};
272+
static __half call(const __half* left, const __half* right) {
273+
__half2 first_a = {left[0], left[1]};
274+
__half2 first_b = {right[0], right[1]};
278275
__half2 accum = __hmul2(first_a, first_b);
279276

280277
#pragma unroll
281278
for (size_t i = 2; i + 2 <= N; i += 2) {
282-
__half2 a = {left.data()[i], left.data()[i + 1]};
283-
__half2 b = {right.data()[i], right.data()[i + 1]};
279+
__half2 a = {left[i], left[i + 1]};
280+
__half2 b = {right[i], right[i + 1]};
284281
accum = __hfma2(a, b, accum);
285282
}
286283

287284
__half result = __hadd(accum.x, accum.y);
288285

289286
if (N % 2 != 0) {
290-
__half a = left.data()[N - 1];
291-
__half b = right.data()[N - 1];
287+
__half a = left[N - 1];
288+
__half b = right[N - 1];
292289
result = __hfma(a, b, result);
293290
}
294291

include/kernel_float/reduce.h

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,13 @@ namespace detail {
143143
template<typename T, size_t N>
144144
struct dot_impl {
145145
KERNEL_FLOAT_INLINE
146-
static T call(const vector_storage<T, N>& left, const vector_storage<T, N>& right) {
146+
static T call(const T* left, const T* right) {
147147
vector_storage<T, N> intermediate;
148148
detail::apply_impl<ops::multiply<T>, N, T, T, T>::call(
149149
ops::multiply<T>(),
150150
intermediate.data(),
151-
left.data(),
152-
right.data());
151+
left,
152+
right);
153153

154154
return detail::reduce_impl<ops::add<T>, N, T>::call(ops::add<T>(), intermediate.data());
155155
}
@@ -171,58 +171,56 @@ template<typename L, typename R, typename T = promoted_vector_value_type<L, R>>
171171
KERNEL_FLOAT_INLINE T dot(const L& left, const R& right) {
172172
using E = broadcast_vector_extent_type<L, R>;
173173
return detail::dot_impl<T, E::value>::call(
174-
convert_storage<T>(left, E {}),
175-
convert_storage<T>(right, E {}));
174+
convert_storage<T>(left, E {}).data(),
175+
convert_storage<T>(right, E {}).data());
176176
}
177177

178178
namespace detail {
179179
template<typename T, size_t N>
180180
struct magnitude_impl {
181181
KERNEL_FLOAT_INLINE
182-
static T call(const vector_storage<T, N>& input) {
182+
static T call(const T* input) {
183183
return ops::sqrt<T> {}(detail::dot_impl<T, N>::call(input, input));
184184
}
185185
};
186186

187187
template<typename T>
188188
struct magnitude_impl<T, 0> {
189189
KERNEL_FLOAT_INLINE
190-
static T call(const vector_storage<T, 0>& input) {
190+
static T call(const T* input) {
191191
return T {};
192192
}
193193
};
194194

195195
template<typename T>
196196
struct magnitude_impl<T, 1> {
197197
KERNEL_FLOAT_INLINE
198-
static T call(const vector_storage<T, 1>& input) {
199-
return ops::abs<T> {}(input);
198+
static T call(const T* input) {
199+
return ops::abs<T> {}(input[0]);
200200
}
201201
};
202202

203203
template<typename T>
204204
struct magnitude_impl<T, 2> {
205205
KERNEL_FLOAT_INLINE
206-
static T call(const vector_storage<T, 2>& input) {
207-
return ops::hypot<T>()(input.data()[0], input.data()[1]);
206+
static T call(const T* input) {
207+
return ops::hypot<T>()(input[0], input[1]);
208208
}
209209
};
210210

211211
// The 3-argument overload of hypot is only available on host from C++17
212212
#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST
213213
template<>
214214
struct magnitude_impl<float, 3> {
215-
KERNEL_FLOAT_INLINE
216-
static float call(const vector_storage<float, 3>& input) {
217-
return ::hypot(input.data()[0], input.data()[1], input.data()[2]);
215+
static float call(const float* input) {
216+
return ::hypot(input[0], input[1], input[2]);
218217
}
219218
};
220219

221220
template<>
222221
struct magnitude_impl<double, 3> {
223-
KERNEL_FLOAT_INLINE
224-
static float call(const vector_storage<double, 3>& input) {
225-
return ::hypot(input.data()[0], input.data()[1], input.data()[2]);
222+
static double call(const double* input) {
223+
return ::hypot(input[0], input[1], input[2]);
226224
}
227225
};
228226
#endif
@@ -242,7 +240,7 @@ struct magnitude_impl<double, 3> {
242240
*/
243241
template<typename V, typename T = vector_value_type<V>>
244242
KERNEL_FLOAT_INLINE T mag(const V& input) {
245-
return detail::magnitude_impl<T, vector_extent<V>>::call(into_vector_storage(input));
243+
return detail::magnitude_impl<T, vector_extent<V>>::call(into_vector_storage(input).data());
246244
}
247245
} // namespace kernel_float
248246

0 commit comments

Comments
 (0)