Skip to content

Commit 83bc966

Browse files
committed
Fix: bytes_per_core_optimal estimate
1 parent c8dc314 commit 83bc966

File tree

2 files changed

+36
-16
lines changed

2 files changed

+36
-16
lines changed

include/stringcuzilla/find_many.hpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
586586

587587
_sz_assert(counts.size() == haystacks.size());
588588
size_t const cores_total = specs.cores_total();
589+
size_t const cache_line_width = specs.cache_line_width;
589590

590591
using haystacks_t = typename std::remove_reference_t<haystacks_type_>;
591592
using haystack_t = typename haystacks_t::value_type;
@@ -609,7 +610,8 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
609610
size_t count_matches_across_cores = 0;
610611
#pragma omp parallel for reduction(+ : count_matches_across_cores) schedule(static, 1) num_threads(cores_total)
611612
for (size_t core_index = 0; core_index < cores_total; ++core_index) {
612-
size_t count_matches_on_one_core = count_matches_in_one_part(haystack, core_index, cores_total);
613+
size_t count_matches_on_one_core =
614+
count_matches_in_one_part(haystack, core_index, cores_total, cache_line_width);
613615
count_matches_across_cores += count_matches_on_one_core;
614616
}
615617

@@ -662,7 +664,7 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
662664

663665
_sz_assert(counts.size() == haystacks.size());
664666
size_t const cores_total = specs.cores_total();
665-
size_t const max_needle_length = dict_.max_needle_length();
667+
size_t const cache_line_width = specs.cache_line_width;
666668

667669
using haystacks_t = typename std::remove_reference_t<haystacks_type_>;
668670
using haystack_t = typename haystacks_t::value_type;
@@ -705,7 +707,8 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
705707
// First, on each core, estimate the number of matches in the haystack
706708
#pragma omp parallel for schedule(static, 1) num_threads(cores_total)
707709
for (size_t core_index = 0; core_index < cores_total; ++core_index)
708-
counts_per_core[core_index] = count_matches_in_one_part(haystack, core_index, cores_total);
710+
counts_per_core[core_index] =
711+
count_matches_in_one_part(haystack, core_index, cores_total, cache_line_width);
709712

710713
// Now that we know the number of matches to expect per slice, we can convert the counts
711714
// into offsets using inclusive prefix sum
@@ -716,9 +719,13 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
716719
counts_per_core[core_index] += counts_per_core[core_index - 1];
717720
}
718721

722+
// We shouldn't even consider needles longer than the haystack
723+
size_t const max_needle_length = std::min(dict_.max_needle_length(), haystack_length);
724+
719725
// On each core, pick an overlapping slice and go through all of the matches in it,
720726
// that start before the end of the private slice.
721-
size_t const bytes_per_core_optimal = haystack_length / cores_total;
727+
size_t const bytes_per_core_optimal =
728+
round_up_to_multiple(divide_round_up(haystack_length, cores_total), cache_line_width);
722729
size_t const count_matches_before_this_haystack = offsets_per_haystack[i];
723730
#pragma omp parallel for schedule(static, 1) num_threads(cores_total)
724731
for (size_t core_index = 0; core_index < cores_total; ++core_index) {
@@ -727,12 +734,12 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
727734
counts_per_core[core_index] - count_matches_before_this_core;
728735

729736
// The last core may have a smaller slice, so we need to be careful
730-
size_t const bytes_per_core =
737+
size_t const bytes_for_core =
731738
std::min(bytes_per_core_optimal, haystack_length - core_index * bytes_per_core_optimal);
732739
char_t const *optimal_start = haystack_begin + core_index * bytes_per_core_optimal;
733-
char_t const *optimal_end = optimal_start + bytes_per_core;
740+
char_t const *optimal_end = optimal_start + bytes_for_core;
734741
char_t const *overlapping_end =
735-
std::min(optimal_start + bytes_per_core + max_needle_length - 1, haystack_begin + haystack_length);
742+
std::min(optimal_start + bytes_for_core + max_needle_length - 1, haystack_begin + haystack_length);
736743

737744
// Iterate through the matches in the overlapping region
738745
size_t count_matches_found_on_this_core = 0;
@@ -748,6 +755,7 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
748755
}
749756
}
750757

