Skip to content

Commit c8d55d8

Browse files
authored
Merge pull request #1064 from stephenswat/build/cuda_fast_math
Enable fast math library in CUDA compilation
2 parents 9457b4f + a270507 commit c8d55d8

File tree

4 files changed

+66
-10
lines changed

4 files changed

+66
-10
lines changed

cmake/traccc-compiler-options-cuda.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ endif()
2525

2626
# Allow to use functions in device code that are constexpr, even if they are
2727
# not marked with __device__.
28-
traccc_add_flag( CMAKE_CUDA_FLAGS "--expt-relaxed-constexpr" )
28+
traccc_add_flag( CMAKE_CUDA_FLAGS "--expt-relaxed-constexpr --use_fast_math" )
2929

3030
# Make CUDA generate debug symbols for the device code as well in a debug
3131
# build.

core/include/traccc/definitions/math.hpp

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
#pragma once
99

10+
#include "traccc/definitions/qualifiers.hpp"
11+
1012
// SYCL include(s).
1113
#if defined(CL_SYCL_LANGUAGE_VERSION) || defined(SYCL_LANGUAGE_VERSION)
1214
#include <sycl/sycl.hpp>
@@ -15,13 +17,63 @@
1517
// System include(s).
1618
#include <cmath>
1719

18-
namespace traccc {
20+
namespace traccc::math {
1921

20-
/// Namespace to pick up math functions from
2122
#if defined(CL_SYCL_LANGUAGE_VERSION) || defined(SYCL_LANGUAGE_VERSION)
22-
namespace math = ::sycl;
23+
using ::sycl::abs;
24+
using ::sycl::acos;
25+
using ::sycl::asin;
26+
using ::sycl::atan;
27+
using ::sycl::atan2;
28+
using ::sycl::cos;
29+
using ::sycl::exp;
30+
using ::sycl::fabs;
31+
using ::sycl::floor;
32+
using ::sycl::fmod;
33+
using ::sycl::log;
34+
using ::sycl::max;
35+
using ::sycl::min;
36+
using ::sycl::pow;
37+
using ::sycl::sin;
38+
using ::sycl::sqrt;
39+
using ::sycl::tan;
2340
#else
24-
namespace math = std;
25-
#endif // SYCL
41+
using std::abs;
42+
using std::acos;
43+
using std::asin;
44+
using std::atan;
45+
using std::atan2;
46+
using std::cos;
47+
using std::exp;
48+
using std::fabs;
49+
using std::floor;
50+
using std::fmod;
51+
using std::log;
52+
using std::max;
53+
using std::min;
54+
using std::pow;
55+
using std::sin;
56+
using std::sqrt;
57+
using std::tan;
58+
#endif
2659

27-
} // namespace traccc
60+
/**
61+
* @brief Perform IEEE-754 division, even if fast math is enabled.
62+
*
63+
* @returns x divided by y
64+
*/
65+
template <typename T>
66+
TRACCC_HOST_DEVICE inline __attribute__((always_inline)) T div_ieee754(T x,
67+
T y) {
68+
static_assert(std::is_same_v<T, double> || std::is_same_v<T, float>);
69+
#ifdef __CUDA_ARCH__
70+
if constexpr (std::is_same_v<T, double>) {
71+
return __ddiv_rn(x, y);
72+
} else {
73+
return __fdiv_rn(x, y);
74+
}
75+
#else
76+
return x / y;
77+
#endif
78+
}
79+
} // namespace traccc::math

device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "./kernels/sort_tracks_per_measurement.cuh"
2626
#include "./kernels/sort_updated_tracks.cuh"
2727
#include "traccc/cuda/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"
28+
#include "traccc/definitions/math.hpp"
2829

2930
// Thrust include(s).
3031
#include <thrust/execution_policy.h>
@@ -42,7 +43,8 @@ namespace traccc::cuda {
4243
struct devide_op {
4344
TRACCC_HOST_DEVICE
4445
traccc::scalar operator()(unsigned int a, unsigned int b) const {
45-
return static_cast<traccc::scalar>(a) / static_cast<traccc::scalar>(b);
46+
return math::div_ieee754(static_cast<traccc::scalar>(a),
47+
static_cast<traccc::scalar>(b));
4648
}
4749
};
4850

device/cuda/src/ambiguity_resolution/kernels/remove_tracks.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
*/
77

88
// Project include(s).
9+
#include "traccc/definitions/math.hpp"
910
#include "traccc/utils/pair.hpp"
1011

1112
// Local include(s).
@@ -433,8 +434,9 @@ __launch_bounds__(512) __global__
433434
updated_tracks[pos2] = tid;
434435
is_updated[tid] = 1;
435436

436-
rel_shared.at(tid) = static_cast<traccc::scalar>(n_shared.at(tid)) /
437-
static_cast<traccc::scalar>(n_meas.at(tid));
437+
rel_shared.at(tid) =
438+
math::div_ieee754(static_cast<traccc::scalar>(n_shared.at(tid)),
439+
static_cast<traccc::scalar>(n_meas.at(tid)));
438440
}
439441
}
440442
}

0 commit comments

Comments
 (0)