@@ -37,12 +37,20 @@ namespace ashvardanian {
3737 */
3838inline static std::size_t total_cores () { return std::thread::hardware_concurrency (); }
3939
40+ /* *
41+ * @brief Divides a value by another value and rounds it up to the nearest integer.
42+ * Example: `round_up_to_multiple(5, 3) == 2`
43+ */
44+ inline static std::size_t divide_round_up (std::size_t value, std::size_t multiple) noexcept {
45+ return ((value + multiple - 1 ) / multiple);
46+ }
47+
4048/* *
4149 * @brief Rounds a value up to the nearest multiple of another value.
4250 * Example: `round_up_to_multiple(5, 3) == 6`
4351 */
4452inline static std::size_t round_up_to_multiple (std::size_t value, std::size_t multiple) noexcept {
45- return (( value + multiple - 1 ) / multiple) * multiple;
53+ return divide_round_up ( value, multiple) * multiple;
4654}
4755
4856#pragma region - Serial and Autovectorized
@@ -590,11 +598,11 @@ class openmp_gt {
590598
591599 double operator ()() {
592600 auto const input_size = static_cast <std::size_t >(end_ - begin_);
593- auto const chunk_size = input_size / total_cores_;
601+ auto const chunk_size = divide_round_up ( input_size, total_cores_) ;
594602#pragma omp parallel
595603 {
596604 std::size_t const thread_id = static_cast <std::size_t >(omp_get_thread_num ());
597- std::size_t const start = thread_id * chunk_size;
605+ std::size_t const start = std::min ( thread_id * chunk_size) ;
598606 std::size_t const stop = std::min (start + chunk_size, input_size);
599607 double local_sum = serial_at {begin_ + start, begin_ + stop}();
600608 sums_[thread_id] = local_sum;
@@ -636,7 +644,7 @@ class threads_gt {
636644
637645 std::size_t count_per_thread () const noexcept {
638646 constexpr std::size_t entries_per_zmm_register = 64 / sizeof (float );
639- std::size_t balanced_split = (end_ - begin_) / sums_.size ();
647+ std::size_t balanced_split = divide_round_up (end_ - begin_, sums_.size () );
640648 return round_up_to_multiple (balanced_split, entries_per_zmm_register);
641649 }
642650
0 commit comments