758+
// Aggregate the results
751759
matches_count = 0;
752760
for (size_t count : counts) matches_count += count;
753761
return status_t::success_k;
@@ -763,31 +771,33 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
763771
* overlapping matches.
764772
*/
765773
template <typename char_type_>
766-
size_t count_matches_in_one_part(span<char_type_ const> haystack, size_t core_index,
767-
size_t cores_total) const noexcept {
774+
size_t count_matches_in_one_part(span<char_type_ const> haystack, size_t core_index, size_t cores_total,
775+
size_t cache_line_width) const noexcept {
768776

769777
using char_t = char_type_;
770778
char_t const *haystack_begin = haystack.data();
771779
size_t const haystack_length = haystack.size();
772780
char_t const *const haystack_end = haystack_begin + haystack_length;
773-
size_t const bytes_per_core_optimal = haystack_length / cores_total;
774-
size_t const max_needle_length = dict_.max_needle_length();
781+
size_t const bytes_per_core_optimal =
782+
round_up_to_multiple(divide_round_up(haystack_length, cores_total), cache_line_width);
783+
784+
// We shouldn't even consider needles longer than the haystack
785+
size_t const max_needle_length = std::min(dict_.max_needle_length(), haystack_length);
775786

776787
// The last core may have a smaller slice, so we need to be careful
777-
size_t const bytes_per_core =
788+
size_t const bytes_for_core =
778789
std::min(bytes_per_core_optimal, haystack_length - core_index * bytes_per_core_optimal);
779790

780791
// First, each core will process its own slice excluding the overlapping regions
781792
char_t const *optimal_start = haystack_begin + core_index * bytes_per_core_optimal;
782-
char_t const *optimal_end = optimal_start + bytes_per_core;
793+
char_t const *optimal_end = optimal_start + bytes_for_core;
783794
size_t const count_matches_non_overlapping = dict_.count({optimal_start, optimal_end});
784795

785796
// Now, each thread will take care of the subsequent overlapping regions,
786797
// but we must be careful for cases when the core-specific slice is shorter
787798
// than the longest needle! It's a very unlikely case in practice, but we
788799
// still may want an optimization for it down the road.
789-
char_t const *overlapping_start =
790-
std::min(optimal_start + bytes_per_core - max_needle_length + 1, haystack_end);
800+
char_t const *overlapping_start = std::max(optimal_end - max_needle_length + 1, haystack_begin);
791801
char_t const *overlapping_end = std::min(optimal_end + max_needle_length - 1, haystack_end);
792802
size_t count_matches_overlapping = 0;
793803
dict_.find({overlapping_start, overlapping_end}, [&](match_t match) noexcept {

include/stringzilla/types.hpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,13 +799,23 @@ struct cpu_specs_t {
799799
size_t cores_total() const noexcept { return cores_per_socket * sockets; }
800800
};
801801

802+
/**
803+
* @brief Divides the @p x by @p divisor and rounds up to the nearest integer.
804+
* @note This is equivalent to `ceil(x / divisor)`, but avoids floating-point arithmetic.
805+
*/
806+
template <typename scalar_type_>
807+
constexpr scalar_type_ divide_round_up(scalar_type_ x, scalar_type_ divisor) {
808+
_sz_assert(divisor > 0 && "Divisor must be positive");
809+
return (x + divisor - 1) / divisor;
810+
}
811+
802812
/**
803813
* @brief Rounds @p x up to the nearest multiple of @p divisor.
804814
*/
805815
template <typename scalar_type_>
806816
constexpr scalar_type_ round_up_to_multiple(scalar_type_ x, scalar_type_ divisor) {
807817
_sz_assert(divisor > 0 && "Divisor must be positive");
808-
return ((x + divisor - 1) / divisor) * divisor;
818+
return divide_round_up(x, divisor) * divisor;
809819
}
810820

811821
} // namespace stringzilla

0 commit comments

Comments
 (0)