Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
475 changes: 220 additions & 255 deletions dpctl/tensor/libtensor/include/kernels/accumulators.hpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ struct MaskedPlaceStridedFunctor

// ======= Masked extraction ================================

namespace
namespace detail
{

template <std::size_t I, std::size_t... IR>
Expand All @@ -234,7 +234,7 @@ std::size_t get_lws(std::size_t n)
return _get_lws_impl<lws0, lws1, lws2>(n);
}

} // end of anonymous namespace
} // end of namespace detail

template <typename MaskedDstIndexerT, typename dataT, typename indT>
class masked_extract_all_slices_contig_impl_krn;
Expand Down Expand Up @@ -278,7 +278,7 @@ sycl::event masked_extract_all_slices_contig_impl(

const std::size_t masked_extent = iteration_size;

const std::size_t lws = get_lws(masked_extent);
const std::size_t lws = detail::get_lws(masked_extent);

const std::size_t n_groups = (iteration_size + lws - 1) / lws;

Expand Down Expand Up @@ -357,7 +357,7 @@ sycl::event masked_extract_all_slices_strided_impl(

const std::size_t masked_nelems = iteration_size;

const std::size_t lws = get_lws(masked_nelems);
const std::size_t lws = detail::get_lws(masked_nelems);

const std::size_t n_groups = (masked_nelems + lws - 1) / lws;

Expand Down Expand Up @@ -452,7 +452,7 @@ sycl::event masked_extract_some_slices_strided_impl(

const std::size_t masked_extent = masked_nelems;

const std::size_t lws = get_lws(masked_extent);
const std::size_t lws = detail::get_lws(masked_extent);

const std::size_t n_groups = ((masked_extent + lws - 1) / lws);
const std::size_t orthog_extent = static_cast<std::size_t>(orthog_nelems);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ template <typename T> struct AbsOutputType
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

namespace
namespace hyperparam_detail
{

namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
Expand All @@ -141,7 +141,7 @@ template <typename argTy> struct AbsContigHyperparameterSet
constexpr static auto n_vecs = value_type::n_vecs;
};

} // namespace
} // namespace hyperparam_detail

template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
class abs_contig_kernel;
Expand All @@ -153,8 +153,9 @@ sycl::event abs_contig_impl(sycl::queue &exec_q,
char *res_p,
const std::vector<sycl::event> &depends = {})
{
constexpr std::uint8_t vec_sz = AbsContigHyperparameterSet<argTy>::vec_sz;
constexpr std::uint8_t n_vec = AbsContigHyperparameterSet<argTy>::n_vecs;
using AbsHS = hyperparam_detail::AbsContigHyperparameterSet<argTy>;
constexpr std::uint8_t vec_sz = AbsHS::vec_sz;
constexpr std::uint8_t n_vec = AbsHS::n_vecs;

return elementwise_common::unary_contig_impl<
argTy, AbsOutputType, AbsContigFunctor, abs_contig_kernel, vec_sz,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ template <typename T> struct AcosOutputType
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

namespace
namespace hyperparam_detail
{

namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
Expand All @@ -175,7 +175,7 @@ template <typename argTy> struct AcosContigHyperparameterSet
constexpr static auto n_vecs = value_type::n_vecs;
};

} // namespace
} // end of namespace hyperparam_detail

template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
class acos_contig_kernel;
Expand All @@ -187,8 +187,9 @@ sycl::event acos_contig_impl(sycl::queue &exec_q,
char *res_p,
const std::vector<sycl::event> &depends = {})
{
constexpr std::uint8_t vec_sz = AcosContigHyperparameterSet<argTy>::vec_sz;
constexpr std::uint8_t n_vec = AcosContigHyperparameterSet<argTy>::n_vecs;
using AcosHS = hyperparam_detail::AcosContigHyperparameterSet<argTy>;
constexpr std::uint8_t vec_sz = AcosHS::vec_sz;
constexpr std::uint8_t n_vec = AcosHS::n_vecs;

return elementwise_common::unary_contig_impl<
argTy, AcosOutputType, AcosContigFunctor, acos_contig_kernel, vec_sz,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ template <typename T> struct AcoshOutputType
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

namespace
namespace hyperparam_detail
{

namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
Expand All @@ -203,7 +203,7 @@ template <typename argTy> struct AcoshContigHyperparameterSet
constexpr static auto n_vecs = value_type::n_vecs;
};

} // namespace
} // end of namespace hyperparam_detail

