Skip to content

Commit 7740ce5

Browse files
klauslerjeanPerier
authored andcommitted
[flang] Speed common runtime cases of DOT_PRODUCT & MATMUL
Look for contiguous numeric argument arrays at runtime and use specialized code for them. Differential Revision: https://reviews.llvm.org/D112239
1 parent 7749fa2 commit 7740ce5

File tree

5 files changed

+272
-90
lines changed

5 files changed

+272
-90
lines changed

flang/include/flang/Runtime/c-or-cpp.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
#define IF_CPLUSPLUS(x) x
1414
#define IF_NOT_CPLUSPLUS(x)
1515
#define DEFAULT_VALUE(x) = (x)
16+
#define RESTRICT __restrict
1617
#else
1718
#include <stdbool.h>
1819
#define IF_CPLUSPLUS(x)
1920
#define IF_NOT_CPLUSPLUS(x) x
2021
#define DEFAULT_VALUE(x)
22+
#define RESTRICT restrict
2123
#endif
2224

2325
#define FORTRAN_EXTERN_C_BEGIN IF_CPLUSPLUS(extern "C" {)

flang/include/flang/Runtime/descriptor.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,10 @@ class Descriptor {
304304

305305
bool IsContiguous(int leadingDimensions = maxRank) const {
306306
auto bytes{static_cast<SubscriptValue>(ElementBytes())};
307-
for (int j{0}; j < leadingDimensions && j < raw_.rank; ++j) {
307+
if (leadingDimensions > raw_.rank) {
308+
leadingDimensions = raw_.rank;
309+
}
310+
for (int j{0}; j < leadingDimensions; ++j) {
308311
const Dimension &dim{GetDimension(j)};
309312
if (bytes != dim.ByteStride()) {
310313
return false;

flang/runtime/dot-product.cpp

Lines changed: 74 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,29 @@
1515

1616
namespace Fortran::runtime {
1717

18-
template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
18+
// Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
19+
// argument; MATMUL does not.
20+
21+
// General accumulator for any type and stride; this is not used for
22+
// contiguous numeric vectors.
23+
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
1924
class Accumulator {
2025
public:
21-
using Result = RESULT;
26+
using Result = AccumulationType<RCAT, RKIND>;
2227
Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
23-
void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
24-
if constexpr (XCAT == TypeCategory::Complex) {
25-
sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
26-
static_cast<Result>(*y_.Element<YT>(&yAt));
27-
} else if constexpr (XCAT == TypeCategory::Logical) {
28+
void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
29+
if constexpr (RCAT == TypeCategory::Logical) {
2830
sum_ = sum_ ||
2931
(IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
3032
} else {
31-
sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
32-
static_cast<Result>(*y_.Element<YT>(&yAt));
33+
const XT &xElement{*x_.Element<XT>(&xAt)};
34+
const YT &yElement{*y_.Element<YT>(&yAt)};
35+
if constexpr (RCAT == TypeCategory::Complex) {
36+
sum_ += std::conj(static_cast<Result>(xElement)) *
37+
static_cast<Result>(yElement);
38+
} else {
39+
sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
40+
}
3341
}
3442
}
3543
Result GetResult() const { return sum_; }
@@ -39,34 +47,59 @@ class Accumulator {
3947
Result sum_{};
4048
};
4149

42-
template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
43-
static inline RESULT DoDotProduct(
50+
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
51+
static inline CppTypeFor<RCAT, RKIND> DoDotProduct(
4452
const Descriptor &x, const Descriptor &y, Terminator &terminator) {
53+
using Result = CppTypeFor<RCAT, RKIND>;
4554
RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
4655
SubscriptValue n{x.GetDimension(0).Extent()};
4756
if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
4857
terminator.Crash(
4958
"DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
5059
static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
5160
}
52-
if constexpr (std::is_same_v<XT, YT>) {
53-
if constexpr (std::is_same_v<XT, float>) {
54-
// TODO: call BLAS-1 SDOT or SDSDOT
55-
} else if constexpr (std::is_same_v<XT, double>) {
56-
// TODO: call BLAS-1 DDOT
57-
} else if constexpr (std::is_same_v<XT, std::complex<float>>) {
58-
// TODO: call BLAS-1 CDOTC
59-
} else if constexpr (std::is_same_v<XT, std::complex<float>>) {
60-
// TODO: call BLAS-1 ZDOTC
61+
if constexpr (RCAT != TypeCategory::Logical) {
62+
if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
63+
y.GetDimension(0).ByteStride() == sizeof(YT)) {
64+
// Contiguous numeric vectors
65+
if constexpr (std::is_same_v<XT, YT>) {
66+
// Contiguous homogeneous numeric vectors
67+
if constexpr (std::is_same_v<XT, float>) {
68+
// TODO: call BLAS-1 SDOT or SDSDOT
69+
} else if constexpr (std::is_same_v<XT, double>) {
70+
// TODO: call BLAS-1 DDOT
71+
} else if constexpr (std::is_same_v<XT, std::complex<float>>) {
72+
// TODO: call BLAS-1 CDOTC
73+
} else if constexpr (std::is_same_v<XT, std::complex<double>>) {
74+
// TODO: call BLAS-1 ZDOTC
75+
}
76+
}
77+
XT *xp{x.OffsetElement<XT>(0)};
78+
YT *yp{y.OffsetElement<YT>(0)};
79+
using AccumType = AccumulationType<RCAT, RKIND>;
80+
AccumType accum{};
81+
if constexpr (RCAT == TypeCategory::Complex) {
82+
for (SubscriptValue j{0}; j < n; ++j) {
83+
accum += std::conj(static_cast<AccumType>(*xp++)) *
84+
static_cast<AccumType>(*yp++);
85+
}
86+
} else {
87+
for (SubscriptValue j{0}; j < n; ++j) {
88+
accum +=
89+
static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
90+
}
91+
}
92+
return static_cast<Result>(accum);
6193
}
6294
}
95+
// Non-contiguous, heterogeneous, & LOGICAL cases
6396
SubscriptValue xAt{x.GetDimension(0).LowerBound()};
6497
SubscriptValue yAt{y.GetDimension(0).LowerBound()};
65-
Accumulator<RESULT, XCAT, XT, YT> accumulator{x, y};
98+
Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
6699
for (SubscriptValue j{0}; j < n; ++j) {
67-
accumulator.Accumulate(xAt++, yAt++);
100+
accumulator.AccumulateIndexed(xAt++, yAt++);
68101
}
69-
return accumulator.GetResult();
102+
return static_cast<Result>(accumulator.GetResult());
70103
}
71104

72105
template <TypeCategory RCAT, int RKIND> struct DotProduct {
@@ -79,7 +112,7 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
79112
GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
80113
if constexpr (resultType->first == RCAT &&
81114
resultType->second <= RKIND) {
82-
return DoDotProduct<Result, XCAT, CppTypeFor<XCAT, XKIND>,
115+
return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
83116
CppTypeFor<YCAT, YKIND>>(x, y, terminator);
84117
}
85118
}
@@ -97,26 +130,32 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
97130
Result operator()(const Descriptor &x, const Descriptor &y,
98131
const char *source, int line) const {
99132
Terminator terminator{source, line};
100-
auto xCatKind{x.type().GetCategoryAndKind()};
101-
auto yCatKind{y.type().GetCategoryAndKind()};
102-
RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
103-
return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, terminator,
104-
x, y, terminator, yCatKind->first, yCatKind->second);
133+
if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
134+
// No conversions needed, operands and result have same known type
135+
return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
136+
x, y, terminator);
137+
} else {
138+
auto xCatKind{x.type().GetCategoryAndKind()};
139+
auto yCatKind{y.type().GetCategoryAndKind()};
140+
RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
141+
return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
142+
terminator, x, y, terminator, yCatKind->first, yCatKind->second);
143+
}
105144
}
106145
};
107146

108147
extern "C" {
109148
std::int8_t RTNAME(DotProductInteger1)(
110149
const Descriptor &x, const Descriptor &y, const char *source, int line) {
111-
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
150+
return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
112151
}
113152
std::int16_t RTNAME(DotProductInteger2)(
114153
const Descriptor &x, const Descriptor &y, const char *source, int line) {
115-
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
154+
return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
116155
}
117156
std::int32_t RTNAME(DotProductInteger4)(
118157
const Descriptor &x, const Descriptor &y, const char *source, int line) {
119-
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
158+
return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
120159
}
121160
std::int64_t RTNAME(DotProductInteger8)(
122161
const Descriptor &x, const Descriptor &y, const char *source, int line) {
@@ -130,9 +169,10 @@ common::int128_t RTNAME(DotProductInteger16)(
130169
#endif
131170

132171
// TODO: REAL/COMPLEX(2 & 3)
172+
// Intermediate results and operations are at least 64 bits
133173
float RTNAME(DotProductReal4)(
134174
const Descriptor &x, const Descriptor &y, const char *source, int line) {
135-
return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
175+
return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
136176
}
137177
double RTNAME(DotProductReal8)(
138178
const Descriptor &x, const Descriptor &y, const char *source, int line) {
@@ -152,7 +192,7 @@ long double RTNAME(DotProductReal16)(
152192

153193
void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
154194
const Descriptor &x, const Descriptor &y, const char *source, int line) {
155-
auto z{DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line)};
195+
auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)};
156196
result = std::complex<float>{
157197
static_cast<float>(z.real()), static_cast<float>(z.imag())};
158198
}

0 commit comments

Comments
 (0)