Skip to content

Commit f95dc85

Browse files
committed
Revert "Add cmul_add<bfloat> & draft test"
This reverts commit 55e95ad.
1 parent e5f9231 commit f95dc85

File tree

2 files changed

+18
-85
lines changed

2 files changed

+18
-85
lines changed

sycl/include/syclcompat/math.hpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -946,28 +946,6 @@ inline sycl::marray<ValueT, 2> cmul_add(const sycl::marray<ValueT, 2> a,
946946
t = t * u + v;
947947
return sycl::marray<ValueT, 2>{t.real(), t.imag()};
948948
}
949-
template <>
950-
inline sycl::vec<sycl::ext::oneapi::bfloat16, 2>
951-
cmul_add(const sycl::vec<sycl::ext::oneapi::bfloat16, 2> a,
952-
const sycl::vec<sycl::ext::oneapi::bfloat16, 2> b,
953-
const sycl::vec<sycl::ext::oneapi::bfloat16, 2> c) {
954-
sycl::ext::oneapi::experimental::complex<float> t(a[0], a[1]);
955-
sycl::ext::oneapi::experimental::complex<float> u(b[0], b[1]);
956-
sycl::ext::oneapi::experimental::complex<float> v(c[0], c[1]);
957-
t = t * u + v;
958-
return sycl::vec<sycl::ext::oneapi::bfloat16, 2>{t.real(), t.imag()};
959-
}
960-
template <>
961-
inline sycl::marray<sycl::ext::oneapi::bfloat16, 2>
962-
cmul_add(const sycl::marray<sycl::ext::oneapi::bfloat16, 2> a,
963-
const sycl::marray<sycl::ext::oneapi::bfloat16, 2> b,
964-
const sycl::marray<sycl::ext::oneapi::bfloat16, 2> c) {
965-
sycl::ext::oneapi::experimental::complex<float> t(a[0], a[1]);
966-
sycl::ext::oneapi::experimental::complex<float> u(b[0], b[1]);
967-
sycl::ext::oneapi::experimental::complex<float> v(c[0], c[1]);
968-
t = t * u + v;
969-
return sycl::marray<sycl::ext::oneapi::bfloat16, 2>{t.real(), t.imag()};
970-
}
971949

