15
15
16
16
namespace Fortran ::runtime {
17
17
18
- template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
18
+ // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
19
+ // argument; MATMUL does not.
20
+
21
+ // General accumulator for any type and stride; this is not used for
22
+ // contiguous numeric vectors.
23
+ template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
19
24
class Accumulator {
20
25
public:
21
- using Result = RESULT ;
26
+ using Result = AccumulationType<RCAT, RKIND> ;
22
27
Accumulator (const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
23
- void Accumulate (SubscriptValue xAt, SubscriptValue yAt) {
24
- if constexpr (XCAT == TypeCategory::Complex) {
25
- sum_ += std::conj (static_cast <Result>(*x_.Element <XT>(&xAt))) *
26
- static_cast <Result>(*y_.Element <YT>(&yAt));
27
- } else if constexpr (XCAT == TypeCategory::Logical) {
28
+ void AccumulateIndexed (SubscriptValue xAt, SubscriptValue yAt) {
29
+ if constexpr (RCAT == TypeCategory::Logical) {
28
30
sum_ = sum_ ||
29
31
(IsLogicalElementTrue (x_, &xAt) && IsLogicalElementTrue (y_, &yAt));
30
32
} else {
31
- sum_ += static_cast <Result>(*x_.Element <XT>(&xAt)) *
32
- static_cast <Result>(*y_.Element <YT>(&yAt));
33
+ const XT &xElement{*x_.Element <XT>(&xAt)};
34
+ const YT &yElement{*y_.Element <YT>(&yAt)};
35
+ if constexpr (RCAT == TypeCategory::Complex) {
36
+ sum_ += std::conj (static_cast <Result>(xElement)) *
37
+ static_cast <Result>(yElement);
38
+ } else {
39
+ sum_ += static_cast <Result>(xElement) * static_cast <Result>(yElement);
40
+ }
33
41
}
34
42
}
35
43
Result GetResult () const { return sum_; }
@@ -39,34 +47,59 @@ class Accumulator {
39
47
Result sum_{};
40
48
};
41
49
42
- template <typename RESULT, TypeCategory XCAT , typename XT, typename YT>
43
- static inline RESULT DoDotProduct (
50
+ template <TypeCategory RCAT, int RKIND , typename XT, typename YT>
51
+ static inline CppTypeFor<RCAT, RKIND> DoDotProduct (
44
52
const Descriptor &x, const Descriptor &y, Terminator &terminator) {
53
+ using Result = CppTypeFor<RCAT, RKIND>;
45
54
RUNTIME_CHECK (terminator, x.rank () == 1 && y.rank () == 1 );
46
55
SubscriptValue n{x.GetDimension (0 ).Extent ()};
47
56
if (SubscriptValue yN{y.GetDimension (0 ).Extent ()}; yN != n) {
48
57
terminator.Crash (
49
58
" DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd" ,
50
59
static_cast <std::intmax_t >(n), static_cast <std::intmax_t >(yN));
51
60
}
52
- if constexpr (std::is_same_v<XT, YT>) {
53
- if constexpr (std::is_same_v<XT, float >) {
54
- // TODO: call BLAS-1 SDOT or SDSDOT
55
- } else if constexpr (std::is_same_v<XT, double >) {
56
- // TODO: call BLAS-1 DDOT
57
- } else if constexpr (std::is_same_v<XT, std::complex<float >>) {
58
- // TODO: call BLAS-1 CDOTC
59
- } else if constexpr (std::is_same_v<XT, std::complex<float >>) {
60
- // TODO: call BLAS-1 ZDOTC
61
+ if constexpr (RCAT != TypeCategory::Logical) {
62
+ if (x.GetDimension (0 ).ByteStride () == sizeof (XT) &&
63
+ y.GetDimension (0 ).ByteStride () == sizeof (YT)) {
64
+ // Contiguous numeric vectors
65
+ if constexpr (std::is_same_v<XT, YT>) {
66
+ // Contiguous homogeneous numeric vectors
67
+ if constexpr (std::is_same_v<XT, float >) {
68
+ // TODO: call BLAS-1 SDOT or SDSDOT
69
+ } else if constexpr (std::is_same_v<XT, double >) {
70
+ // TODO: call BLAS-1 DDOT
71
+ } else if constexpr (std::is_same_v<XT, std::complex<float >>) {
72
+ // TODO: call BLAS-1 CDOTC
73
+ } else if constexpr (std::is_same_v<XT, std::complex<double >>) {
74
+ // TODO: call BLAS-1 ZDOTC
75
+ }
76
+ }
77
+ XT *xp{x.OffsetElement <XT>(0 )};
78
+ YT *yp{y.OffsetElement <YT>(0 )};
79
+ using AccumType = AccumulationType<RCAT, RKIND>;
80
+ AccumType accum{};
81
+ if constexpr (RCAT == TypeCategory::Complex) {
82
+ for (SubscriptValue j{0 }; j < n; ++j) {
83
+ accum += std::conj (static_cast <AccumType>(*xp++)) *
84
+ static_cast <AccumType>(*yp++);
85
+ }
86
+ } else {
87
+ for (SubscriptValue j{0 }; j < n; ++j) {
88
+ accum +=
89
+ static_cast <AccumType>(*xp++) * static_cast <AccumType>(*yp++);
90
+ }
91
+ }
92
+ return static_cast <Result>(accum);
61
93
}
62
94
}
95
+ // Non-contiguous, heterogeneous, & LOGICAL cases
63
96
SubscriptValue xAt{x.GetDimension (0 ).LowerBound ()};
64
97
SubscriptValue yAt{y.GetDimension (0 ).LowerBound ()};
65
- Accumulator<RESULT, XCAT , XT, YT> accumulator{x, y};
98
+ Accumulator<RCAT, RKIND , XT, YT> accumulator{x, y};
66
99
for (SubscriptValue j{0 }; j < n; ++j) {
67
- accumulator.Accumulate (xAt++, yAt++);
100
+ accumulator.AccumulateIndexed (xAt++, yAt++);
68
101
}
69
- return accumulator.GetResult ();
102
+ return static_cast <Result>( accumulator.GetResult () );
70
103
}
71
104
72
105
template <TypeCategory RCAT, int RKIND> struct DotProduct {
@@ -79,7 +112,7 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
79
112
GetResultType (XCAT, XKIND, YCAT, YKIND)}) {
80
113
if constexpr (resultType->first == RCAT &&
81
114
resultType->second <= RKIND) {
82
- return DoDotProduct<Result, XCAT , CppTypeFor<XCAT, XKIND>,
115
+ return DoDotProduct<RCAT, RKIND , CppTypeFor<XCAT, XKIND>,
83
116
CppTypeFor<YCAT, YKIND>>(x, y, terminator);
84
117
}
85
118
}
@@ -97,26 +130,32 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
97
130
Result operator ()(const Descriptor &x, const Descriptor &y,
98
131
const char *source, int line) const {
99
132
Terminator terminator{source, line};
100
- auto xCatKind{x.type ().GetCategoryAndKind ()};
101
- auto yCatKind{y.type ().GetCategoryAndKind ()};
102
- RUNTIME_CHECK (terminator, xCatKind.has_value () && yCatKind.has_value ());
103
- return ApplyType<DP1, Result>(xCatKind->first , xCatKind->second , terminator,
104
- x, y, terminator, yCatKind->first , yCatKind->second );
133
+ if (RCAT != TypeCategory::Logical && x.type () == y.type ()) {
134
+ // No conversions needed, operands and result have same known type
135
+ return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
136
+ x, y, terminator);
137
+ } else {
138
+ auto xCatKind{x.type ().GetCategoryAndKind ()};
139
+ auto yCatKind{y.type ().GetCategoryAndKind ()};
140
+ RUNTIME_CHECK (terminator, xCatKind.has_value () && yCatKind.has_value ());
141
+ return ApplyType<DP1, Result>(xCatKind->first , xCatKind->second ,
142
+ terminator, x, y, terminator, yCatKind->first , yCatKind->second );
143
+ }
105
144
}
106
145
};
107
146
108
147
extern " C" {
109
148
std::int8_t RTNAME (DotProductInteger1)(
110
149
const Descriptor &x, const Descriptor &y, const char *source, int line) {
111
- return DotProduct<TypeCategory::Integer, 8 >{}(x, y, source, line);
150
+ return DotProduct<TypeCategory::Integer, 1 >{}(x, y, source, line);
112
151
}
113
152
std::int16_t RTNAME (DotProductInteger2)(
114
153
const Descriptor &x, const Descriptor &y, const char *source, int line) {
115
- return DotProduct<TypeCategory::Integer, 8 >{}(x, y, source, line);
154
+ return DotProduct<TypeCategory::Integer, 2 >{}(x, y, source, line);
116
155
}
117
156
std::int32_t RTNAME (DotProductInteger4)(
118
157
const Descriptor &x, const Descriptor &y, const char *source, int line) {
119
- return DotProduct<TypeCategory::Integer, 8 >{}(x, y, source, line);
158
+ return DotProduct<TypeCategory::Integer, 4 >{}(x, y, source, line);
120
159
}
121
160
std::int64_t RTNAME (DotProductInteger8)(
122
161
const Descriptor &x, const Descriptor &y, const char *source, int line) {
@@ -130,9 +169,10 @@ common::int128_t RTNAME(DotProductInteger16)(
130
169
#endif
131
170
132
171
// TODO: REAL/COMPLEX(2 & 3)
172
+ // Intermediate results and operations are at least 64 bits
133
173
float RTNAME (DotProductReal4)(
134
174
const Descriptor &x, const Descriptor &y, const char *source, int line) {
135
- return DotProduct<TypeCategory::Real, 8 >{}(x, y, source, line);
175
+ return DotProduct<TypeCategory::Real, 4 >{}(x, y, source, line);
136
176
}
137
177
double RTNAME (DotProductReal8)(
138
178
const Descriptor &x, const Descriptor &y, const char *source, int line) {
@@ -152,7 +192,7 @@ long double RTNAME(DotProductReal16)(
152
192
153
193
void RTNAME (CppDotProductComplex4)(std::complex<float > &result,
154
194
const Descriptor &x, const Descriptor &y, const char *source, int line) {
155
- auto z{DotProduct<TypeCategory::Complex, 8 >{}(x, y, source, line)};
195
+ auto z{DotProduct<TypeCategory::Complex, 4 >{}(x, y, source, line)};
156
196
result = std::complex<float >{
157
197
static_cast <float >(z.real ()), static_cast <float >(z.imag ())};
158
198
}
0 commit comments