Skip to content

Commit c8dc314

Browse files
committed
Add: Parallel multi-needle search with OpenMP
1 parent cf53b9b commit c8dc314

File tree

2 files changed

+312
-106
lines changed

2 files changed

+312
-106
lines changed

include/stringcuzilla/find_many.hpp

Lines changed: 140 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ struct find_many_match_t {
6464
span<char const> needle {};
6565
size_t haystack_index {};
6666
size_t needle_index {};
67+
68+
inline static bool less_globally(find_many_match_t const &lhs, find_many_match_t const &rhs) noexcept {
69+
return lhs.needle.data() < rhs.needle.data() ||
70+
(lhs.needle.data() == rhs.needle.data() && lhs.needle.end() < rhs.needle.end());
71+
}
6772
};
6873

6974
template <typename value_type_>
@@ -287,6 +292,7 @@ struct aho_corasick_dictionary {
287292
* @retval `status_t::success_k` The needle was successfully added.
288293
* @retval `status_t::bad_alloc_k` Memory allocation failed.
289294
* @retval `status_t::overflow_risk_k` Too many needles for the current state ID type.
295+
* @retval `status_t::contains_duplicates_k` The needle is already in the vocabulary.
290296
*/
291297
status_t try_insert(span<char const> needle) noexcept {
292298
if (!needle.size()) return status_t::success_k; // Don't care about empty needles.
@@ -315,6 +321,10 @@ struct aho_corasick_dictionary {
315321
current_state = current_row[symbol];
316322
}
317323

324+
// If the terminal state's output is already set, the needle already exists.
325+
if (outputs_[current_state] != invalid_state_k) return status_t::contains_duplicates_k;
326+
327+
// Populate the new state.
318328
outputs_[current_state] = needle_id;
319329
needles_lengths_.try_push_back(needle.size()); // ? Can't fail due to `try_reserve` above
320330
outputs_counts_[current_state] = 1; // ? This will snowball in `try_build` if needles have shared suffixes
@@ -466,18 +476,25 @@ struct find_many {
466476
using match_t = typename dictionary_t::match_t;
467477

468478
find_many(allocator_t alloc = allocator_t()) noexcept : dict_(alloc) {}
479+
void reset() noexcept { dict_.reset(); }
469480

481+
/**
482+
* @brief Indexes all of the @p needles strings into the FSM.
483+
* @retval `status_t::success_k` The needle was successfully added.
484+
* @retval `status_t::bad_alloc_k` Memory allocation failed.
485+
* @retval `status_t::overflow_risk_k` Too many needles for the current state ID type.
486+
* @retval `status_t::contains_duplicates_k` The needle is already in the vocabulary.
487+
* @note Before reusing, please `reset` the FSM.
488+
*/
470489
template <typename needles_type_>
471-
status_t try_build(needles_type_ &&needles_strings) noexcept {
472-
for (auto const &needle : needles_strings) {
490+
status_t try_build(needles_type_ &&needles) noexcept {
491+
for (auto const &needle : needles) {
473492
status_t status = dict_.try_insert(needle);
474493
if (status != status_t::success_k) return status;
475494
}
476495
return dict_.try_build();
477496
}
478497

479-
void reset() noexcept { dict_.reset(); }
480-
481498
/**
482499
* @brief Counts the number of occurrences of all needles in all @p haystacks. Relevant for filtering and ranking.
483500
* @param[in] haystacks The input strings to search in.
@@ -539,15 +556,24 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
539556

540557
using size_allocator_t = typename std::allocator_traits<allocator_t>::template rebind_alloc<size_t>;
541558

559+
find_many(allocator_t alloc = allocator_t()) noexcept : dict_(alloc) {}
560+
void reset() noexcept { dict_.reset(); }
561+
562+
/**
563+
* @brief Indexes all of the @p needles strings into the FSM.
564+
* @retval `status_t::success_k` The needle was successfully added.
565+
* @retval `status_t::bad_alloc_k` Memory allocation failed.
566+
* @retval `status_t::overflow_risk_k` Too many needles for the current state ID type.
567+
* @retval `status_t::contains_duplicates_k` The needle is already in the vocabulary.
568+
* @note Before reusing, please `reset` the FSM.
569+
*/
542570
template <typename needles_type_>
543-
status_t try_build(needles_type_ &&needles_strings) noexcept {
544-
for (auto const &needle : needles_strings)
571+
status_t try_build(needles_type_ &&needles) noexcept {
572+
for (auto const &needle : needles)
545573
if (status_t status = dict_.try_insert(needle); status != status_t::success_k) return status;
546574
return dict_.try_build();
547575
}
548576

549-
void reset() noexcept { dict_.reset(); }
550-
551577
/**
552578
* @brief Counts the number of occurrences of all needles in all @p haystacks. Relevant for filtering and ranking.
553579
* @param[in] haystacks The input strings to search in.
@@ -560,11 +586,9 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
560586

561587
_sz_assert(counts.size() == haystacks.size());
562588
size_t const cores_total = specs.cores_total();
563-
size_t const max_needle_length = dict_.max_needle_length();
564589

565590
using haystacks_t = typename std::remove_reference_t<haystacks_type_>;
566591
using haystack_t = typename haystacks_t::value_type;
567-
using char_t = typename haystack_t::value_type;
568592

569593
// On small strings, individually compute the counts
570594
#pragma omp parallel for schedule(dynamic, 1) num_threads(cores_total)
@@ -582,37 +606,14 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
582606
// The shorter strings have already been processed
583607
if (haystack_length <= specs.l2_bytes) continue;
584608

585-
// First, each core will process its own slice excluding the overlapping regions
586-
char_t const *haystack_begin = haystack.data();
587-
char_t const *const haystack_end = haystack_begin + haystack_length;
588-
size_t const bytes_per_core_optimal = haystack_length / cores_total;
589609
size_t count_matches_across_cores = 0;
590610
#pragma omp parallel for reduction(+ : count_matches_across_cores) schedule(static, 1) num_threads(cores_total)
591-
for (size_t j = 0; j < cores_total; ++j) {
592-
size_t const bytes_per_core =
593-
std::min(bytes_per_core_optimal, haystack_length - j * bytes_per_core_optimal);
594-
char_t const *optimal_start = haystack_begin + j * bytes_per_core_optimal;
595-
char_t const *optimal_end = optimal_start + bytes_per_core;
596-
size_t const count_matches_non_overlapping = dict_.count({optimal_start, optimal_end});
597-
598-
// Now, each thread will take care of the subsequent overlapping regions,
599-
// but we must be careful for cases when the core-specific slice is shorter
600-
// than the longest needle! It's a very unlikely case in practice, but we
601-
// still may want an optimization for it down the road.
602-
char_t const *overlapping_start =
603-
std::min(optimal_start + bytes_per_core - max_needle_length + 1, haystack_end);
604-
char_t const *overlapping_end = std::min(optimal_end + max_needle_length - 1, haystack_end);
605-
size_t count_matches_overlapping = 0;
606-
dict_.find({overlapping_start, overlapping_end}, [&](match_t match) noexcept {
607-
bool is_boundary = match.needle.begin() < optimal_end && match.needle.end() >= optimal_end;
608-
count_matches_overlapping += is_boundary;
609-
return true;
610-
});
611-
612-
// Now, finally, aggregate the results
613-
count_matches_across_cores += count_matches_non_overlapping;
614-
count_matches_across_cores += count_matches_overlapping;
611+
for (size_t core_index = 0; core_index < cores_total; ++core_index) {
612+
size_t count_matches_on_one_core = count_matches_in_one_part(haystack, core_index, cores_total);
613+
count_matches_across_cores += count_matches_on_one_core;
615614
}
615+
616+
counts[i] = count_matches_across_cores;
616617
}
617618

618619
return status_t::success_k;
@@ -661,11 +662,11 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
661662

662663
_sz_assert(counts.size() == haystacks.size());
663664
size_t const cores_total = specs.cores_total();
664-
// size_t const max_needle_length = dict_.max_needle_length();
665+
size_t const max_needle_length = dict_.max_needle_length();
665666

666667
using haystacks_t = typename std::remove_reference_t<haystacks_type_>;
667668
using haystack_t = typename haystacks_t::value_type;
668-
// using char_t = typename std::iterator_traits<haystack_t>::value_type;
669+
using char_t = typename haystack_t::value_type;
669670

670671
// Calculate the exclusive prefix sum of the counts to navigate into the `matches` array
671672
safe_vector<size_t, size_allocator_t> offsets_per_haystack(dict_.allocator());
@@ -691,7 +692,61 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
691692
}
692693

693694
// On longer strings, throw all cores on each haystack, but between the threads we need additional
694-
// memory to track the number of matches within a core-specific slice of the haystack
695+
// memory to track the number of matches within a core-specific slice of the haystack.
696+
safe_vector<size_t, size_allocator_t> counts_per_core(dict_.allocator());
697+
if (counts_per_core.try_resize(cores_total) != status_t::success_k) return status_t::bad_alloc_k;
698+
for (size_t i = 0; i < counts.size(); ++i) {
699+
haystack_t const &haystack = haystacks[i];
700+
char_t const *haystack_begin = haystack.data();
701+
size_t const haystack_length = haystack.size();
702+
// The shorter strings have already been processed
703+
if (haystack_length <= specs.l2_bytes) continue;
704+
705+
// First, on each core, estimate the number of matches in the haystack
706+
#pragma omp parallel for schedule(static, 1) num_threads(cores_total)
707+
for (size_t core_index = 0; core_index < cores_total; ++core_index)
708+
counts_per_core[core_index] = count_matches_in_one_part(haystack, core_index, cores_total);
709+
710+
// Now that we know the number of matches to expect per slice, we can convert the counts
711+
// into offsets using inclusive prefix sum
712+
#pragma omp barrier
713+
#pragma omp single
714+
{
715+
for (size_t core_index = 1; core_index < cores_total; ++core_index)
716+
counts_per_core[core_index] += counts_per_core[core_index - 1];
717+
}
718+
719+
// On each core, pick an overlapping slice and go through all of the matches in it,
720+
// that start before the end of the private slice.
721+
size_t const bytes_per_core_optimal = haystack_length / cores_total;
722+
size_t const count_matches_before_this_haystack = offsets_per_haystack[i];
723+
#pragma omp parallel for schedule(static, 1) num_threads(cores_total)
724+
for (size_t core_index = 0; core_index < cores_total; ++core_index) {
725+
size_t const count_matches_before_this_core = core_index ? counts_per_core[core_index - 1] : 0;
726+
size_t const count_matches_expected_on_this_core =
727+
counts_per_core[core_index] - count_matches_before_this_core;
728+
729+
// The last core may have a smaller slice, so we need to be careful
730+
size_t const bytes_per_core =
731+
std::min(bytes_per_core_optimal, haystack_length - core_index * bytes_per_core_optimal);
732+
char_t const *optimal_start = haystack_begin + core_index * bytes_per_core_optimal;
733+
char_t const *optimal_end = optimal_start + bytes_per_core;
734+
char_t const *overlapping_end =
735+
std::min(optimal_start + bytes_per_core + max_needle_length - 1, haystack_begin + haystack_length);
736+
737+
// Iterate through the matches in the overlapping region
738+
size_t count_matches_found_on_this_core = 0;
739+
dict_.find({optimal_start, overlapping_end}, [&](match_t match) noexcept {
740+
bool blongs_to_this_core = match.needle.begin() < optimal_end;
741+
if (!blongs_to_this_core) return true;
742+
matches[count_matches_before_this_haystack + count_matches_before_this_core +
743+
count_matches_found_on_this_core] = match;
744+
count_matches_found_on_this_core++;
745+
return true;
746+
});
747+
_sz_assert(count_matches_found_on_this_core == count_matches_expected_on_this_core);
748+
}
749+
}
695750

696751
matches_count = 0;
697752
for (size_t count : counts) matches_count += count;
@@ -700,6 +755,50 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
700755

701756
private:
702757
dictionary_t dict_;
758+
759+
/**
760+
* @brief Helper method implementing the core logic of the parallel `try_count` and part of `try_find`.
761+
* For a given single input haystack, assumes all of the cores are processing it in parallel,
762+
* and this method is called from each core with its own index to count the number of potentially
763+
* overlapping matches.
764+
*/
765+
template <typename char_type_>
766+
size_t count_matches_in_one_part(span<char_type_ const> haystack, size_t core_index,
767+
size_t cores_total) const noexcept {
768+
769+
using char_t = char_type_;
770+
char_t const *haystack_begin = haystack.data();
771+
size_t const haystack_length = haystack.size();
772+
char_t const *const haystack_end = haystack_begin + haystack_length;
773+
size_t const bytes_per_core_optimal = haystack_length / cores_total;
774+
size_t const max_needle_length = dict_.max_needle_length();
775+
776+
// The last core may have a smaller slice, so we need to be careful
777+
size_t const bytes_per_core =
778+
std::min(bytes_per_core_optimal, haystack_length - core_index * bytes_per_core_optimal);
779+
780+
// First, each core will process its own slice excluding the overlapping regions
781+
char_t const *optimal_start = haystack_begin + core_index * bytes_per_core_optimal;
782+
char_t const *optimal_end = optimal_start + bytes_per_core;
783+
size_t const count_matches_non_overlapping = dict_.count({optimal_start, optimal_end});
784+
785+
// Now, each thread will take care of the subsequent overlapping regions,
786+
// but we must be careful for cases when the core-specific slice is shorter
787+
// than the longest needle! It's a very unlikely case in practice, but we
788+
// still may want an optimization for it down the road.
789+
char_t const *overlapping_start =
790+
std::min(optimal_start + bytes_per_core - max_needle_length + 1, haystack_end);
791+
char_t const *overlapping_end = std::min(optimal_end + max_needle_length - 1, haystack_end);
792+
size_t count_matches_overlapping = 0;
793+
dict_.find({overlapping_start, overlapping_end}, [&](match_t match) noexcept {
794+
bool is_boundary = match.needle.begin() < optimal_end && match.needle.end() >= optimal_end;
795+
count_matches_overlapping += is_boundary;
796+
return true;
797+
});
798+
799+
// Now, finally, aggregate the results
800+
return count_matches_non_overlapping + count_matches_overlapping;
801+
}
703802
};
704803

705804
#pragma endregion // Parallel OpenMP Backend

0 commit comments

Comments
 (0)