Skip to content

Commit 79a353e

Browse files
committed
Rework in line with Slava's review comments
1 parent 8f24bb5 commit 79a353e

File tree

4 files changed

+138
-90
lines changed

4 files changed

+138
-90
lines changed

flang-rt/include/flang-rt/runtime/descriptor.h

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -440,34 +440,55 @@ static_assert(sizeof(Descriptor) == sizeof(ISO::CFI_cdesc_t));
440440
// Lightweight iterator-like API to simplify specialising Descriptor indexing
441441
// in cases where it can improve application performance. On account of the
442442
// purpose of this API being performance optimisation, it is up to the user to
443-
// do all the necessary checks to make sure the RANK1=true variant can be used
443+
// do all the necessary checks to make sure the specialised variants can be used
444444
// safely and that Advance() is not called more times than the number of
445445
// elements in the Descriptor allows for.
446-
template <bool RANK1 = false> class DescriptorIterator {
446+
// Default RANK=-1 supports aray descriptors of any rank up to maxRank.
447+
template <int RANK = -1> class DescriptorIterator {
447448
private:
448449
const Descriptor &descriptor;
449450
SubscriptValue subscripts[maxRank];
450451
std::size_t elementOffset = 0;
451452

452453
public:
453-
DescriptorIterator(const Descriptor &descriptor) : descriptor(descriptor) {
454+
RT_API_ATTRS DescriptorIterator(const Descriptor &descriptor)
455+
: descriptor(descriptor) {
454456
descriptor.GetLowerBounds(subscripts);
455-
if constexpr (RANK1) {
457+
if constexpr (RANK == 1) {
456458
elementOffset = descriptor.SubscriptByteOffset(0, subscripts[0]);
457459
}
458460
};
459461

460-
template <typename A> A *Get() {
461-
if constexpr (RANK1) {
462-
return descriptor.OffsetElement<A>(elementOffset);
462+
template <typename A> RT_API_ATTRS A *Get() {
463+
std::size_t offset = 0;
464+
// The rank-1 case doesn't require looping at all
465+
if constexpr (RANK == 1) {
466+
offset = elementOffset;
467+
// The compiler might be able to optimise this better if we know the rank
468+
// at compile time
469+
} else if (RANK != -1) {
470+
for (int j{0}; j < RANK; ++j) {
471+
offset += descriptor.SubscriptByteOffset(j, subscripts[j]);
472+
}
473+
// General fallback
463474
} else {
464-
return descriptor.Element<A>(subscripts);
475+
offset = descriptor.SubscriptsToByteOffset(subscripts);
465476
}
477+
478+
return descriptor.OffsetElement<A>(offset);
466479
}
467480

468-
void Advance() {
469-
if constexpr (RANK1) {
481+
RT_API_ATTRS void Advance() {
482+
if constexpr (RANK == 1) {
470483
elementOffset += descriptor.GetDimension(0).ByteStride();
484+
} else if (RANK != -1) {
485+
for (int j{0}; j < RANK; ++j) {
486+
const Dimension &dim{descriptor.GetDimension(j)};
487+
if (subscripts[j]++ < dim.UpperBound()) {
488+
break;
489+
}
490+
subscripts[j] = dim.LowerBound();
491+
}
471492
} else {
472493
descriptor.IncrementSubscripts(subscripts);
473494
}

flang-rt/include/flang-rt/runtime/tools.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -511,13 +511,13 @@ inline RT_API_ATTRS const char *FindCharacter(
511511
// Copy payload data from one allocated descriptor to another.
512512
// Assumes element counts and element sizes match, and that both
513513
// descriptors are allocated.
514-
template <bool RANK1 = false>
514+
template <typename P = char, int RANK = -1>
515515
RT_API_ATTRS void ShallowCopyDiscontiguousToDiscontiguous(
516516
const Descriptor &to, const Descriptor &from);
517-
template <bool RANK1 = false>
517+
template <typename P = char, int RANK = -1>
518518
RT_API_ATTRS void ShallowCopyDiscontiguousToContiguous(
519519
const Descriptor &to, const Descriptor &from);
520-
template <bool RANK1 = false>
520+
template <typename P = char, int RANK = -1>
521521
RT_API_ATTRS void ShallowCopyContiguousToDiscontiguous(
522522
const Descriptor &to, const Descriptor &from);
523523
RT_API_ATTRS void ShallowCopy(const Descriptor &to, const Descriptor &from,

flang-rt/lib/runtime/assign.cpp

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -492,21 +492,11 @@ RT_API_ATTRS void Assign(Descriptor &to, const Descriptor &from,
492492
terminator.Crash("unexpected type code %d in blank padded Assign()",
493493
to.type().raw());
494494
}
495-
} else {
496-
// We can't simply call ShallowCopy due to edge cases such as character
497-
// truncation or assignments where the RHS is a scalar.
498-
if (toElementBytes == fromElementBytes && to.IsContiguous()) {
499-
if (to.rank() == 1 && from.rank() == 1) {
500-
ShallowCopyDiscontiguousToContiguous<true>(to, from);
501-
} else {
502-
ShallowCopyDiscontiguousToContiguous<false>(to, from);
503-
}
504-
} else {
505-
if (to.rank() == 1 && from.rank() == 1) {
506-
ShallowCopyDiscontiguousToDiscontiguous<true>(to, from);
507-
} else {
508-
ShallowCopyDiscontiguousToDiscontiguous<false>(to, from);
509-
}
495+
} else { // elemental copies, possibly with character truncation
496+
for (std::size_t n{toElements}; n-- > 0;
497+
to.IncrementSubscripts(toAt), from.IncrementSubscripts(fromAt)) {
498+
memmoveFct(to.Element<char>(toAt), from.Element<const char>(fromAt),
499+
toElementBytes);
510500
}
511501
}
512502
}
@@ -598,7 +588,8 @@ void RTDEF(CopyInAssign)(Descriptor &temp, const Descriptor &var,
598588
temp = var;
599589
temp.set_base_addr(nullptr);
600590
temp.raw().attribute = CFI_attribute_allocatable;
601-
RTNAME(AssignTemporary)(temp, var, sourceFile, sourceLine);
591+
temp.Allocate(kNoAsyncId);
592+
ShallowCopy(temp, var);
602593
}
603594

604595
void RTDEF(CopyOutAssign)(
@@ -607,9 +598,10 @@ void RTDEF(CopyOutAssign)(
607598

608599
// Copyout from the temporary must not cause any finalizations
609600
// for LHS. The variable must be properly initialized already.
610-
if (var)
611-
Assign(*var, temp, terminator, NoAssignFlags);
612-
temp.Destroy(/*finalize=*/false, /*destroyPointers=*/false, &terminator);
601+
if (var) {
602+
ShallowCopy(*var, temp);
603+
}
604+
temp.Deallocate();
613605
}
614606

615607
void RTDEF(AssignExplicitLengthCharacter)(Descriptor &to,

flang-rt/lib/runtime/tools.cpp

Lines changed: 93 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -114,111 +114,146 @@ RT_API_ATTRS void CheckIntegerKind(
114114
}
115115
}
116116

117-
template <bool RANK1>
117+
template <typename P, int RANK>
118118
RT_API_ATTRS void ShallowCopyDiscontiguousToDiscontiguous(
119119
const Descriptor &to, const Descriptor &from) {
120-
DescriptorIterator<RANK1> toIt{to};
121-
DescriptorIterator<RANK1> fromIt{from};
120+
DescriptorIterator<RANK> toIt{to};
121+
DescriptorIterator<RANK> fromIt{from};
122+
// Knowing the size at compile time can enable memcpy inlining optimisations
123+
constexpr std::size_t typeElementBytes{sizeof(P)};
124+
// We might still need to check the actual size as a fallback
122125
std::size_t elementBytes{to.ElementBytes()};
123126
for (std::size_t n{to.Elements()}; n-- > 0;
124127
toIt.Advance(), fromIt.Advance()) {
125-
// Checking the size at runtime and making sure the pointer passed to memcpy
126-
// has a type that matches the element size makes it possible for the
127-
// compiler to optimise out the memcpy calls altogether and can
128-
// substantially improve performance for some applications.
129-
if (elementBytes == 16) {
130-
std::memcpy(toIt.template Get<__int128_t>(),
131-
fromIt.template Get<__int128_t>(), elementBytes);
132-
} else if (elementBytes == 8) {
133-
std::memcpy(toIt.template Get<int64_t>(), fromIt.template Get<int64_t>(),
134-
elementBytes);
135-
} else if (elementBytes == 4) {
136-
std::memcpy(toIt.template Get<int32_t>(), fromIt.template Get<int32_t>(),
137-
elementBytes);
138-
} else if (elementBytes == 2) {
139-
std::memcpy(toIt.template Get<int16_t>(), fromIt.template Get<int16_t>(),
140-
elementBytes);
128+
// typeElementBytes == 1 when P is a char - the non-specialised case
129+
if constexpr (typeElementBytes != 1) {
130+
std::memcpy(
131+
toIt.template Get<P>(), fromIt.template Get<P>(), typeElementBytes);
141132
} else {
142133
std::memcpy(
143-
toIt.template Get<char>(), fromIt.template Get<char>(), elementBytes);
134+
toIt.template Get<P>(), fromIt.template Get<P>(), elementBytes);
144135
}
145136
}
146137
}
147138

148-
template <bool RANK1>
139+
template <typename P, int RANK>
149140
RT_API_ATTRS void ShallowCopyDiscontiguousToContiguous(
150141
const Descriptor &to, const Descriptor &from) {
151142
char *toAt{to.OffsetElement()};
143+
constexpr std::size_t typeElementBytes{sizeof(P)};
152144
std::size_t elementBytes{to.ElementBytes()};
153-
DescriptorIterator<RANK1> fromIt{from};
145+
DescriptorIterator<RANK> fromIt{from};
154146
for (std::size_t n{to.Elements()}; n-- > 0;
155147
toAt += elementBytes, fromIt.Advance()) {
156-
if (elementBytes == 16) {
157-
std::memcpy(toAt, fromIt.template Get<__int128_t>(), elementBytes);
158-
} else if (elementBytes == 8) {
159-
std::memcpy(toAt, fromIt.template Get<int64_t>(), elementBytes);
160-
} else if (elementBytes == 4) {
161-
std::memcpy(toAt, fromIt.template Get<int32_t>(), elementBytes);
162-
} else if (elementBytes == 2) {
163-
std::memcpy(toAt, fromIt.template Get<int16_t>(), elementBytes);
148+
if constexpr (typeElementBytes != 1) {
149+
std::memcpy(toAt, fromIt.template Get<P>(), typeElementBytes);
164150
} else {
165-
std::memcpy(toAt, fromIt.template Get<char>(), elementBytes);
151+
std::memcpy(toAt, fromIt.template Get<P>(), elementBytes);
166152
}
167153
}
168154
}
169155

170-
template <bool RANK1>
156+
template <typename P, int RANK>
171157
RT_API_ATTRS void ShallowCopyContiguousToDiscontiguous(
172158
const Descriptor &to, const Descriptor &from) {
173159
char *fromAt{from.OffsetElement()};
174-
DescriptorIterator<RANK1> toIt{to};
160+
DescriptorIterator<RANK> toIt{to};
161+
constexpr std::size_t typeElementBytes{sizeof(P)};
175162
std::size_t elementBytes{to.ElementBytes()};
176163
for (std::size_t n{to.Elements()}; n-- > 0;
177164
toIt.Advance(), fromAt += elementBytes) {
178-
if (elementBytes == 16) {
179-
std::memcpy(toIt.template Get<__int128_t>(), fromAt, elementBytes);
180-
} else if (elementBytes == 8) {
181-
std::memcpy(toIt.template Get<int64_t>(), fromAt, elementBytes);
182-
} else if (elementBytes == 4) {
183-
std::memcpy(toIt.template Get<int32_t>(), fromAt, elementBytes);
184-
} else if (elementBytes == 2) {
185-
std::memcpy(toIt.template Get<int16_t>(), fromAt, elementBytes);
165+
if constexpr (typeElementBytes != 1) {
166+
std::memcpy(toIt.template Get<P>(), fromAt, typeElementBytes);
186167
} else {
187-
std::memcpy(toIt.template Get<char>(), fromAt, elementBytes);
168+
std::memcpy(toIt.template Get<P>(), fromAt, elementBytes);
188169
}
189170
}
190171
}
191172

192-
RT_API_ATTRS void ShallowCopy(const Descriptor &to, const Descriptor &from,
173+
// ShallowCopy helper for calling the correct specialised variant based on
174+
// scenario
175+
template <typename P, int RANK = -1>
176+
RT_API_ATTRS void ShallowCopyInner(const Descriptor &to, const Descriptor &from,
193177
bool toIsContiguous, bool fromIsContiguous) {
194178
if (toIsContiguous) {
195179
if (fromIsContiguous) {
196180
std::memcpy(to.OffsetElement(), from.OffsetElement(),
197181
to.Elements() * to.ElementBytes());
198182
} else {
199-
if (to.rank() == 1 && from.rank() == 1) {
200-
ShallowCopyDiscontiguousToContiguous<true>(to, from);
201-
} else {
202-
ShallowCopyDiscontiguousToContiguous<false>(to, from);
203-
}
183+
ShallowCopyDiscontiguousToContiguous<P, RANK>(to, from);
204184
}
205185
} else {
206186
if (fromIsContiguous) {
207-
if (to.rank() == 1 && from.rank() == 1) {
208-
ShallowCopyContiguousToDiscontiguous<true>(to, from);
209-
} else {
210-
ShallowCopyContiguousToDiscontiguous<false>(to, from);
211-
}
187+
ShallowCopyContiguousToDiscontiguous<P, RANK>(to, from);
212188
} else {
213-
if (to.rank() == 1 && from.rank() == 1) {
214-
ShallowCopyDiscontiguousToDiscontiguous<true>(to, from);
215-
} else {
216-
ShallowCopyDiscontiguousToDiscontiguous<false>(to, from);
217-
}
189+
ShallowCopyDiscontiguousToDiscontiguous<P, RANK>(to, from);
218190
}
219191
}
220192
}
221193

194+
// ShallowCopy helper for specialising the variants based on array rank
195+
template <typename P>
196+
RT_API_ATTRS void ShallowCopyRank(const Descriptor &to, const Descriptor &from,
197+
bool toIsContiguous, bool fromIsContiguous) {
198+
if (to.rank() == 1 && from.rank() == 1) {
199+
ShallowCopyInner<P, 1>(to, from, toIsContiguous, fromIsContiguous);
200+
} else if (to.rank() == 2 && from.rank() == 2) {
201+
ShallowCopyInner<P, 2>(to, from, toIsContiguous, fromIsContiguous);
202+
} else if (to.rank() == 3 && from.rank() == 3) {
203+
ShallowCopyInner<P, 3>(to, from, toIsContiguous, fromIsContiguous);
204+
} else if (to.rank() == 4 && from.rank() == 4) {
205+
ShallowCopyInner<P, 4>(to, from, toIsContiguous, fromIsContiguous);
206+
} else if (to.rank() == 5 && from.rank() == 5) {
207+
ShallowCopyInner<P, 5>(to, from, toIsContiguous, fromIsContiguous);
208+
} else if (to.rank() == 6 && from.rank() == 6) {
209+
ShallowCopyInner<P, 6>(to, from, toIsContiguous, fromIsContiguous);
210+
} else if (to.rank() == 7 && from.rank() == 7) {
211+
ShallowCopyInner<P, 7>(to, from, toIsContiguous, fromIsContiguous);
212+
} else if (to.rank() == 8 && from.rank() == 8) {
213+
ShallowCopyInner<P, 8>(to, from, toIsContiguous, fromIsContiguous);
214+
} else if (to.rank() == 9 && from.rank() == 9) {
215+
ShallowCopyInner<P, 9>(to, from, toIsContiguous, fromIsContiguous);
216+
} else if (to.rank() == 10 && from.rank() == 10) {
217+
ShallowCopyInner<P, 10>(to, from, toIsContiguous, fromIsContiguous);
218+
} else {
219+
ShallowCopyInner<P>(to, from, toIsContiguous, fromIsContiguous);
220+
}
221+
}
222+
223+
RT_API_ATTRS void ShallowCopy(const Descriptor &to, const Descriptor &from,
224+
bool toIsContiguous, bool fromIsContiguous) {
225+
std::size_t elementBytes{to.ElementBytes()};
226+
// Checking the type at runtime and making sure the pointer passed to memcpy
227+
// has a type that matches the element type makes it possible for the compiler
228+
// to optimise out the memcpy calls altogether and can substantially improve
229+
// performance for some applications.
230+
if (to.type().IsInteger()) {
231+
if (elementBytes == sizeof(int64_t)) {
232+
ShallowCopyRank<int64_t>(to, from, toIsContiguous, fromIsContiguous);
233+
} else if (elementBytes == sizeof(int32_t)) {
234+
ShallowCopyRank<int32_t>(to, from, toIsContiguous, fromIsContiguous);
235+
} else if (elementBytes == sizeof(int16_t)) {
236+
ShallowCopyRank<int16_t>(to, from, toIsContiguous, fromIsContiguous);
237+
#if defined USING_NATIVE_INT128_T
238+
} else if (elementBytes == sizeof(__int128_t)) {
239+
ShallowCopyRank<__int128_t>(to, from, toIsContiguous, fromIsContiguous);
240+
#endif
241+
} else {
242+
ShallowCopyRank<char>(to, from, toIsContiguous, fromIsContiguous);
243+
}
244+
} else if (to.type().IsReal()) {
245+
if (elementBytes == sizeof(double)) {
246+
ShallowCopyRank<double>(to, from, toIsContiguous, fromIsContiguous);
247+
} else if (elementBytes == sizeof(float)) {
248+
ShallowCopyRank<float>(to, from, toIsContiguous, fromIsContiguous);
249+
} else {
250+
ShallowCopyRank<char>(to, from, toIsContiguous, fromIsContiguous);
251+
}
252+
} else {
253+
ShallowCopyRank<char>(to, from, toIsContiguous, fromIsContiguous);
254+
}
255+
}
256+
222257
RT_API_ATTRS void ShallowCopy(const Descriptor &to, const Descriptor &from) {
223258
ShallowCopy(to, from, to.IsContiguous(), from.IsContiguous());
224259
}

0 commit comments

Comments
 (0)