Skip to content

Commit 55e95ad

Browse files
committed
Add cmul_add<bfloat> & draft test
Function casts to sycl::complex<float> (no native bfloat16 support)
1 parent 847d45d commit 55e95ad

File tree

2 files changed

+85
-18
lines changed

2 files changed

+85
-18
lines changed

sycl/include/syclcompat/math.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,28 @@ inline sycl::marray<ValueT, 2> cmul_add(const sycl::marray<ValueT, 2> a,
929929
t = t * u + v;
930930
return sycl::marray<ValueT, 2>{t.real(), t.imag()};
931931
}
932+
template <>
933+
inline sycl::vec<sycl::ext::oneapi::bfloat16, 2>
934+
cmul_add(const sycl::vec<sycl::ext::oneapi::bfloat16, 2> a,
935+
const sycl::vec<sycl::ext::oneapi::bfloat16, 2> b,
936+
const sycl::vec<sycl::ext::oneapi::bfloat16, 2> c) {
937+
sycl::ext::oneapi::experimental::complex<float> t(a[0], a[1]);
938+
sycl::ext::oneapi::experimental::complex<float> u(b[0], b[1]);
939+
sycl::ext::oneapi::experimental::complex<float> v(c[0], c[1]);
940+
t = t * u + v;
941+
return sycl::vec<sycl::ext::oneapi::bfloat16, 2>{t.real(), t.imag()};
942+
}
943+
template <>
944+
inline sycl::marray<sycl::ext::oneapi::bfloat16, 2>
945+
cmul_add(const sycl::marray<sycl::ext::oneapi::bfloat16, 2> a,
946+
const sycl::marray<sycl::ext::oneapi::bfloat16, 2> b,
947+
const sycl::marray<sycl::ext::oneapi::bfloat16, 2> c) {
948+
sycl::ext::oneapi::experimental::complex<float> t(a[0], a[1]);
949+
sycl::ext::oneapi::experimental::complex<float> u(b[0], b[1]);
950+
sycl::ext::oneapi::experimental::complex<float> v(c[0], c[1]);
951+
t = t * u + v;
952+
return sycl::marray<sycl::ext::oneapi::bfloat16, 2>{t.real(), t.imag()};
953+
}
932954

