55#include " IRMutator.h"
66#include " IROperator.h"
77#include " IRPrinter.h"
8+ #include " Util.h"
89
910namespace Halide {
1011namespace Internal {
11- namespace ApproxImpl {
1212
13+ namespace {
1314constexpr double PI = 3.14159265358979323846 ;
1415constexpr double ONE_OVER_PI = 1.0 / PI;
1516constexpr double TWO_OVER_PI = 2.0 / PI;
1617constexpr double PI_OVER_TWO = PI / 2 ;
1718
19+ float ulp_to_ae (float max, int ulp) {
20+ internal_assert (max > 0.0 );
21+ uint32_t n = reinterpret_bits<uint32_t >(max);
22+ float fn = reinterpret_bits<float >(n + ulp);
23+ return fn - max;
24+ }
25+
26+ uint32_t ae_to_ulp (float smallest, float ae) {
27+ internal_assert (smallest >= 0.0 );
28+ float fn = smallest + ae;
29+ return reinterpret_bits<uint32_t >(fn) - reinterpret_bits<uint32_t >(smallest);
30+ }
31+ } // namespace
32+
33+ namespace ApproxImpl {
34+
1835std::pair<float , float > split_float (double value) {
1936 float high = float (value); // Convert to single precision
2037 float low = float (value - double (high)); // Compute the residual part
@@ -152,7 +169,7 @@ Expr fast_sin(const Expr &x_full, ApproximationPrecision precision) {
152169 Expr k = cast<int >(k_real);
153170 Expr k_mod4 = k % 4 ; // Halide mod is always positive!
154171 Expr mirror = (k_mod4 == 1 ) || (k_mod4 == 3 );
155- Expr flip_sign = (k_mod4 > 1 ) ^ (x_full < 0 );
172+ Expr flip_sign = (k_mod4 > 1 ) != (x_full < 0 );
156173
157174 // Reduce the angle modulo pi/2: i.e., to the angle within the quadrant.
158175 Expr x = x_abs - k_real * make_const (type, PI_OVER_TWO);
@@ -417,7 +434,7 @@ Expr fast_tanh(const Expr &x, ApproximationPrecision prec) {
417434 Expr arg_exp = select(flip_exp, -abs_x, abs_x);
418435 Expr exp2xm1 = Halide::fast_expm1(2 * arg_exp, prec);
419436 Expr tanh = (exp2xm1) / (exp2xm1 + make_const(type, 2));
420- tanh = select(flip_exp ^ flip_sign, -tanh, tanh);
437+ tanh = select(flip_exp != flip_sign, -tanh, tanh);
421438 return common_subexpression_elimination(tanh, true);
422439#else
423440 // expm1 is devloped around 0 and is ULP accurate in [-ln(2)/2, ln(2)/2].
@@ -465,6 +482,19 @@ struct IntrinsicsInfo {
465482 } intrinsic;
466483};
467484
485+ IntrinsicsInfo::NativeFunc MAE_func (bool fast, float mae, float smallest_output = 0 .0f ) {
486+ return IntrinsicsInfo::NativeFunc{fast, OO::MAE, mae, ae_to_ulp (smallest_output, mae)};
487+ }
488+ IntrinsicsInfo::NativeFunc MULPE_func (bool fast, uint64_t mulpe, float largest_output) {
489+ return IntrinsicsInfo::NativeFunc{fast, OO::MULPE, ulp_to_ae (largest_output, mulpe), mulpe};
490+ }
491+ IntrinsicsInfo::IntrinsicImpl MAE_intrinsic (float mae, float smallest_output = 0 .0f ) {
492+ return IntrinsicsInfo::IntrinsicImpl{OO::MAE, mae, ae_to_ulp (smallest_output, mae)};
493+ }
494+ IntrinsicsInfo::IntrinsicImpl MULPE_intrinsic (uint64_t mulpe, float largest_output) {
495+ return IntrinsicsInfo::IntrinsicImpl{OO::MULPE, ulp_to_ae (largest_output, mulpe), mulpe};
496+ }
497+
468498struct IntrinsicsInfoPerDeviceAPI {
469499 OO reasonable_behavior; // A reasonable optimization objective for a given function.
470500 float default_mae; // A reasonable desirable MAE (if specified)
@@ -475,37 +505,45 @@ struct IntrinsicsInfoPerDeviceAPI {
475505// clang-format off
476506IntrinsicsInfoPerDeviceAPI ii_sin{
477507 OO::MAE, 1e-5f , 0 , {
478- {DeviceAPI::Vulkan, { true } , {}},
479- {DeviceAPI::CUDA, {false }, {OO::MAE, 5e-7f , 1'000'000 } },
480- {DeviceAPI::Metal, {true }, {OO::MAE, 6e- 5f , 400'000 }},
508+ {DeviceAPI::Vulkan, MAE_func ( true , 5e- 4f ) , {}},
509+ {DeviceAPI::CUDA, {false }, MAE_intrinsic ( 5e-7f ) },
510+ {DeviceAPI::Metal, {true }, MAE_intrinsic ( 1 .2e- 4f )}, // 2^-13
481511 {DeviceAPI::WebGPU, {true }, {}},
482- {DeviceAPI::OpenCL, {false }, {OO::MAE, 5e-7f , 1'000'000 } },
512+ {DeviceAPI::OpenCL, {false }, MAE_intrinsic ( 5e-7f ) },
483513}};
484514
485515IntrinsicsInfoPerDeviceAPI ii_cos{
486516 OO::MAE, 1e-5f , 0 , {
487- {DeviceAPI::Vulkan, { true } , {}},
488- {DeviceAPI::CUDA, {false }, {OO::MAE, 5e-7f , 1'000'000 } },
489- {DeviceAPI::Metal, {true }, {OO::MAE, 7e-7f , 5'000 }},
517+ {DeviceAPI::Vulkan, MAE_func ( true , 5e- 4f ) , {}},
518+ {DeviceAPI::CUDA, {false }, MAE_intrinsic ( 5e-7f ) },
519+ {DeviceAPI::Metal, {true }, MAE_intrinsic ( 1 .2e- 4f )}, // Seems to be 7e-7, but spec says 2^-13...
490520 {DeviceAPI::WebGPU, {true }, {}},
491- {DeviceAPI::OpenCL, {false }, {OO::MAE, 5e-7f , 1'000'000 } },
521+ {DeviceAPI::OpenCL, {false }, MAE_intrinsic ( 5e-7f ) },
492522}};
493523
494- IntrinsicsInfoPerDeviceAPI ii_atan_atan2 {
524+ IntrinsicsInfoPerDeviceAPI ii_atan {
495525 OO::MAE, 1e-5f , 0 , {
496526 // no intrinsics available
497527 {DeviceAPI::Vulkan, {false }, {}},
498- {DeviceAPI::Metal, {true }, {OO::MAE, 5e-6f }},
528+ {DeviceAPI::Metal, {true }, MULPE_intrinsic (5 , float (PI * 0.501 ))}, // They claim <= 5 ULP!
529+ {DeviceAPI::WebGPU, {true }, {}},
530+ }};
531+
532+ IntrinsicsInfoPerDeviceAPI ii_atan2{
533+ OO::MAE, 1e-5f , 0 , {
534+ // no intrinsics available
535+ {DeviceAPI::Vulkan, {false }, {}},
536+ {DeviceAPI::Metal, {true }, MAE_intrinsic (5e-6f , 0 .0f )},
499537 {DeviceAPI::WebGPU, {true }, {}},
500538}};
501539
502540IntrinsicsInfoPerDeviceAPI ii_tan{
503541 OO::MULPE, 0 .0f , 2000 , {
504- {DeviceAPI::Vulkan, { true , OO::MAE, 2e-6f , 1'000'000 }, {}}, // Vulkan tan seems to mimic our CUDA implementation
505- {DeviceAPI::CUDA, {false }, {OO::MAE, 2e-6f , 1'000'000 } },
506- {DeviceAPI::Metal, {true }, {OO::MULPE, 2e-6f , 1'000'000 }},
542+ {DeviceAPI::Vulkan, MAE_func ( true , 2e-6f ), {}}, // Vulkan tan() seems to mimic our CUDA implementation
543+ {DeviceAPI::CUDA, {false }, MAE_intrinsic ( 2e-6f ) },
544+ {DeviceAPI::Metal, {true }, MAE_intrinsic ( 2e-6f )}, // sin()/cos()
507545 {DeviceAPI::WebGPU, {true }, {}},
508- {DeviceAPI::OpenCL, {false }, {OO::MAE, 2e-6f , 1'000'000 } },
546+ {DeviceAPI::OpenCL, {false }, MAE_intrinsic ( 2e-6f ) },
509547}};
510548
511549IntrinsicsInfoPerDeviceAPI ii_expm1{
@@ -514,16 +552,16 @@ IntrinsicsInfoPerDeviceAPI ii_expm1{
514552
515553IntrinsicsInfoPerDeviceAPI ii_exp{
516554 OO::MULPE, 0 .0f , 50 , {
517- {DeviceAPI::Vulkan, { true } , {}},
518- {DeviceAPI::CUDA, {false }, {OO::MULPE, 0 .0f , 5 } },
519- {DeviceAPI::Metal, {true }, {OO::MULPE, 0 .0f , 5 } }, // precise::exp() is fast on metal
555+ {DeviceAPI::Vulkan, MULPE_func ( true , 3 + 2 * 2 , 2 . 0f ) , {}},
556+ {DeviceAPI::CUDA, {false }, MULPE_intrinsic ( 5 , 2 .0f ) },
557+ {DeviceAPI::Metal, {true }, MULPE_intrinsic ( 5 , 2 .0f ) }, // precise::exp() is fast on metal
520558 {DeviceAPI::WebGPU, {true }, {}},
521- {DeviceAPI::OpenCL, {true }, {OO::MULPE, 0 .0f , 5 } }, // Both exp() and native_exp() are faster than polys.
559+ {DeviceAPI::OpenCL, {true }, MULPE_intrinsic ( 5 , 2 .0f ) }, // Both exp() and native_exp() are faster than polys.
522560}};
523561
524562IntrinsicsInfoPerDeviceAPI ii_log{
525563 OO::MAE, 1e-5f , 1000 , {
526- {DeviceAPI::Vulkan, {true }, {}},
564+ {DeviceAPI::Vulkan, {true , ApproximationPrecision::MULPE, 5e- 7f , 3 }, {}}, // Precision piecewise defined: 3 ULP outside the range [0.5,2.0]. Absolute error < 2^−21 inside the range [0.5,2.0].
527565 {DeviceAPI::CUDA, {false }, {OO::MAE, 0 .0f , 3'800'000 }},
528566 {DeviceAPI::Metal, {false }, {OO::MAE, 0 .0f , 3'800'000 }}, // slow log() on metal
529567 {DeviceAPI::WebGPU, {true }, {}},
@@ -551,6 +589,7 @@ IntrinsicsInfoPerDeviceAPI ii_asin_acos{
551589 OO::MULPE, 1e-5f , 500 , {
552590 {DeviceAPI::Vulkan, {true }, {}},
553591 {DeviceAPI::CUDA, {true }, {}},
592+ {DeviceAPI::Metal, {true }, MULPE_intrinsic (5 , PI)},
554593 {DeviceAPI::OpenCL, {true }, {}},
555594}};
556595// clang-format on
@@ -559,8 +598,10 @@ bool fast_math_func_has_intrinsic_based_implementation(Call::IntrinsicOp op, Dev
559598 const IntrinsicsInfoPerDeviceAPI *iipda = nullptr ;
560599 switch (op) {
561600 case Call::fast_atan:
601+ iipda = &ii_atan;
602+ break ;
562603 case Call::fast_atan2:
563- iipda = &ii_atan_atan2 ;
604+ iipda = &ii_atan2 ;
564605 break ;
565606 case Call::fast_cos:
566607 iipda = &ii_cos;
@@ -858,20 +899,24 @@ class LowerFastMathFunctions : public IRMutator {
858899
859900 // No known fast version available, we will expand our own approximation.
860901 return ApproxImpl::fast_cos (mutate (op->args [0 ]), prec);
861- } else if (op->is_intrinsic (Call::fast_atan) || op-> is_intrinsic (Call::fast_atan2) ) {
902+ } else if (op->is_intrinsic (Call::fast_atan)) {
862903 // Handle fast_atan and fast_atan2 together!
863904 ApproximationPrecision prec = extract_approximation_precision (op);
864- IntrinsicsInfo ii = resolve_precision (prec, ii_atan_atan2 , for_device_api);
905+ IntrinsicsInfo ii = resolve_precision (prec, ii_atan , for_device_api);
865906 if (ii.native_func .is_fast && native_func_satisfies_precision (ii, prec)) {
866907 // The native atan is fast: fall back to native and continue lowering.
867908 return to_native_func (op);
868909 }
869-
870- if (op->is_intrinsic (Call::fast_atan)) {
871- return ApproxImpl::fast_atan (mutate (op->args [0 ]), prec);
872- } else {
873- return ApproxImpl::fast_atan2 (mutate (op->args [0 ]), mutate (op->args [1 ]), prec);
910+ return ApproxImpl::fast_atan (mutate (op->args [0 ]), prec);
911+ } else if (op->is_intrinsic (Call::fast_atan2)) {
912+ // Handle fast_atan and fast_atan2 together!
913+ ApproximationPrecision prec = extract_approximation_precision (op);
914+ IntrinsicsInfo ii = resolve_precision (prec, ii_atan2, for_device_api);
915+ if (ii.native_func .is_fast && native_func_satisfies_precision (ii, prec)) {
916+ // The native atan2 is fast: fall back to native and continue lowering.
917+ return to_native_func (op);
874918 }
919+ return ApproxImpl::fast_atan2 (mutate (op->args [0 ]), mutate (op->args [1 ]), prec);
875920 } else if (op->is_intrinsic (Call::fast_tan)) {
876921 ApproximationPrecision prec = extract_approximation_precision (op);
877922 IntrinsicsInfo ii = resolve_precision (prec, ii_tan, for_device_api);
@@ -913,7 +958,7 @@ class LowerFastMathFunctions : public IRMutator {
913958 return append_type_suffix (op);
914959 }
915960 if (ii.native_func .is_fast && native_func_satisfies_precision (ii, prec)) {
916- // The native atan is fast: fall back to native and continue lowering.
961+ // The native exp is fast: fall back to native and continue lowering.
917962 return to_native_func (op);
918963 }
919964
0 commit comments