@@ -43,7 +43,7 @@ template <typename srcT, typename dstT> class copy_cast_from_host_kernel;
43
43
template <typename srcT, typename dstT, int nd> class copy_cast_spec_kernel ;
44
44
template <typename Ty> class copy_for_reshape_generic_kernel ;
45
45
template <typename Ty> class linear_sequence_step_kernel ;
46
- template <typename Ty> class linear_sequence_affine_kernel ;
46
+ template <typename Ty, typename wTy > class linear_sequence_affine_kernel ;
47
47
48
48
static dpctl::tensor::detail::usm_ndarray_types array_types;
49
49
@@ -1526,7 +1526,7 @@ typedef sycl::event (*lin_space_affine_fn_ptr_t)(
1526
1526
static lin_space_affine_fn_ptr_t
1527
1527
lin_space_affine_dispatch_vector[_ns::num_types];
1528
1528
1529
- template <typename Ty> class LinearSequenceAffineFunctor
1529
+ template <typename Ty, typename wTy > class LinearSequenceAffineFunctor
1530
1530
{
1531
1531
private:
1532
1532
Ty *p = nullptr ;
@@ -1544,8 +1544,8 @@ template <typename Ty> class LinearSequenceAffineFunctor
1544
1544
void operator ()(sycl::id<1 > wiid) const
1545
1545
{
1546
1546
auto i = wiid.get (0 );
1547
- double wc = double (i) / n;
1548
- double w = double (n - i) / n;
1547
+ wTy wc = wTy (i) / n;
1548
+ wTy w = wTy (n - i) / n;
1549
1549
if constexpr (is_complex<Ty>::value) {
1550
1550
auto _w = static_cast <typename Ty::value_type>(w);
1551
1551
auto _wc = static_cast <typename Ty::value_type>(wc);
@@ -1578,13 +1578,23 @@ sycl::event lin_space_affine_impl(sycl::queue exec_q,
1578
1578
throw ;
1579
1579
}
1580
1580
1581
+ bool device_supports_doubles = exec_q.get_device ().has (sycl::aspect::fp64);
1581
1582
sycl::event lin_space_affine_event = exec_q.submit ([&](sycl::handler &cgh) {
1582
1583
cgh.depends_on (depends);
1583
- cgh.parallel_for <linear_sequence_affine_kernel<Ty>>(
1584
- sycl::range<1 >{nelems},
1585
- LinearSequenceAffineFunctor<Ty>(array_data, start_v, end_v,
1586
- (include_endpoint) ? nelems - 1
1587
- : nelems));
1584
+ if (device_supports_doubles) {
1585
+ cgh.parallel_for <linear_sequence_affine_kernel<Ty, double >>(
1586
+ sycl::range<1 >{nelems},
1587
+ LinearSequenceAffineFunctor<Ty, double >(
1588
+ array_data, start_v, end_v,
1589
+ (include_endpoint) ? nelems - 1 : nelems));
1590
+ }
1591
+ else {
1592
+ cgh.parallel_for <linear_sequence_affine_kernel<Ty, float >>(
1593
+ sycl::range<1 >{nelems},
1594
+ LinearSequenceAffineFunctor<Ty, float >(
1595
+ array_data, start_v, end_v,
1596
+ (include_endpoint) ? nelems - 1 : nelems));
1597
+ }
1588
1598
});
1589
1599
1590
1600
return lin_space_affine_event;
0 commit comments