template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
class acosh_contig_kernel;
Expand All @@ -215,8 +215,9 @@ sycl::event acosh_contig_impl(sycl::queue &exec_q,
char *res_p,
const std::vector<sycl::event> &depends = {})
{
constexpr std::uint8_t vec_sz = AcoshContigHyperparameterSet<argTy>::vec_sz;
constexpr std::uint8_t n_vec = AcoshContigHyperparameterSet<argTy>::n_vecs;
using AcoshHS = hyperparam_detail::AcoshContigHyperparameterSet<argTy>;
constexpr std::uint8_t vec_sz = AcoshHS::vec_sz;
constexpr std::uint8_t n_vec = AcoshHS::n_vecs;

return elementwise_common::unary_contig_impl<
argTy, AcoshOutputType, AcoshContigFunctor, acosh_contig_kernel, vec_sz,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ template <typename T1, typename T2> struct AddOutputType
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

namespace
namespace hyperparam_detail
{

namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
Expand Down Expand Up @@ -252,7 +252,7 @@ template <typename argTy1, typename argTy2> struct AddContigHyperparameterSet
constexpr static auto n_vecs = value_type::n_vecs;
};

} // end of anonymous namespace
} // end of namespace hyperparam_detail

template <typename argT1,
typename argT2,
Expand All @@ -272,8 +272,9 @@ sycl::event add_contig_impl(sycl::queue &exec_q,
ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
constexpr auto vec_sz = AddContigHyperparameterSet<argTy1, argTy2>::vec_sz;
constexpr auto n_vecs = AddContigHyperparameterSet<argTy1, argTy2>::n_vecs;
using AddHS = hyperparam_detail::AddContigHyperparameterSet<argTy1, argTy2>;
constexpr auto vec_sz = AddHS::vec_sz;
constexpr auto n_vecs = AddHS::n_vecs;

return elementwise_common::binary_contig_impl<
argTy1, argTy2, AddOutputType, AddContigFunctor, add_contig_kernel,
Expand Down Expand Up @@ -550,8 +551,10 @@ add_inplace_contig_impl(sycl::queue &exec_q,
ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
constexpr auto vec_sz = AddContigHyperparameterSet<resTy, argTy>::vec_sz;
constexpr auto n_vecs = AddContigHyperparameterSet<resTy, argTy>::n_vecs;
constexpr auto vec_sz =
hyperparam_detail::AddContigHyperparameterSet<resTy, argTy>::vec_sz;
constexpr auto n_vecs =
hyperparam_detail::AddContigHyperparameterSet<resTy, argTy>::n_vecs;

return elementwise_common::binary_inplace_contig_impl<
argTy, resTy, AddInplaceContigFunctor, add_inplace_contig_kernel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ template <typename T> struct AngleOutputType
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

namespace
namespace hyperparam_detail
{

namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
Expand All @@ -119,7 +119,7 @@ template <typename argTy> struct AngleContigHyperparameterSet
constexpr static auto n_vecs = value_type::n_vecs;
};

} // end of anonymous namespace
} // end of namespace hyperparam_detail

template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
class angle_contig_kernel;
Expand All @@ -131,8 +131,9 @@ sycl::event angle_contig_impl(sycl::queue &exec_q,
char *res_p,
const std::vector<sycl::event> &depends = {})
{
constexpr std::uint8_t vec_sz = AngleContigHyperparameterSet<argTy>::vec_sz;
constexpr std::uint8_t n_vec = AngleContigHyperparameterSet<argTy>::n_vecs;
using AngleHS = hyperparam_detail::AngleContigHyperparameterSet<argTy>;
constexpr std::uint8_t vec_sz = AngleHS::vec_sz;
constexpr std::uint8_t n_vec = AngleHS::n_vecs;

return elementwise_common::unary_contig_impl<
argTy, AngleOutputType, AngleContigFunctor, angle_contig_kernel, vec_sz,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ template <typename T> struct AsinOutputType
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

namespace
namespace hyperparam_detail
{

namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
Expand All @@ -196,7 +196,7 @@ template <typename argTy> struct AsinContigHyperparameterSet
constexpr static auto n_vecs = value_type::n_vecs;
};

} // end of anonymous namespace
} // end of namespace hyperparam_detail

template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
class asin_contig_kernel;
Expand All @@ -208,8 +208,9 @@ sycl::event asin_contig_impl(sycl::queue &exec_q,
char *res_p,
const std::vector<sycl::event> &depends = {})
{
constexpr std::uint8_t vec_sz = AsinContigHyperparameterSet<argTy>::vec_sz;
constexpr std::uint8_t n_vec = AsinContigHyperparameterSet<argTy>::n_vecs;
using AddHS = hyperparam_detail::AsinContigHyperparameterSet<argTy>;
constexpr std::uint8_t vec_sz = AddHS::vec_sz;
constexpr std::uint8_t n_vec = AddHS::n_vecs;

return elementwise_common::unary_contig_impl<
argTy, AsinOutputType, AsinContigFunctor, asin_contig_kernel, vec_sz,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ template <typename T> struct AsinhOutputType
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

namespace
namespace hyperparam_detail
{

namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
Expand All @@ -179,7 +179,7 @@ template <typename argTy> struct AsinhContigHyperparameterSet
constexpr static auto n_vecs = value_type::n_vecs;
};

} // end of anonymous namespace
} // end of namespace hyperparam_detail

