Skip to content

Commit baa91b3

Browse files
ariostasfwyzard
andcommitted
Added lower_bound function that works in device code
Co-authored-by: Andrea Bocci <[email protected]>
1 parent 1416747 commit baa91b3

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

HeterogeneousCore/AlpakaInterface/interface/alpakastdAlgorithm.h

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,28 @@
77

88
#include <alpaka/alpaka.hpp>
99

10-
// reimplementation of std algorithms able to compile with Alpaka,
11-
// mostly by declaring them constexpr (until C++20, which will make it
12-
// constexpr by default. TODO: drop when moving to C++20)
10+
// reimplementation of std algorithms able to work on device code
1311

1412
namespace alpaka_std {
1513

14+
template <typename RandomIt, typename T, typename Compare = std::less<T>>
15+
ALPAKA_FN_HOST_ACC constexpr RandomIt lower_bound(RandomIt first, RandomIt last, const T &value, Compare comp = {}) {
16+
auto count = last - first;
17+
18+
while (count > 0) {
19+
auto it = first;
20+
auto step = count / 2;
21+
it += step;
22+
if (comp(*it, value)) {
23+
first = ++it;
24+
count -= step + 1;
25+
} else {
26+
count = step;
27+
}
28+
}
29+
return first;
30+
}
31+
1632
template <typename RandomIt, typename T, typename Compare = std::less<T>>
1733
ALPAKA_FN_HOST_ACC constexpr RandomIt upper_bound(RandomIt first, RandomIt last, const T &value, Compare comp = {}) {
1834
auto count = last - first;

RecoTracker/LSTCore/src/alpaka/Hit.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define RecoTracker_LSTCore_src_alpaka_Hit_h
33

44
#include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
5+
#include "HeterogeneousCore/AlpakaInterface/interface/alpakastdAlgorithm.h"
56

67
#include "RecoTracker/LSTCore/interface/alpaka/Common.h"
78
#include "RecoTracker/LSTCore/interface/ModulesSoA.h"
@@ -103,7 +104,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst {
103104
((ihit_z > 0) - (ihit_z < 0)) *
104105
alpaka::math::acosh(
105106
acc, alpaka::math::sqrt(acc, ihit_x * ihit_x + ihit_y * ihit_y + ihit_z * ihit_z) / hits.rts()[ihit]);
106-
auto found_pointer = std::lower_bound(modules.mapdetId(), modules.mapdetId() + nModules, iDetId);
107+
auto found_pointer = alpaka_std::lower_bound(modules.mapdetId(), modules.mapdetId() + nModules, iDetId);
107108
int found_index = std::distance(modules.mapdetId(), found_pointer);
108109
if (found_pointer == modules.mapdetId() + nModules)
109110
found_index = -1;
@@ -112,7 +113,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst {
112113
hits.moduleIndices()[ihit] = lastModuleIndex;
113114

114115
if (modules.subdets()[lastModuleIndex] == Endcap && modules.moduleType()[lastModuleIndex] == TwoS) {
115-
found_pointer = std::lower_bound(geoMapDetId, geoMapDetId + nEndCapMap, iDetId);
116+
found_pointer = alpaka_std::lower_bound(geoMapDetId, geoMapDetId + nEndCapMap, iDetId);
116117
found_index = std::distance(geoMapDetId, found_pointer);
117118
if (found_pointer == geoMapDetId + nEndCapMap)
118119
found_index = -1;

0 commit comments

Comments
 (0)