Skip to content

Commit d768c9b

Browse files
Special-case b==1 case in y[i] = a*x[i] + b*y[i] kernel
1 parent d31977f commit d768c9b

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ sub(sycl::queue q,
240240
return std::make_pair(ht_event, res_ev);
241241
}
242242

243+
template <typename T> class axpy_inplace_kern;
243244
template <typename T> class axpby_inplace_kern;
244245

245246
template <typename T>
@@ -259,11 +260,20 @@ sycl::event axpby_inplace_impl(sycl::queue q,
259260

260261
sycl::event res_ev = q.submit([&](sycl::handler &cgh) {
261262
cgh.depends_on(depends);
262-
cgh.parallel_for<axpby_inplace_kern<T>>(sycl::range<1>{nelems},
263-
[=](sycl::id<1> id) {
264-
auto i = id.get(0);
265-
y[i] = a * x[i] + b * y[i];
266-
});
263+
if (b == T(1)) {
264+
cgh.parallel_for<axpy_inplace_kern<T>>(sycl::range<1>{nelems},
265+
[=](sycl::id<1> id) {
266+
auto i = id.get(0);
267+
y[i] += a * x[i];
268+
});
269+
}
270+
else {
271+
cgh.parallel_for<axpby_inplace_kern<T>>(
272+
sycl::range<1>{nelems}, [=](sycl::id<1> id) {
273+
auto i = id.get(0);
274+
y[i] = b * y[i] + a * x[i];
275+
});
276+
}
267277
});
268278

269279
return res_ev;

0 commit comments

Comments
 (0)