template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
class asinh_contig_kernel;
Expand All @@ -191,8 +191,9 @@ sycl::event asinh_contig_impl(sycl::queue &exec_q,
char *res_p,
const std::vector<sycl::event> &depends = {})
{
constexpr std::uint8_t vec_sz = AsinhContigHyperparameterSet<argTy>::vec_sz;
constexpr std::uint8_t n_vec = AsinhContigHyperparameterSet<argTy>::n_vecs;
using AsinhHS = hyperparam_detail::AsinhContigHyperparameterSet<argTy>;
constexpr std::uint8_t vec_sz = AsinhHS::vec_sz;
constexpr std::uint8_t n_vec = AsinhHS::n_vecs;

return elementwise_common::unary_contig_impl<
argTy, AsinhOutputType, AsinhContigFunctor, asinh_contig_kernel, vec_sz,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ template <typename T> struct AtanOutputType
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

namespace
namespace hyperparam_detail
{

namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
Expand All @@ -189,7 +189,7 @@ template <typename argTy> struct AtanContigHyperparameterSet
constexpr static auto n_vecs = value_type::n_vecs;
};

} // end of anonymous namespace
} // end of namespace hyperparam_detail

template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
class atan_contig_kernel;
Expand All @@ -201,8 +201,9 @@ sycl::event atan_contig_impl(sycl::queue &exec_q,
char *res_p,
const std::vector<sycl::event> &depends = {})
{
constexpr std::uint8_t vec_sz = AtanContigHyperparameterSet<argTy>::vec_sz;
constexpr std::uint8_t n_vec = AtanContigHyperparameterSet<argTy>::n_vecs;
using AtanHS = hyperparam_detail::AtanContigHyperparameterSet<argTy>;
constexpr std::uint8_t vec_sz = AtanHS::vec_sz;
constexpr std::uint8_t n_vec = AtanHS::n_vecs;

return elementwise_common::unary_contig_impl<
argTy, AtanOutputType, AtanContigFunctor, atan_contig_kernel, vec_sz,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ template <typename T1, typename T2> struct Atan2OutputType
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

namespace
namespace hyperparam_detail
{

namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
Expand All @@ -123,7 +123,7 @@ template <typename argTy1, typename argTy2> struct Atan2ContigHyperparameterSet
constexpr static auto n_vecs = value_type::n_vecs;
};

} // end of anonymous namespace
} // end of namespace hyperparam_detail

template <typename argT1,
typename argT2,
Expand All @@ -143,10 +143,10 @@ sycl::event atan2_contig_impl(sycl::queue &exec_q,
ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
constexpr std::uint8_t vec_sz =
Atan2ContigHyperparameterSet<argTy1, argTy2>::vec_sz;
constexpr std::uint8_t n_vecs =
Atan2ContigHyperparameterSet<argTy1, argTy2>::n_vecs;
using Atan2HS =
hyperparam_detail::Atan2ContigHyperparameterSet<argTy1, argTy2>;
constexpr std::uint8_t vec_sz = Atan2HS::vec_sz;
constexpr std::uint8_t n_vecs = Atan2HS::n_vecs;

return elementwise_common::binary_contig_impl<
argTy1, argTy2, Atan2OutputType, Atan2ContigFunctor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ template <typename T> struct AtanhOutputType
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

namespace
namespace hyperparam_detail
{

namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
Expand All @@ -180,7 +180,7 @@ template <typename argTy> struct AtanhContigHyperparameterSet
constexpr static auto n_vecs = value_type::n_vecs;
};

} // end of anonymous namespace
} // end of namespace hyperparam_detail

template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
class atanh_contig_kernel;
Expand All @@ -192,8 +192,9 @@ sycl::event atanh_contig_impl(sycl::queue &exec_q,
char *res_p,
const std::vector<sycl::event> &depends = {})
{
constexpr std::uint8_t vec_sz = AtanhContigHyperparameterSet<argTy>::vec_sz;
constexpr std::uint8_t n_vec = AtanhContigHyperparameterSet<argTy>::n_vecs;
using AtanhHS = hyperparam_detail::AtanhContigHyperparameterSet<argTy>;
constexpr std::uint8_t vec_sz = AtanhHS::vec_sz;
constexpr std::uint8_t n_vec = AtanhHS::n_vecs;

return elementwise_common::unary_contig_impl<
argTy, AtanhOutputType, AtanhContigFunctor, atanh_contig_kernel, vec_sz,
Expand Down
Loading
Loading