@@ -64,6 +64,11 @@ struct find_many_match_t {
6464 span<char const > needle {};
6565 size_t haystack_index {};
6666 size_t needle_index {};
67+
68+ inline static bool less_globally (find_many_match_t const &lhs, find_many_match_t const &rhs) noexcept {
69+ return lhs.needle .data () < rhs.needle .data () ||
70+ (lhs.needle .data () == rhs.needle .data () && lhs.needle .end () < rhs.needle .end ());
71+ }
6772};
6873
6974template <typename value_type_>
@@ -287,6 +292,7 @@ struct aho_corasick_dictionary {
287292 * @retval `status_t::success_k` The needle was successfully added.
288293 * @retval `status_t::bad_alloc_k` Memory allocation failed.
289294 * @retval `status_t::overflow_risk_k` Too many needles for the current state ID type.
295+ * @retval `status_t::contains_duplicates_k` The needle is already in the vocabulary.
290296 */
291297 status_t try_insert (span<char const > needle) noexcept {
292298 if (!needle.size ()) return status_t ::success_k; // Don't care about empty needles.
@@ -315,6 +321,10 @@ struct aho_corasick_dictionary {
315321 current_state = current_row[symbol];
316322 }
317323
324+ // If the terminal state's output is already set, the needle already exists.
325+ if (outputs_[current_state] != invalid_state_k) return status_t ::contains_duplicates_k;
326+
327+ // Populate the new state.
318328 outputs_[current_state] = needle_id;
319329 needles_lengths_.try_push_back (needle.size ()); // ? Can't fail due to `try_reserve` above
320330 outputs_counts_[current_state] = 1 ; // ? This will snowball in `try_build` if needles have shared suffixes
@@ -466,18 +476,25 @@ struct find_many {
466476 using match_t = typename dictionary_t ::match_t ;
467477
468478 find_many (allocator_t alloc = allocator_t ()) noexcept : dict_(alloc) {}
479+ void reset () noexcept { dict_.reset (); }
469480
481+ /* *
482+ * @brief Indexes all of the @p needles strings into the FSM.
483+ * @retval `status_t::success_k` The needle was successfully added.
484+ * @retval `status_t::bad_alloc_k` Memory allocation failed.
485+ * @retval `status_t::overflow_risk_k` Too many needles for the current state ID type.
486+ * @retval `status_t::contains_duplicates_k` The needle is already in the vocabulary.
487+ * @note Before reusing, please `reset` the FSM.
488+ */
470489 template <typename needles_type_>
471- status_t try_build (needles_type_ &&needles_strings ) noexcept {
472- for (auto const &needle : needles_strings ) {
490+ status_t try_build (needles_type_ &&needles ) noexcept {
491+ for (auto const &needle : needles ) {
473492 status_t status = dict_.try_insert (needle);
474493 if (status != status_t ::success_k) return status;
475494 }
476495 return dict_.try_build ();
477496 }
478497
479- void reset () noexcept { dict_.reset (); }
480-
481498 /* *
482499 * @brief Counts the number of occurrences of all needles in all @p haystacks. Relevant for filtering and ranking.
483500 * @param[in] haystacks The input strings to search in.
@@ -539,15 +556,24 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
539556
540557 using size_allocator_t = typename std::allocator_traits<allocator_t >::template rebind_alloc<size_t >;
541558
559+ find_many (allocator_t alloc = allocator_t ()) noexcept : dict_(alloc) {}
560+ void reset () noexcept { dict_.reset (); }
561+
562+ /* *
563+ * @brief Indexes all of the @p needles strings into the FSM.
564+ * @retval `status_t::success_k` The needle was successfully added.
565+ * @retval `status_t::bad_alloc_k` Memory allocation failed.
566+ * @retval `status_t::overflow_risk_k` Too many needles for the current state ID type.
567+ * @retval `status_t::contains_duplicates_k` The needle is already in the vocabulary.
568+ * @note Before reusing, please `reset` the FSM.
569+ */
542570 template <typename needles_type_>
543- status_t try_build (needles_type_ &&needles_strings ) noexcept {
544- for (auto const &needle : needles_strings )
571+ status_t try_build (needles_type_ &&needles ) noexcept {
572+ for (auto const &needle : needles )
545573 if (status_t status = dict_.try_insert (needle); status != status_t ::success_k) return status;
546574 return dict_.try_build ();
547575 }
548576
549- void reset () noexcept { dict_.reset (); }
550-
551577 /* *
552578 * @brief Counts the number of occurrences of all needles in all @p haystacks. Relevant for filtering and ranking.
553579 * @param[in] haystacks The input strings to search in.
@@ -560,11 +586,9 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
560586
561587 _sz_assert (counts.size () == haystacks.size ());
562588 size_t const cores_total = specs.cores_total ();
563- size_t const max_needle_length = dict_.max_needle_length ();
564589
565590 using haystacks_t = typename std::remove_reference_t <haystacks_type_>;
566591 using haystack_t = typename haystacks_t ::value_type;
567- using char_t = typename haystack_t ::value_type;
568592
569593 // On small strings, individually compute the counts
570594#pragma omp parallel for schedule(dynamic, 1) num_threads(cores_total)
@@ -582,37 +606,14 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
582606 // The shorter strings have already been processed
583607 if (haystack_length <= specs.l2_bytes ) continue ;
584608
585- // First, each core will process its own slice excluding the overlapping regions
586- char_t const *haystack_begin = haystack.data ();
587- char_t const *const haystack_end = haystack_begin + haystack_length;
588- size_t const bytes_per_core_optimal = haystack_length / cores_total;
589609 size_t count_matches_across_cores = 0 ;
590610#pragma omp parallel for reduction(+ : count_matches_across_cores) schedule(static, 1) num_threads(cores_total)
591- for (size_t j = 0 ; j < cores_total; ++j) {
592- size_t const bytes_per_core =
593- std::min (bytes_per_core_optimal, haystack_length - j * bytes_per_core_optimal);
594- char_t const *optimal_start = haystack_begin + j * bytes_per_core_optimal;
595- char_t const *optimal_end = optimal_start + bytes_per_core;
596- size_t const count_matches_non_overlapping = dict_.count ({optimal_start, optimal_end});
597-
598- // Now, each thread will take care of the subsequent overlapping regions,
599- // but we must be careful for cases when the core-specific slice is shorter
600- // than the longest needle! It's a very unlikely case in practice, but we
601- // still may want an optimization for it down the road.
602- char_t const *overlapping_start =
603- std::min (optimal_start + bytes_per_core - max_needle_length + 1 , haystack_end);
604- char_t const *overlapping_end = std::min (optimal_end + max_needle_length - 1 , haystack_end);
605- size_t count_matches_overlapping = 0 ;
606- dict_.find ({overlapping_start, overlapping_end}, [&](match_t match) noexcept {
607- bool is_boundary = match.needle .begin () < optimal_end && match.needle .end () >= optimal_end;
608- count_matches_overlapping += is_boundary;
609- return true ;
610- });
611-
612- // Now, finally, aggregate the results
613- count_matches_across_cores += count_matches_non_overlapping;
614- count_matches_across_cores += count_matches_overlapping;
611+ for (size_t core_index = 0 ; core_index < cores_total; ++core_index) {
612+ size_t count_matches_on_one_core = count_matches_in_one_part (haystack, core_index, cores_total);
613+ count_matches_across_cores += count_matches_on_one_core;
615614 }
615+
616+ counts[i] = count_matches_across_cores;
616617 }
617618
618619 return status_t ::success_k;
@@ -661,11 +662,11 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
661662
662663 _sz_assert (counts.size () == haystacks.size ());
663664 size_t const cores_total = specs.cores_total ();
664- // size_t const max_needle_length = dict_.max_needle_length();
665+ size_t const max_needle_length = dict_.max_needle_length ();
665666
666667 using haystacks_t = typename std::remove_reference_t <haystacks_type_>;
667668 using haystack_t = typename haystacks_t ::value_type;
668- // using char_t = typename std::iterator_traits< haystack_t> ::value_type;
669+ using char_t = typename haystack_t ::value_type;
669670
670671 // Calculate the exclusive prefix sum of the counts to navigate into the `matches` array
671672 safe_vector<size_t , size_allocator_t > offsets_per_haystack (dict_.allocator ());
@@ -691,7 +692,61 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
691692 }
692693
693694 // On longer strings, throw all cores on each haystack, but between the threads we need additional
694- // memory to track the number of matches within a core-specific slice of the haystack
695+ // memory to track the number of matches within a core-specific slice of the haystack.
696+ safe_vector<size_t , size_allocator_t > counts_per_core (dict_.allocator ());
697+ if (counts_per_core.try_resize (cores_total) != status_t ::success_k) return status_t ::bad_alloc_k;
698+ for (size_t i = 0 ; i < counts.size (); ++i) {
699+ haystack_t const &haystack = haystacks[i];
700+ char_t const *haystack_begin = haystack.data ();
701+ size_t const haystack_length = haystack.size ();
702+ // The shorter strings have already been processed
703+ if (haystack_length <= specs.l2_bytes ) continue ;
704+
705+ // First, on each core, estimate the number of matches in the haystack
706+ #pragma omp parallel for schedule(static, 1) num_threads(cores_total)
707+ for (size_t core_index = 0 ; core_index < cores_total; ++core_index)
708+ counts_per_core[core_index] = count_matches_in_one_part (haystack, core_index, cores_total);
709+
710+ // Now that we know the number of matches to expect per slice, we can convert the counts
711+ // into offsets using inclusive prefix sum
712+ #pragma omp barrier
713+ #pragma omp single
714+ {
715+ for (size_t core_index = 1 ; core_index < cores_total; ++core_index)
716+ counts_per_core[core_index] += counts_per_core[core_index - 1 ];
717+ }
718+
719+ // On each core, pick an overlapping slice and go through all of the matches in it,
720+ // that start before the end of the private slice.
721+ size_t const bytes_per_core_optimal = haystack_length / cores_total;
722+ size_t const count_matches_before_this_haystack = offsets_per_haystack[i];
723+ #pragma omp parallel for schedule(static, 1) num_threads(cores_total)
724+ for (size_t core_index = 0 ; core_index < cores_total; ++core_index) {
725+ size_t const count_matches_before_this_core = core_index ? counts_per_core[core_index - 1 ] : 0 ;
726+ size_t const count_matches_expected_on_this_core =
727+ counts_per_core[core_index] - count_matches_before_this_core;
728+
729+ // The last core may have a smaller slice, so we need to be careful
730+ size_t const bytes_per_core =
731+ std::min (bytes_per_core_optimal, haystack_length - core_index * bytes_per_core_optimal);
732+ char_t const *optimal_start = haystack_begin + core_index * bytes_per_core_optimal;
733+ char_t const *optimal_end = optimal_start + bytes_per_core;
734+ char_t const *overlapping_end =
735+ std::min (optimal_start + bytes_per_core + max_needle_length - 1 , haystack_begin + haystack_length);
736+
737+ // Iterate through the matches in the overlapping region
738+ size_t count_matches_found_on_this_core = 0 ;
739+ dict_.find ({optimal_start, overlapping_end}, [&](match_t match) noexcept {
740+ bool blongs_to_this_core = match.needle .begin () < optimal_end;
741+ if (!blongs_to_this_core) return true ;
742+ matches[count_matches_before_this_haystack + count_matches_before_this_core +
743+ count_matches_found_on_this_core] = match;
744+ count_matches_found_on_this_core++;
745+ return true ;
746+ });
747+ _sz_assert (count_matches_found_on_this_core == count_matches_expected_on_this_core);
748+ }
749+ }
695750
696751 matches_count = 0 ;
697752 for (size_t count : counts) matches_count += count;
@@ -700,6 +755,50 @@ struct find_many<state_id_type_, allocator_type_, sz_caps_sp_k, enable_> {
700755
701756 private:
702757 dictionary_t dict_;
758+
759+ /* *
760+ * @brief Helper method implementing the core logic of the parallel `try_count` and part of `try_find`.
761+ * For a given single input haystack, assumes all of the cores are processing it in parallel,
762+ * and this method is called from each core with its own index to count the number of potentially
763+ * overlapping matches.
764+ */
765+ template <typename char_type_>
766+ size_t count_matches_in_one_part (span<char_type_ const > haystack, size_t core_index,
767+ size_t cores_total) const noexcept {
768+
769+ using char_t = char_type_;
770+ char_t const *haystack_begin = haystack.data ();
771+ size_t const haystack_length = haystack.size ();
772+ char_t const *const haystack_end = haystack_begin + haystack_length;
773+ size_t const bytes_per_core_optimal = haystack_length / cores_total;
774+ size_t const max_needle_length = dict_.max_needle_length ();
775+
776+ // The last core may have a smaller slice, so we need to be careful
777+ size_t const bytes_per_core =
778+ std::min (bytes_per_core_optimal, haystack_length - core_index * bytes_per_core_optimal);
779+
780+ // First, each core will process its own slice excluding the overlapping regions
781+ char_t const *optimal_start = haystack_begin + core_index * bytes_per_core_optimal;
782+ char_t const *optimal_end = optimal_start + bytes_per_core;
783+ size_t const count_matches_non_overlapping = dict_.count ({optimal_start, optimal_end});
784+
785+ // Now, each thread will take care of the subsequent overlapping regions,
786+ // but we must be careful for cases when the core-specific slice is shorter
787+ // than the longest needle! It's a very unlikely case in practice, but we
788+ // still may want an optimization for it down the road.
789+ char_t const *overlapping_start =
790+ std::min (optimal_start + bytes_per_core - max_needle_length + 1 , haystack_end);
791+ char_t const *overlapping_end = std::min (optimal_end + max_needle_length - 1 , haystack_end);
792+ size_t count_matches_overlapping = 0 ;
793+ dict_.find ({overlapping_start, overlapping_end}, [&](match_t match) noexcept {
794+ bool is_boundary = match.needle .begin () < optimal_end && match.needle .end () >= optimal_end;
795+ count_matches_overlapping += is_boundary;
796+ return true ;
797+ });
798+
799+ // Now, finally, aggregate the results
800+ return count_matches_non_overlapping + count_matches_overlapping;
801+ }
703802};
704803
705804#pragma endregion // Parallel OpenMP Backend
0 commit comments