Skip to content

Commit e6d3564

Browse files
Generate implementations for accumulating into some wider types
Also support boolean output for boolean input (NumPy does support it)
1 parent 69f17a3 commit e6d3564

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,21 @@ template <typename T1, typename T2> struct DotAtomicOutputType
2323
T2,
2424
std::uint32_t,
2525
std::uint32_t>,
26+
td_ns::BinaryTypeMapResultEntry<T1,
27+
std::uint32_t,
28+
T2,
29+
std::uint32_t,
30+
std::uint64_t>,
2631
td_ns::BinaryTypeMapResultEntry<T1,
2732
std::int32_t,
2833
T2,
2934
std::int32_t,
3035
std::int32_t>,
36+
td_ns::BinaryTypeMapResultEntry<T1,
37+
std::int32_t,
38+
T2,
39+
std::int32_t,
40+
std::int64_t>,
3141
td_ns::BinaryTypeMapResultEntry<T1,
3242
std::uint64_t,
3343
T2,
@@ -39,6 +49,7 @@ template <typename T1, typename T2> struct DotAtomicOutputType
3949
std::int64_t,
4050
std::int64_t>,
4151
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
52+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, double>,
4253
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
4354
td_ns::DefaultResultEntry<void>>::result_type;
4455
};
@@ -49,6 +60,7 @@ template <typename T1, typename T2> struct DotNoAtomicOutputType
4960
{
5061
using value_type = typename std::disjunction< // disjunction is C++17
5162
// feature, supported by DPC++
63+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
5264
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, std::uint8_t>,
5365
td_ns::BinaryTypeMapResultEntry<T1,
5466
std::uint8_t,
@@ -75,11 +87,21 @@ template <typename T1, typename T2> struct DotNoAtomicOutputType
7587
T2,
7688
std::uint32_t,
7789
std::uint32_t>,
90+
td_ns::BinaryTypeMapResultEntry<T1,
91+
std::uint32_t,
92+
T2,
93+
std::uint32_t,
94+
std::uint64_t>,
7895
td_ns::BinaryTypeMapResultEntry<T1,
7996
std::int32_t,
8097
T2,
8198
std::int32_t,
8299
std::int32_t>,
100+
td_ns::BinaryTypeMapResultEntry<T1,
101+
std::int32_t,
102+
T2,
103+
std::int32_t,
104+
std::int64_t>,
83105
td_ns::BinaryTypeMapResultEntry<T1,
84106
std::uint64_t,
85107
T2,
@@ -102,6 +124,11 @@ template <typename T1, typename T2> struct DotNoAtomicOutputType
102124
T2,
103125
std::complex<float>,
104126
std::complex<float>>,
127+
td_ns::BinaryTypeMapResultEntry<T1,
128+
std::complex<float>,
129+
T2,
130+
std::complex<float>,
131+
std::complex<double>>,
105132
td_ns::BinaryTypeMapResultEntry<T1,
106133
std::complex<double>,
107134
T2,

0 commit comments

Comments
 (0)