Skip to content

Commit 5f49827

Browse files
linspace_affine should not use double precision type in kernels is HW does not support it
1 parent 09de29b commit 5f49827

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ template <typename srcT, typename dstT> class copy_cast_from_host_kernel;
4343
template <typename srcT, typename dstT, int nd> class copy_cast_spec_kernel;
4444
template <typename Ty> class copy_for_reshape_generic_kernel;
4545
template <typename Ty> class linear_sequence_step_kernel;
46-
template <typename Ty> class linear_sequence_affine_kernel;
46+
template <typename Ty, typename wTy> class linear_sequence_affine_kernel;
4747

4848
static dpctl::tensor::detail::usm_ndarray_types array_types;
4949

@@ -1526,7 +1526,7 @@ typedef sycl::event (*lin_space_affine_fn_ptr_t)(
15261526
static lin_space_affine_fn_ptr_t
15271527
lin_space_affine_dispatch_vector[_ns::num_types];
15281528

1529-
template <typename Ty> class LinearSequenceAffineFunctor
1529+
template <typename Ty, typename wTy> class LinearSequenceAffineFunctor
15301530
{
15311531
private:
15321532
Ty *p = nullptr;
@@ -1544,8 +1544,8 @@ template <typename Ty> class LinearSequenceAffineFunctor
15441544
void operator()(sycl::id<1> wiid) const
15451545
{
15461546
auto i = wiid.get(0);
1547-
double wc = double(i) / n;
1548-
double w = double(n - i) / n;
1547+
wTy wc = wTy(i) / n;
1548+
wTy w = wTy(n - i) / n;
15491549
if constexpr (is_complex<Ty>::value) {
15501550
auto _w = static_cast<typename Ty::value_type>(w);
15511551
auto _wc = static_cast<typename Ty::value_type>(wc);
@@ -1578,13 +1578,23 @@ sycl::event lin_space_affine_impl(sycl::queue exec_q,
15781578
throw;
15791579
}
15801580

1581+
bool device_supports_doubles = exec_q.get_device().has(sycl::aspect::fp64);
15811582
sycl::event lin_space_affine_event = exec_q.submit([&](sycl::handler &cgh) {
15821583
cgh.depends_on(depends);
1583-
cgh.parallel_for<linear_sequence_affine_kernel<Ty>>(
1584-
sycl::range<1>{nelems},
1585-
LinearSequenceAffineFunctor<Ty>(array_data, start_v, end_v,
1586-
(include_endpoint) ? nelems - 1
1587-
: nelems));
1584+
if (device_supports_doubles) {
1585+
cgh.parallel_for<linear_sequence_affine_kernel<Ty, double>>(
1586+
sycl::range<1>{nelems},
1587+
LinearSequenceAffineFunctor<Ty, double>(
1588+
array_data, start_v, end_v,
1589+
(include_endpoint) ? nelems - 1 : nelems));
1590+
}
1591+
else {
1592+
cgh.parallel_for<linear_sequence_affine_kernel<Ty, float>>(
1593+
sycl::range<1>{nelems},
1594+
LinearSequenceAffineFunctor<Ty, float>(
1595+
array_data, start_v, end_v,
1596+
(include_endpoint) ? nelems - 1 : nelems));
1597+
}
15881598
});
15891599

15901600
return lin_space_affine_event;

0 commit comments

Comments
 (0)