@@ -586,6 +586,7 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
586586
587587 _sz_assert (counts.size () == haystacks.size ());
588588 size_t const cores_total = specs.cores_total ();
589+ size_t const cache_line_width = specs.cache_line_width ;
589590
590591 using haystacks_t = typename std::remove_reference_t <haystacks_type_>;
591592 using haystack_t = typename haystacks_t ::value_type;
@@ -609,7 +610,8 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
609610 size_t count_matches_across_cores = 0 ;
610611#pragma omp parallel for reduction(+ : count_matches_across_cores) schedule(static, 1) num_threads(cores_total)
611612 for (size_t core_index = 0 ; core_index < cores_total; ++core_index) {
612- size_t count_matches_on_one_core = count_matches_in_one_part (haystack, core_index, cores_total);
613+ size_t count_matches_on_one_core =
614+ count_matches_in_one_part (haystack, core_index, cores_total, cache_line_width);
613615 count_matches_across_cores += count_matches_on_one_core;
614616 }
615617
@@ -662,7 +664,7 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
662664
663665 _sz_assert (counts.size () == haystacks.size ());
664666 size_t const cores_total = specs.cores_total ();
665- size_t const max_needle_length = dict_. max_needle_length () ;
667+ size_t const cache_line_width = specs. cache_line_width ;
666668
667669 using haystacks_t = typename std::remove_reference_t <haystacks_type_>;
668670 using haystack_t = typename haystacks_t ::value_type;
@@ -705,7 +707,8 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
705707 // First, on each core, estimate the number of matches in the haystack
706708#pragma omp parallel for schedule(static, 1) num_threads(cores_total)
707709 for (size_t core_index = 0 ; core_index < cores_total; ++core_index)
708- counts_per_core[core_index] = count_matches_in_one_part (haystack, core_index, cores_total);
710+ counts_per_core[core_index] =
711+ count_matches_in_one_part (haystack, core_index, cores_total, cache_line_width);
709712
710713 // Now that we know the number of matches to expect per slice, we can convert the counts
711714 // into offsets using inclusive prefix sum
@@ -716,9 +719,13 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
716719 counts_per_core[core_index] += counts_per_core[core_index - 1 ];
717720 }
718721
722+ // We shouldn't even consider needles longer than the haystack
723+ size_t const max_needle_length = std::min (dict_.max_needle_length (), haystack_length);
724+
719725 // On each core, pick an overlapping slice and go through all of the matches in it,
720726 // that start before the end of the private slice.
721- size_t const bytes_per_core_optimal = haystack_length / cores_total;
727+ size_t const bytes_per_core_optimal =
728+ round_up_to_multiple (divide_round_up (haystack_length, cores_total), cache_line_width);
722729 size_t const count_matches_before_this_haystack = offsets_per_haystack[i];
723730#pragma omp parallel for schedule(static, 1) num_threads(cores_total)
724731 for (size_t core_index = 0 ; core_index < cores_total; ++core_index) {
@@ -727,12 +734,12 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
727734 counts_per_core[core_index] - count_matches_before_this_core;
728735
729736 // The last core may have a smaller slice, so we need to be careful
730- size_t const bytes_per_core =
737+ size_t const bytes_for_core =
731738 std::min (bytes_per_core_optimal, haystack_length - core_index * bytes_per_core_optimal);
732739 char_t const *optimal_start = haystack_begin + core_index * bytes_per_core_optimal;
733- char_t const *optimal_end = optimal_start + bytes_per_core ;
740+ char_t const *optimal_end = optimal_start + bytes_for_core ;
734741 char_t const *overlapping_end =
735- std::min (optimal_start + bytes_per_core + max_needle_length - 1 , haystack_begin + haystack_length);
742+ std::min (optimal_start + bytes_for_core + max_needle_length - 1 , haystack_begin + haystack_length);
736743
737744 // Iterate through the matches in the overlapping region
738745 size_t count_matches_found_on_this_core = 0 ;
@@ -748,6 +755,7 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
748755 }
749756 }
750757
758+ // Aggregate the results
751759 matches_count = 0 ;
752760 for (size_t count : counts) matches_count += count;
753761 return status_t ::success_k;
@@ -763,31 +771,33 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
763771 * overlapping matches.
764772 */
765773 template <typename char_type_>
766- size_t count_matches_in_one_part (span<char_type_ const > haystack, size_t core_index,
767- size_t cores_total ) const noexcept {
774+ size_t count_matches_in_one_part (span<char_type_ const > haystack, size_t core_index, size_t cores_total,
775+ size_t cache_line_width ) const noexcept {
768776
769777 using char_t = char_type_;
770778 char_t const *haystack_begin = haystack.data ();
771779 size_t const haystack_length = haystack.size ();
772780 char_t const *const haystack_end = haystack_begin + haystack_length;
773- size_t const bytes_per_core_optimal = haystack_length / cores_total;
774- size_t const max_needle_length = dict_.max_needle_length ();
781+ size_t const bytes_per_core_optimal =
782+ round_up_to_multiple (divide_round_up (haystack_length, cores_total), cache_line_width);
783+
784+ // We shouldn't even consider needles longer than the haystack
785+ size_t const max_needle_length = std::min (dict_.max_needle_length (), haystack_length);
775786
776787 // The last core may have a smaller slice, so we need to be careful
777- size_t const bytes_per_core =
788+ size_t const bytes_for_core =
778789 std::min (bytes_per_core_optimal, haystack_length - core_index * bytes_per_core_optimal);
779790
780791 // First, each core will process its own slice excluding the overlapping regions
781792 char_t const *optimal_start = haystack_begin + core_index * bytes_per_core_optimal;
782- char_t const *optimal_end = optimal_start + bytes_per_core ;
793+ char_t const *optimal_end = optimal_start + bytes_for_core ;
783794 size_t const count_matches_non_overlapping = dict_.count ({optimal_start, optimal_end});
784795
785796 // Now, each thread will take care of the subsequent overlapping regions,
786797 // but we must be careful for cases when the core-specific slice is shorter
787798 // than the longest needle! It's a very unlikely case in practice, but we
788799 // still may want an optimization for it down the road.
789- char_t const *overlapping_start =
790- std::min (optimal_start + bytes_per_core - max_needle_length + 1 , haystack_end);
800+ char_t const *overlapping_start = std::max (optimal_end - max_needle_length + 1 , haystack_begin);
791801 char_t const *overlapping_end = std::min (optimal_end + max_needle_length - 1 , haystack_end);
792802 size_t count_matches_overlapping = 0 ;
793803 dict_.find ({overlapping_start, overlapping_end}, [&](match_t match) noexcept {
0 commit comments