Skip to content

Commit d83bb02

Browse files
committed
Fix: OpenMP split size
1 parent e4f601b commit d83bb02

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

reduce_cpu.hpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,20 @@ namespace ashvardanian {
3737
*/
3838
inline static std::size_t total_cores() { return std::thread::hardware_concurrency(); }
3939

40+
/**
41+
* @brief Divides a value by another value and rounds it up to the nearest integer.
42+
* Example: `round_up_to_multiple(5, 3) == 2`
43+
*/
44+
inline static std::size_t divide_round_up(std::size_t value, std::size_t multiple) noexcept {
45+
return ((value + multiple - 1) / multiple);
46+
}
47+
4048
/**
4149
* @brief Rounds a value up to the nearest multiple of another value.
4250
* Example: `round_up_to_multiple(5, 3) == 6`
4351
*/
4452
inline static std::size_t round_up_to_multiple(std::size_t value, std::size_t multiple) noexcept {
45-
return ((value + multiple - 1) / multiple) * multiple;
53+
return divide_round_up(value, multiple) * multiple;
4654
}
4755

4856
#pragma region - Serial and Autovectorized
@@ -590,11 +598,11 @@ class openmp_gt {
590598

591599
double operator()() {
592600
auto const input_size = static_cast<std::size_t>(end_ - begin_);
593-
auto const chunk_size = input_size / total_cores_;
601+
auto const chunk_size = divide_round_up(input_size, total_cores_);
594602
#pragma omp parallel
595603
{
596604
std::size_t const thread_id = static_cast<std::size_t>(omp_get_thread_num());
597-
std::size_t const start = thread_id * chunk_size;
605+
std::size_t const start = std::min(thread_id * chunk_size);
598606
std::size_t const stop = std::min(start + chunk_size, input_size);
599607
double local_sum = serial_at {begin_ + start, begin_ + stop}();
600608
sums_[thread_id] = local_sum;
@@ -636,7 +644,7 @@ class threads_gt {
636644

637645
std::size_t count_per_thread() const noexcept {
638646
constexpr std::size_t entries_per_zmm_register = 64 / sizeof(float);
639-
std::size_t balanced_split = (end_ - begin_) / sums_.size();
647+
std::size_t balanced_split = divide_round_up(end_ - begin_, sums_.size());
640648
return round_up_to_multiple(balanced_split, entries_per_zmm_register);
641649
}
642650

0 commit comments

Comments
 (0)