Skip to content

Commit 8ce9aee

Browse files
committed
implement suggestion from the review
1 parent 82970e1 commit 8ce9aee

File tree

1 file changed

+123
-196
lines changed

1 file changed

+123
-196
lines changed

flang/lib/Evaluate/intrinsics-library.cpp

Lines changed: 123 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -260,25 +260,65 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
260260
static_assert(map.Verify(), "map must be sorted");
261261
};
262262

263-
enum trigFunc {
264-
Cacos,
265-
Cacosh,
266-
Casin,
267-
Casinh,
268-
Catan,
269-
Catanh,
270-
Ccos,
271-
Ccosh,
272-
Cexp,
273-
Clog,
274-
Csin,
275-
Csinh,
276-
Csqrt,
277-
Ctan,
278-
Ctanh
279-
};
263+
#define COMPLEX_SIGNATURES(HOST_T) \
264+
using F = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &>; \
265+
using F2 = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &, \
266+
const std::complex<HOST_T> &>; \
267+
using F2A = FuncPointer<std::complex<HOST_T>, const HOST_T &, \
268+
const std::complex<HOST_T> &>; \
269+
using F2B = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &, \
270+
const HOST_T &>;
271+
272+
#ifndef _AIX
273+
// Helpers to map complex std::pow whose resolution in F2{std::pow} is
274+
// ambiguous as of clang++ 20.
275+
template <typename HostT>
276+
static std::complex<HostT> StdPowF2(
277+
const std::complex<HostT> &x, const std::complex<HostT> &y) {
278+
return std::pow(x, y);
279+
}
280280

281-
#ifdef _AIX
281+
template <typename HostT>
282+
static std::complex<HostT> StdPowF2A(
283+
const HostT &x, const std::complex<HostT> &y) {
284+
return std::pow(x, y);
285+
}
286+
287+
template <typename HostT>
288+
static std::complex<HostT> StdPowF2B(
289+
const std::complex<HostT> &x, const HostT &y) {
290+
return std::pow(x, y);
291+
}
292+
293+
template <typename HostT>
294+
struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
295+
COMPLEX_SIGNATURES(HostT)
296+
static constexpr HostRuntimeFunction table[]{
297+
FolderFactory<F, F{std::acos}>::Create("acos"),
298+
FolderFactory<F, F{std::acosh}>::Create("acosh"),
299+
FolderFactory<F, F{std::asin}>::Create("asin"),
300+
FolderFactory<F, F{std::asinh}>::Create("asinh"),
301+
FolderFactory<F, F{std::atan}>::Create("atan"),
302+
FolderFactory<F, F{std::atanh}>::Create("atanh"),
303+
FolderFactory<F, F{std::cos}>::Create("cos"),
304+
FolderFactory<F, F{std::cosh}>::Create("cosh"),
305+
FolderFactory<F, F{std::exp}>::Create("exp"),
306+
FolderFactory<F, F{std::log}>::Create("log"),
307+
FolderFactory<F2, F2{StdPowF2}>::Create("pow"),
308+
FolderFactory<F2A, F2A{StdPowF2A}>::Create("pow"),
309+
FolderFactory<F2B, F2B{StdPowF2B}>::Create("pow"),
310+
FolderFactory<F, F{std::sin}>::Create("sin"),
311+
FolderFactory<F, F{std::sinh}>::Create("sinh"),
312+
FolderFactory<F, F{std::sqrt}>::Create("sqrt"),
313+
FolderFactory<F, F{std::tan}>::Create("tan"),
314+
FolderFactory<F, F{std::tanh}>::Create("tanh"),
315+
};
316+
static constexpr HostRuntimeMap map{table};
317+
static_assert(map.Verify(), "map must be sorted");
318+
};
319+
#else
320+
// On AIX, call libm routines to preserve consistent value between
321+
// runtime and compile time evaluation.
282322
#ifdef __clang_major__
283323
#pragma clang diagnostic ignored "-Wc99-extensions"
284324
#endif
@@ -318,193 +358,80 @@ float _Complex ctanhf(float _Complex);
318358
double _Complex ctanh(double _Complex);
319359
}
320360

