Skip to content

Commit c0939d0

Browse files
committed
Fix incorrect definition of FMA
1 parent 07af0ad commit c0939d0

File tree

3 files changed

+168
-1
lines changed

3 files changed

+168
-1
lines changed

include/kernel_float/triops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ namespace ops {
9292
template<typename T>
9393
struct fma {
9494
KERNEL_FLOAT_INLINE T operator()(T a, T b, T c) {
95-
return a + b * c;
95+
return a * b + c;
9696
}
9797
};
9898

tests/constant.cu

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "common.h"
2+
3+
struct triops_tests {
4+
template<typename T, size_t... I, size_t N = sizeof...(I)>
5+
__host__ __device__ void operator()(generator<T> gen, std::index_sequence<I...>) {
6+
T x[N] = {gen.next(I)...};
7+
T y[N] = {gen.next(I)...};
8+
T z[N] = {gen.next(I)...};
9+
10+
kf::vec<T, N> a = {x[I]...};
11+
kf::vec<T, N> b = {y[I]...};
12+
kf::vec<T, N> c = {z[I]...};
13+
14+
kf::vec<T, N> answer = kf::where(a, b, c);
15+
ASSERT_EQ_ALL(answer[I], bool(x[I]) ? y[I] : z[I]);
16+
17+
answer = kf::where(a, b);
18+
ASSERT_EQ_ALL(answer[I], bool(x[I]) ? y[I] : T());
19+
20+
answer = kf::where(a);
21+
ASSERT_EQ_ALL(answer[I], T(bool(x[I])));
22+
23+
answer = kf::fma(a, b, c);
24+
ASSERT_EQ_ALL(answer[I], x[I] * y[I] + z[I]);
25+
26+
}
27+
};
28+
29+
REGISTER_TEST_CASE("ternary operators", triops_tests, int, float, double)
30+
REGISTER_TEST_CASE_GPU("ternary operators", triops_tests, __half, __nv_bfloat16)

tests/triops.cu

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#include "common.h"
2+
3+
struct binops_tests {
4+
template<typename T, size_t... I, size_t N = sizeof...(I)>
5+
__host__ __device__ void operator()(generator<T> gen, std::index_sequence<I...>) {
6+
T x[N] = {gen.next(I)...};
7+
T y[N] = {gen.next(I)...};
8+
9+
kf::vec<T, N> a = {x[I]...};
10+
kf::vec<T, N> b = {y[I]...};
11+
kf::vec<T, N> c;
12+
13+
// Arithmetic
14+
c = a + b;
15+
ASSERT(equals(T(x[I] + y[I]), c[I]) && ...);
16+
17+
c = a - b;
18+
ASSERT(equals(T(x[I] - y[I]), c[I]) && ...);
19+
20+
c = a * b;
21+
ASSERT(equals(T(x[I] * y[I]), c[I]) && ...);
22+
23+
// Results in division by zero
24+
// c = a / b;
25+
// ASSERT(equals(T(x[I] / y[I]), c[I]) && ...);
26+
27+
// Results in division by zero
28+
// c = a % b;
29+
// ASSERT(equals(T(x[I] % y[I]), c[I]) && ...);
30+
31+
// Comparison
32+
c = a < b;
33+
ASSERT(equals(T(x[I] < y[I]), c[I]) && ...);
34+
35+
c = a > b;
36+
ASSERT(equals(T(x[I] > y[I]), c[I]) && ...);
37+
38+
c = a <= b;
39+
ASSERT(equals(T(x[I] <= y[I]), c[I]) && ...);
40+
41+
c = a >= b;
42+
ASSERT(equals(T(x[I] >= y[I]), c[I]) && ...);
43+
44+
c = a == b;
45+
ASSERT(equals(T(x[I] == y[I]), c[I]) && ...);
46+
47+
c = a != b;
48+
ASSERT(equals(T(x[I] != y[I]), c[I]) && ...);
49+
50+
// Assignment
51+
c = a;
52+
c += b;
53+
ASSERT(equals(T(x[I] + y[I]), c[I]) && ...);
54+
55+
c = a;
56+
c -= b;
57+
ASSERT(equals(T(x[I] - y[I]), c[I]) && ...);
58+
59+
c = a;
60+
c *= b;
61+
ASSERT(equals(T(x[I] * y[I]), c[I]) && ...);
62+
}
63+
};
64+
65+
REGISTER_TEST_CASE("binary operators", binops_tests, bool, int, float, double)
66+
REGISTER_TEST_CASE_GPU("binary operators", binops_tests, __half, __nv_bfloat16)
67+
68+
struct binops_float_tests {
69+
template<typename T, size_t... I, size_t N = sizeof...(I)>
70+
__host__ __device__ void operator()(generator<T> gen, std::index_sequence<I...>) {
71+
T x[N] = {gen.next(I)...};
72+
T y[N] = {gen.next(I)...};
73+
74+
kf::vec<T, N> a = {x[I]...};
75+
kf::vec<T, N> b = {y[I]...};
76+
kf::vec<T, N> c;
77+
78+
c = a / b;
79+
ASSERT(equals(T(x[I] / y[I]), c[I]) && ...);
80+
81+
// remainder is not support for fp16
82+
if constexpr (is_none_of<T, __half, __nv_bfloat16>) {
83+
// c = a % b;
84+
// ASSERT(equals(T(fmod(x[I], y[I])), c[I]) && ...);
85+
}
86+
}
87+
};
88+
89+
REGISTER_TEST_CASE("binary float operators", binops_float_tests, float, double)
90+
REGISTER_TEST_CASE_GPU("binary float operators", binops_float_tests, __half, __nv_bfloat16)
91+
92+
struct minmax_tests {
93+
template<typename T, size_t... I, size_t N = sizeof...(I)>
94+
__host__ __device__ void operator()(generator<T> gen, std::index_sequence<I...>) {
95+
T x[N] = {gen.next(I)...};
96+
T y[N] = {gen.next(I)...};
97+
98+
kf::vec<T, N> a = {x[I]...};
99+
kf::vec<T, N> b = {y[I]...};
100+
101+
kf::vec<T, N> lo = min(a, b);
102+
kf::vec<T, N> hi = max(a, b);
103+
104+
if constexpr (is_one_of<T, double>) {
105+
ASSERT(equals(fmin(a[I], b[I]), lo[I]) && ...);
106+
ASSERT(equals(fmax(a[I], b[I]), hi[I]) && ...);
107+
} else if constexpr (is_one_of<T, float>) {
108+
ASSERT(equals(fminf(a[I], b[I]), lo[I]) && ...);
109+
ASSERT(equals(fmaxf(a[I], b[I]), hi[I]) && ...);
110+
} else if constexpr (is_one_of<T, __half, __nv_bfloat16>) {
111+
ASSERT(equals(__hmin(a[I], b[I]), lo[I]) && ...);
112+
ASSERT(equals(__hmax(a[I], b[I]), hi[I]) && ...);
113+
} else {
114+
ASSERT(equals(x[I] < y[I] ? x[I] : y[I], lo[I]) && ...);
115+
ASSERT(equals(x[I] < y[I] ? y[I] : x[I], hi[I]) && ...);
116+
}
117+
}
118+
};
119+
120+
REGISTER_TEST_CASE("min/max functions", minmax_tests, bool, int, float, double)
121+
REGISTER_TEST_CASE_GPU("min/max functions", minmax_tests, __half, __nv_bfloat16)
122+
123+
struct cross_test {
124+
template<typename T>
125+
__host__ __device__ void operator()(generator<T> gen) {
126+
kf::vec<T, 3> a = {1, 2, 3};
127+
kf::vec<T, 3> b = {4, 5, 6};
128+
kf::vec<T, 3> c = cross(a, b);
129+
130+
ASSERT(c[0] == T(-3));
131+
ASSERT(c[1] == T(6));
132+
ASSERT(c[2] == T(-3));
133+
}
134+
};
135+
136+
REGISTER_TEST_CASE("cross product", cross_test, float, double)
137+
REGISTER_TEST_CASE_GPU("cross product", cross_test, __half, __nv_bfloat16)

0 commit comments

Comments
 (0)