Skip to content

Commit 820cd7a

Browse files
committed
Add fma as member function to vector
1 parent c0c7d10 commit 820cd7a

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

include/kernel_float/triops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ struct fma<double> {
114114
} // namespace ops
115115

116116
/**
117-
* Computes the result of `a * b + c`. This is done in a single operation if possible.
117+
* Computes the result of `a * b + c`. This is done in a single operation if possible for the given vector type.
118118
*/
119119
template<
120120
typename A,

include/kernel_float/vector.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,21 @@ struct vector: public S {
276276
KERNEL_FLOAT_INLINE void for_each(F fun) const {
277277
return kernel_float::for_each(*this, std::move(fun));
278278
}
279+
280+
/**
281+
* Returns the result of `*this + lhs * rhs`.
282+
*
283+
* The operation is performed using a single `kernel_float::fma` call, which may be faster then perform
284+
* the addition and multiplication separately.
285+
*/
286+
template<
287+
typename L,
288+
typename R,
289+
typename T2 = promote_t<T, vector_value_type<L>, vector_value_type<R>>,
290+
typename E2 = broadcast_extent<E, vector_extent_type<L>, vector_extent_type<R>>>
291+
KERNEL_FLOAT_INLINE vector<T2, E2> fma(const L& lhs, const R& rhs) const {
292+
return ::kernel_float::fma(lhs, rhs, *this);
293+
}
279294
};
280295

281296
/**

single_include/kernel_float.h

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2024-05-17 10:55:41.948281
20-
// git hash: 41246ab6db9fcc24639342c439e606ba143ee346
19+
// date: 2024-05-17 11:44:08.292272
20+
// git hash: c0c7d100e3ee5bc187211e3d76b1fccc73c2fa5e
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -1890,6 +1890,7 @@ struct apply_fastmath_impl<ops::divide<T>, N, T, T, T> {
18901890
call(ops::divide<T> fun, T* result, const T* lhs, const T* rhs) {
18911891
T rhs_rcp[N];
18921892

1893+
// Fast way to perform division is to multiply by the reciprocal
18931894
apply_fastmath_impl<ops::rcp<T>, N, T, T, T>::call({}, rhs_rcp, rhs);
18941895
apply_fastmath_impl<ops::multiply<T>, N, T, T, T>::call({}, result, lhs, rhs_rcp);
18951896
}
@@ -3430,7 +3431,7 @@ struct fma<double> {
34303431
} // namespace ops
34313432

34323433
/**
3433-
* Computes the result of `a * b + c`. This is done in a single operation if possible.
3434+
* Computes the result of `a * b + c`. This is done in a single operation if possible for the given vector type.
34343435
*/
34353436
template<
34363437
typename A,
@@ -3739,6 +3740,21 @@ struct vector: public S {
37393740
KERNEL_FLOAT_INLINE void for_each(F fun) const {
37403741
return kernel_float::for_each(*this, std::move(fun));
37413742
}
3743+
3744+
/**
3745+
* Returns the result of `*this + lhs * rhs`.
3746+
*
3747+
* The operation is performed using a single `kernel_float::fma` call, which may be faster then perform
3748+
* the addition and multiplication separately.
3749+
*/
3750+
template<
3751+
typename L,
3752+
typename R,
3753+
typename T2 = promote_t<T, vector_value_type<L>, vector_value_type<R>>,
3754+
typename E2 = broadcast_extent<E, vector_extent_type<L>, vector_extent_type<R>>>
3755+
KERNEL_FLOAT_INLINE vector<T2, E2> fma(const L& lhs, const R& rhs) const {
3756+
return ::kernel_float::fma(lhs, rhs, *this);
3757+
}
37423758
};
37433759

37443760
/**

0 commit comments

Comments
 (0)