972950
/// A sycl::abs wrapper functors.
973951
struct abs {

sycl/test-e2e/syclcompat/math/math_complex.cpp

Lines changed: 18 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -184,62 +184,19 @@ void kernel_mul(int *result) {
184184
*result = r;
185185
}
186186

187-
template <typename T>
188-
std::array<T,2> complex_mul(std::array<T, 2> a, std::array<T, 2> b){
189-
std::array<T, 2> result;
190-
result[0] = (a[0] * b[0]) - (a[1] * b[1]);
191-
result[1] = (a[0] * b[1]) + (a[1] * b[0]);
192-
return result;
193-
}
194-
195-
template <typename T>
196-
std::array<T,2> complex_add(std::array<T, 2> a, std::array<T, 2> b){
197-
return {a[0] + b[0], a[1] + b[1]};
198-
}
199-
200-
template <typename T> void mul_add_groundtruth() {
201-
202-
using complex_t = std::complex<double>;
203-
using arr_t = std::array<T, 2>;
204-
205-
arr_t d1 = arr_t({static_cast<T>(5.4), static_cast<T>(-6.3)});
206-
arr_t d2 = arr_t({static_cast<T>(-7.2), static_cast<T>(8.1)});
207-
arr_t d3 = arr_t({static_cast<T>(1.0), static_cast<T>(-1.0)});
208-
209-
arr_t f1 = arr_t({static_cast<T>(1.8), static_cast<T>(-2.7)});
210-
arr_t f2 = arr_t({static_cast<T>(-3.6), static_cast<T>(4.5)});
211-
arr_t f3 = arr_t({static_cast<T>(1.0), static_cast<T>(-1.0)});
212-
213-
arr_t ra1 = complex_add(complex_mul(d1, d2), d3);
214-
arr_t ra2 = complex_add(complex_mul(f1, f2), f3);
215-
216-
T expect[4] = {13.150000, 88.100000, 6.670001, 16.820000};
217-
218-
// complex_t r1 = d1 * d2 + d3;
219-
// complex_t r2 = f1 * f2 + f3;
220-
221-
std::cout << "r1: " << static_cast<T>(ra1[0]) << ", "
222-
<< static_cast<T>(ra1[1]) << std::endl;
223-
std::cout << "Expect 1: " << expect[0] << ", " << expect[1] << std::endl;
224-
std::cout << "r2: " << static_cast<T>(ra2[0]) << ", "
225-
<< static_cast<T>(ra2[1]) << std::endl;
226-
std::cout << "Expect 2: " << expect[2] << ", " << expect[3] << std::endl;
227-
}
228-
229-
template <typename T>
230187
void kernel_mul_add(int *result) {
231-
sycl::vec<T, 2> d1, d2, d3;
232-
sycl::vec<T, 2> f1, f2, f3;
233-
sycl::marray<T, 2> m_d1, m_d2, m_d3;
234-
sycl::marray<T, 2> m_f1, m_f2, m_f3;
188+
sycl::double2 d1, d2, d3;
189+
sycl::float2 f1, f2, f3;
190+
sycl::marray<double, 2> m_d1, m_d2, m_d3;
191+
sycl::marray<float, 2> m_f1, m_f2, m_f3;
235192

236-
d1 = sycl::vec<T, 2>(5.4, -6.3);
237-
d2 = sycl::vec<T, 2>(-7.2, 8.1);
238-
d3 = sycl::vec<T, 2>(1.0, -1.0);
193+
d1 = sycl::double2(5.4, -6.3);
194+
d2 = sycl::double2(-7.2, 8.1);
195+
d3 = sycl::double2(1.0, -1.0);
239196

240-
f1 = sycl::vec<T, 2>(1.8, -2.7);
241-
f2 = sycl::vec<T, 2>(-3.6, 4.5);
242-
f3 = sycl::vec<T, 2>(1.0, -1.0);
197+
f1 = sycl::float2(1.8, -2.7);
198+
f2 = sycl::float2(-3.6, 4.5);
199+
f3 = sycl::float2(1.0, -1.0);
243200

244201
bool r = true;
245202
float expect[4] = {13.150000, 88.100000, 6.670001, 16.820000};
@@ -250,13 +207,13 @@ void kernel_mul_add(int *result) {
250207
auto a2 = syclcompat::cmul_add(f1, f2, f3);
251208
r = r && check(a2, expect + 2);
252209

253-
m_d1 = sycl::marray<T, 2>(5.4, -6.3);
254-
m_d2 = sycl::marray<T, 2>(-7.2, 8.1);
255-
m_d3 = sycl::marray<T, 2>(1.0, -1.0);
210+
m_d1 = sycl::marray<double, 2>(5.4, -6.3);
211+
m_d2 = sycl::marray<double, 2>(-7.2, 8.1);
212+
m_d3 = sycl::marray<double, 2>(1.0, -1.0);
256213

257-
m_f1 = sycl::marray<T, 2>(1.8, -2.7);
258-
m_f2 = sycl::marray<T, 2>(-3.6, 4.5);
259-
m_f3 = sycl::marray<T, 2>(1.0, -1.0);
214+
m_f1 = sycl::marray<float, 2>(1.8, -2.7);
215+
m_f2 = sycl::marray<float, 2>(-3.6, 4.5);
216+
m_f3 = sycl::marray<float, 2>(1.0, -1.0);
260217

261218
auto a3 = syclcompat::cmul_add(m_d1, m_d2, m_d3);
262219
r = r && check(a3, expect);
@@ -284,19 +241,17 @@ void test_conj() {
284241
ComplexLauncher<kernel_conj>().launch();
285242
}
286243

287-
template <typename T>
288244
void test_mul_add() {
289245
std::cout << __PRETTY_FUNCTION__ << std::endl;
290-
mul_add_groundtruth<T>();
291-
ComplexLauncher<kernel_mul_add<T>>().launch();
246+
ComplexLauncher<kernel_mul_add>().launch();
292247
}
293248

294249
int main() {
295250
test_abs();
296251
test_mul();
297252
test_div();
298253
test_conj();
299-
INSTANTIATE_ALL_TYPES(fp_type_list, test_mul_add);
254+
test_mul_add();
300255

301256
return 0;
302257
}

0 commit comments

Comments
 (0)