321-
enum CRI { Real, Imag };
322-
template <typename TR, typename TA> static TR &reIm(TA &x, CRI n) {
323-
return reinterpret_cast<TR(&)[2]>(x)[n];
324-
}
325-
template <typename TR, typename T> static TR CppToC(const std::complex<T> &x) {
326-
TR r;
327-
reIm<T, TR>(r, CRI::Real) = x.real();
328-
reIm<T, TR>(r, CRI::Imag) = x.imag();
329-
return r;
330-
}
331-
template <typename T, typename TA> static std::complex<T> CToCpp(const TA &x) {
332-
TA &z{const_cast<TA &>(x)};
333-
return std::complex<T>(reIm<T, TA>(z, CRI::Real), reIm<T, TA>(z, CRI::Imag));
334-
}
335-
336-
using FTypeCmplxFlt = _Complex float (*)(_Complex float);
337-
using FTypeCmplxDble = _Complex double (*)(_Complex double);
338-
template <typename T>
339-
using FTypeStdCmplx = std::complex<T> (*)(const std::complex<T> &);
340-
341-
std::map<trigFunc, std::tuple<FTypeCmplxFlt, FTypeCmplxDble>> mapLibmTrigFunc{
342-
{Cacos, {&cacosf, &cacos}}, {Cacosh, {&cacoshf, &cacosh}},
343-
{Casin, {&casinf, &casin}}, {Casinh, {&casinhf, &casinh}},
344-
{Catan, {&catanf, &catan}}, {Catanh, {&catanhf, &catanh}},
345-
{Ccos, {&ccosf, &ccos}}, {Ccosh, {&ccoshf, &ccosh}},
346-
{Cexp, {&cexpf, &cexp}}, {Clog, {&clogf, &__clog}}, {Csin, {&csinf, &csin}},
347-
{Csinh, {&csinhf, &csinh}}, {Csqrt, {&csqrtf, &csqrt}},
348-
{Ctan, {&ctanf, &ctan}}, {Ctanh, {&ctanhf, &ctanh}}};
349-
350-
template <trigFunc TF, typename HostT>
351-
std::complex<HostT> LibmTrigFunc(const std::complex<HostT> &x) {
352-
if constexpr (std::is_same_v<HostT, float>) {
353-
float _Complex r{std::get<FTypeCmplxFlt>(mapLibmTrigFunc[TF])(
354-
CppToC<float _Complex, float>(x))};
355-
return CToCpp<float, float _Complex>(r);
356-
} else if constexpr (std::is_same_v<HostT, double>) {
357-
double _Complex r{std::get<FTypeCmplxDble>(mapLibmTrigFunc[TF])(
358-
CppToC<double _Complex, double>(x))};
359-
return CToCpp<double, double _Complex>(r);
360-
}
361-
DIE("bad complex component type");
362-
}
363-
#endif
364-
365-
template <trigFunc TF, typename HostT>
366-
std::complex<HostT> StdTrigFunc(const std::complex<HostT> &x) {
367-
if constexpr (TF == Cacos) {
368-
return std::acos(x);
369-
} else if constexpr (TF == Cacosh) {
370-
return std::acosh(x);
371-
} else if constexpr (TF == Casin) {
372-
return std::asin(x);
373-
} else if constexpr (TF == Casinh) {
374-
return std::asinh(x);
375-
} else if constexpr (TF == Catan) {
376-
return std::atan(x);
377-
} else if constexpr (TF == Catanh) {
378-
return std::atanh(x);
379-
} else if constexpr (TF == Ccos) {
380-
return std::cos(x);
381-
} else if constexpr (TF == Ccosh) {
382-
return std::cosh(x);
383-
} else if constexpr (TF == Cexp) {
384-
return std::exp(x);
385-
} else if constexpr (TF == Clog) {
386-
return std::log(x);
387-
} else if constexpr (TF == Csin) {
388-
return std::sin(x);
389-
} else if constexpr (TF == Csinh) {
390-
return std::sinh(x);
391-
} else if constexpr (TF == Csqrt) {
392-
return std::sqrt(x);
393-
} else if constexpr (TF == Ctan) {
394-
return std::tan(x);
395-
} else if constexpr (TF == Ctanh) {
396-
return std::tanh(x);
397-
}
398-
DIE("unknown function");
399-
}
400-
401-
template <trigFunc TF> struct X {
402-
template <typename HostT>
403-
static std::complex<HostT> f(const std::complex<HostT> &x) {
404-
std::complex<HostT> res;
405-
#ifdef _AIX
406-
// On AIX, the implementation in libm is different from that of the STL
407-
// routines, use the libm routines here in folding for consistent results.
408-
res = LibmTrigFunc<TF>(x);
409-
#else
410-
res = StdTrigFunc<TF, HostT>(x);
411-
#endif
412-
return res;
413-
}
361+
template <typename T> struct ToStdComplex {
362+
using Type = T;
363+
using AType = Type;
364+
};
365+
template <> struct ToStdComplex<float _Complex> {
366+
using Type = std::complex<float>;
367+
using AType = const Type &;
368+
};
369+
template <> struct ToStdComplex<double _Complex> {
370+
using Type = std::complex<double>;
371+
using AType = const Type &;
414372
};
415373

416-
// Helpers to map complex std::pow whose resolution in F2{std::pow} is
417-
// ambiguous as of clang++ 20.
418-
template <typename HostT>
419-
static std::complex<HostT> StdPowF2(
420-
const std::complex<HostT> &x, const std::complex<HostT> &y) {
421-
#ifdef _AIX
422-
if constexpr (std::is_same_v<HostT, float>) {
423-
float _Complex r{cpowf(
424-
CppToC<float _Complex, float>(x), CppToC<float _Complex, float>(y))};
425-
return CToCpp<float, float _Complex>(r);
426-
} else if constexpr (std::is_same_v<HostT, double>) {
427-
double _Complex r{cpow(CppToC<double _Complex, double>(x),
428-
CppToC<double _Complex, double>(y))};
429-
return CToCpp<double, double _Complex>(r);
430-
}
431-
#else
432-
return std::pow(x, y);
433-
#endif
434-
}
435-
436-
template <typename HostT>
437-
static std::complex<HostT> StdPowF2A(
438-
const HostT &x, const std::complex<HostT> &y) {
439-
#ifdef _AIX
440-
constexpr HostT zero{0.0};
441-
std::complex<HostT> z(x, zero);
442-
if constexpr (std::is_same_v<HostT, float>) {
443-
float _Complex r{cpowf(
444-
CppToC<float _Complex, float>(z), CppToC<float _Complex, float>(y))};
445-
return CToCpp<float, float _Complex>(r);
446-
} else if constexpr (std::is_same_v<HostT, double>) {
447-
double _Complex r{cpow(CppToC<double _Complex, double>(z),
448-
CppToC<double _Complex, double>(y))};
449-
return CToCpp<double, double _Complex>(r);
450-
}
451-
#else
452-
return std::pow(x, y);
453-
#endif
454-
}
455-
456-
template <typename HostT>
457-
static std::complex<HostT> StdPowF2B(
458-
const std::complex<HostT> &x, const HostT &y) {
459-
#ifdef _AIX
460-
constexpr HostT zero{0.0};
461-
std::complex<HostT> z(y, zero);
462-
if constexpr (std::is_same_v<HostT, float>) {
463-
float _Complex r{cpowf(
464-
CppToC<float _Complex, float>(x), CppToC<float _Complex, float>(z))};
465-
return CToCpp<float, float _Complex>(r);
466-
} else if constexpr (std::is_same_v<HostT, double>) {
467-
double _Complex r{cpow(CppToC<double _Complex, double>(x),
468-
CppToC<double _Complex, double>(z))};
469-
return CToCpp<double, double _Complex>(r);
374+
template <typename F, F func> struct CComplexFunc {};
375+
template <typename R, typename... A, FuncPointer<R, A...> func>
376+
struct CComplexFunc<FuncPointer<R, A...>, func> {
377+
static typename ToStdComplex<R>::Type wrapper(
378+
typename ToStdComplex<A>::AType... args) {
379+
R res{func(*reinterpret_cast<const A *>(&args)...)};
380+
return *reinterpret_cast<typename ToStdComplex<R>::Type *>(&res);
470381
}
471-
#else
472-
return std::pow(x, y);
473-
#endif
474-
}
382+
};
383+
#define C_COMPLEX_FUNC(func) CComplexFunc<decltype(&func), &func>::wrapper
475384

476-
template <typename HostT>
477-
struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
478-
using F = FuncPointer<std::complex<HostT>, const std::complex<HostT> &>;
479-
using F2 = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
480-
const std::complex<HostT> &>;
481-
using F2A = FuncPointer<std::complex<HostT>, const HostT &,
482-
const std::complex<HostT> &>;
483-
using F2B = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
484-
const HostT &>;
385+
template <>
386+
struct HostRuntimeLibrary<std::complex<float>, LibraryVersion::Libm> {
387+
COMPLEX_SIGNATURES(float)
485388
static constexpr HostRuntimeFunction table[]{
486-
FolderFactory<F, F{X<Cacos>::f}>::Create("acos"),
487-
FolderFactory<F, F{X<Cacosh>::f}>::Create("acosh"),
488-
FolderFactory<F, F{X<Casin>::f}>::Create("asin"),
489-
FolderFactory<F, F{X<Casinh>::f}>::Create("asinh"),
490-
FolderFactory<F, F{X<Catan>::f}>::Create("atan"),
491-
FolderFactory<F, F{X<Catanh>::f}>::Create("atanh"),
492-
FolderFactory<F, F{X<Ccos>::f}>::Create("cos"),
493-
FolderFactory<F, F{X<Ccosh>::f}>::Create("cosh"),
494-
FolderFactory<F, F{X<Cexp>::f}>::Create("exp"),
495-
FolderFactory<F, F{X<Clog>::f}>::Create("log"),
496-
FolderFactory<F2, F2{StdPowF2}>::Create("pow"),
497-
FolderFactory<F2A, F2A{StdPowF2A}>::Create("pow"),
498-
FolderFactory<F2B, F2B{StdPowF2B}>::Create("pow"),
499-
FolderFactory<F, F{X<Csin>::f}>::Create("sin"),
500-
FolderFactory<F, F{X<Csinh>::f}>::Create("sinh"),
501-
FolderFactory<F, F{X<Csqrt>::f}>::Create("sqrt"),
502-
FolderFactory<F, F{X<Ctan>::f}>::Create("tan"),
503-
FolderFactory<F, F{X<Ctanh>::f}>::Create("tanh"),
389+
FolderFactory<F, C_COMPLEX_FUNC(cacosf)>::Create("acos"),
390+
FolderFactory<F, C_COMPLEX_FUNC(cacoshf)>::Create("acosh"),
391+
FolderFactory<F, C_COMPLEX_FUNC(casinf)>::Create("asin"),
392+
FolderFactory<F, C_COMPLEX_FUNC(casinhf)>::Create("asinh"),
393+
FolderFactory<F, C_COMPLEX_FUNC(catanf)>::Create("atan"),
394+
FolderFactory<F, C_COMPLEX_FUNC(catanhf)>::Create("atanh"),
395+
FolderFactory<F, C_COMPLEX_FUNC(ccosf)>::Create("cos"),
396+
FolderFactory<F, C_COMPLEX_FUNC(ccoshf)>::Create("cosh"),
397+
FolderFactory<F, C_COMPLEX_FUNC(cexpf)>::Create("exp"),
398+
FolderFactory<F, C_COMPLEX_FUNC(clogf)>::Create("log"),
399+
FolderFactory<F2, C_COMPLEX_FUNC(cpowf)>::Create("pow"),
400+
FolderFactory<F, C_COMPLEX_FUNC(csinf)>::Create("sin"),
401+
FolderFactory<F, C_COMPLEX_FUNC(csinhf)>::Create("sinh"),
402+
FolderFactory<F, C_COMPLEX_FUNC(csqrtf)>::Create("sqrt"),
403+
FolderFactory<F, C_COMPLEX_FUNC(ctanf)>::Create("tan"),
404+
FolderFactory<F, C_COMPLEX_FUNC(ctanhf)>::Create("tanh"),
504405
};
505406
static constexpr HostRuntimeMap map{table};
506407
static_assert(map.Verify(), "map must be sorted");
507408
};
409+
template <>
410+
struct HostRuntimeLibrary<std::complex<double>, LibraryVersion::Libm> {
411+
COMPLEX_SIGNATURES(double)
412+
static constexpr HostRuntimeFunction table[]{
413+
FolderFactory<F, C_COMPLEX_FUNC(cacos)>::Create("acos"),
414+
FolderFactory<F, C_COMPLEX_FUNC(cacosh)>::Create("acosh"),
415+
FolderFactory<F, C_COMPLEX_FUNC(casin)>::Create("asin"),
416+
FolderFactory<F, C_COMPLEX_FUNC(casinh)>::Create("asinh"),
417+
FolderFactory<F, C_COMPLEX_FUNC(catan)>::Create("atan"),
418+
FolderFactory<F, C_COMPLEX_FUNC(catanh)>::Create("atanh"),
419+
FolderFactory<F, C_COMPLEX_FUNC(ccos)>::Create("cos"),
420+
FolderFactory<F, C_COMPLEX_FUNC(ccosh)>::Create("cosh"),
421+
FolderFactory<F, C_COMPLEX_FUNC(cexp)>::Create("exp"),
422+
FolderFactory<F, C_COMPLEX_FUNC(__clog)>::Create("log"),
423+
FolderFactory<F2, C_COMPLEX_FUNC(cpow)>::Create("pow"),
424+
FolderFactory<F, C_COMPLEX_FUNC(csin)>::Create("sin"),
425+
FolderFactory<F, C_COMPLEX_FUNC(csinh)>::Create("sinh"),
426+
FolderFactory<F, C_COMPLEX_FUNC(csqrt)>::Create("sqrt"),
427+
FolderFactory<F, C_COMPLEX_FUNC(ctan)>::Create("tan"),
428+
FolderFactory<F, C_COMPLEX_FUNC(ctanh)>::Create("tanh"),
429+
};
430+
static constexpr HostRuntimeMap map{table};
431+
static_assert(map.Verify(), "map must be sorted");
432+
};
433+
#endif // _AIX
434+
508435
// Note regarding cmath:
509436
// - cmath does not have modulo and erfc_scaled equivalent
510437
// - C++17 defined standard Bessel math functions std::cyl_bessel_j

0 commit comments

Comments
 (0)