933955
/// A sycl::abs wrapper functors.
934956
struct abs {

sycl/test-e2e/syclcompat/math/math_complex.cpp

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -184,19 +184,62 @@ void kernel_mul(int *result) {
184184
*result = r;
185185
}
186186

187+
template <typename T>
188+
std::array<T,2> complex_mul(std::array<T, 2> a, std::array<T, 2> b){
189+
std::array<T, 2> result;
190+
result[0] = (a[0] * b[0]) - (a[1] * b[1]);
191+
result[1] = (a[0] * b[1]) + (a[1] * b[0]);
192+
return result;
193+
}
194+
195+
template <typename T>
196+
std::array<T,2> complex_add(std::array<T, 2> a, std::array<T, 2> b){
197+
return {a[0] + b[0], a[1] + b[1]};
198+
}
199+
200+
template <typename T> void mul_add_groundtruth() {
201+
202+
using complex_t = std::complex<double>;
203+
using arr_t = std::array<T, 2>;
204+
205+
arr_t d1 = arr_t({static_cast<T>(5.4), static_cast<T>(-6.3)});
206+
arr_t d2 = arr_t({static_cast<T>(-7.2), static_cast<T>(8.1)});
207+
arr_t d3 = arr_t({static_cast<T>(1.0), static_cast<T>(-1.0)});
208+
209+
arr_t f1 = arr_t({static_cast<T>(1.8), static_cast<T>(-2.7)});
210+
arr_t f2 = arr_t({static_cast<T>(-3.6), static_cast<T>(4.5)});
211+
arr_t f3 = arr_t({static_cast<T>(1.0), static_cast<T>(-1.0)});
212+
213+
arr_t ra1 = complex_add(complex_mul(d1, d2), d3);
214+
arr_t ra2 = complex_add(complex_mul(f1, f2), f3);
215+
216+
T expect[4] = {13.150000, 88.100000, 6.670001, 16.820000};
217+
218+
// complex_t r1 = d1 * d2 + d3;
219+
// complex_t r2 = f1 * f2 + f3;
220+
221+
std::cout << "r1: " << static_cast<T>(ra1[0]) << ", "
222+
<< static_cast<T>(ra1[1]) << std::endl;
223+
std::cout << "Expect 1: " << expect[0] << ", " << expect[1] << std::endl;
224+
std::cout << "r2: " << static_cast<T>(ra2[0]) << ", "
225+
<< static_cast<T>(ra2[1]) << std::endl;
226+
std::cout << "Expect 2: " << expect[2] << ", " << expect[3] << std::endl;
227+
}
228+
229+
template <typename T>
187230
void kernel_mul_add(int *result) {
188-
sycl::double2 d1, d2, d3;
189-
sycl::float2 f1, f2, f3;
190-
sycl::marray<double, 2> m_d1, m_d2, m_d3;
191-
sycl::marray<float, 2> m_f1, m_f2, m_f3;
231+
sycl::vec<T, 2> d1, d2, d3;
232+
sycl::vec<T, 2> f1, f2, f3;
233+
sycl::marray<T, 2> m_d1, m_d2, m_d3;
234+
sycl::marray<T, 2> m_f1, m_f2, m_f3;
192235

193-
d1 = sycl::double2(5.4, -6.3);
194-
d2 = sycl::double2(-7.2, 8.1);
195-
d3 = sycl::double2(1.0, -1.0);
236+
d1 = sycl::vec<T, 2>(5.4, -6.3);
237+
d2 = sycl::vec<T, 2>(-7.2, 8.1);
238+
d3 = sycl::vec<T, 2>(1.0, -1.0);
196239

197-
f1 = sycl::float2(1.8, -2.7);
198-
f2 = sycl::float2(-3.6, 4.5);
199-
f3 = sycl::float2(1.0, -1.0);
240+
f1 = sycl::vec<T, 2>(1.8, -2.7);
241+
f2 = sycl::vec<T, 2>(-3.6, 4.5);
242+
f3 = sycl::vec<T, 2>(1.0, -1.0);
200243

201244
bool r = true;
202245
float expect[4] = {13.150000, 88.100000, 6.670001, 16.820000};
@@ -207,13 +250,13 @@ void kernel_mul_add(int *result) {
207250
auto a2 = syclcompat::cmul_add(f1, f2, f3);
208251
r = r && check(a2, expect + 2);
209252

210-
m_d1 = sycl::marray<double, 2>(5.4, -6.3);
211-
m_d2 = sycl::marray<double, 2>(-7.2, 8.1);
212-
m_d3 = sycl::marray<double, 2>(1.0, -1.0);
253+
m_d1 = sycl::marray<T, 2>(5.4, -6.3);
254+
m_d2 = sycl::marray<T, 2>(-7.2, 8.1);
255+
m_d3 = sycl::marray<T, 2>(1.0, -1.0);
213256

214-
m_f1 = sycl::marray<float, 2>(1.8, -2.7);
215-
m_f2 = sycl::marray<float, 2>(-3.6, 4.5);
216-
m_f3 = sycl::marray<float, 2>(1.0, -1.0);
257+
m_f1 = sycl::marray<T, 2>(1.8, -2.7);
258+
m_f2 = sycl::marray<T, 2>(-3.6, 4.5);
259+
m_f3 = sycl::marray<T, 2>(1.0, -1.0);
217260

218261
auto a3 = syclcompat::cmul_add(m_d1, m_d2, m_d3);
219262
r = r && check(a3, expect);
@@ -241,17 +284,19 @@ void test_conj() {
241284
ComplexLauncher<kernel_conj>().launch();
242285
}
243286

287+
template <typename T>
244288
void test_mul_add() {
245289
std::cout << __PRETTY_FUNCTION__ << std::endl;
246-
ComplexLauncher<kernel_mul_add>().launch();
290+
mul_add_groundtruth<T>();
291+
ComplexLauncher<kernel_mul_add<T>>().launch();
247292
}
248293

249294
int main() {
250295
test_abs();
251296
test_mul();
252297
test_div();
253298
test_conj();
254-
test_mul_add();
299+
INSTANTIATE_ALL_TYPES(fp_type_list, test_mul_add);
255300

256301
return 0;
257302
}

0 commit comments

Comments
 (0)