@@ -640,7 +640,6 @@ void test_similarity_scores_memory_usage() {
640640}
641641
642642struct find_many_baselines_t {
643- using state_id_t = sz_u32_t ;
644643 using match_t = find_many_match_t ;
645644
646645 arrow_strings_tape_t needles_;
@@ -652,39 +651,65 @@ struct find_many_baselines_t {
652651
653652 void reset () noexcept { needles_.reset (); }
654653
654+ template <typename haystack_type_, typename needle_type_, typename match_callback_type_>
655+ void one_pair (haystack_type_ const &haystack, needle_type_ const &needle,
656+ match_callback_type_ &&callback) const noexcept {
657+
658+ // Define iterators for the current haystack and the needle.
659+ auto haystack_begin = haystack.begin ();
660+ auto haystack_end = haystack.end ();
661+ auto needle_begin = needle.begin ();
662+ auto needle_end = needle.end ();
663+
664+ // Use `std::search` to find all occurrences of needle in haystack.
665+ while (true ) {
666+ auto it = std::search (haystack_begin, haystack_end, needle_begin, needle_end);
667+ if (it == haystack_end) break ;
668+
669+ // Compute the starting index of the found occurrence.
670+ std::size_t found = static_cast <std::size_t >(std::distance (haystack.begin (), it));
671+
672+ // Construct a match record.
673+ match_t match;
674+ match.haystack = {haystack.data (), haystack.size ()};
675+ match.needle = {haystack.data () + found, needle.size ()};
676+
677+ // Invoke the callback. If it returns false, abort all further processing.
678+ if (!callback (match)) return ;
679+
680+ // Advance the starting iterator for the next search.
681+ haystack_begin = it + 1 ;
682+ }
683+ }
684+
655685 template <typename haystacks_type_, typename needles_type_, typename match_callback_type_>
656- void iterate_through_unsorted_matches (haystacks_type_ &&haystacks, needles_type_ &&needles,
657- match_callback_type_ &&callback) const noexcept {
658- for (std::size_t i = 0 ; i != haystacks.size (); ++i) {
659- auto const &haystack = haystacks[i];
660- for (std::size_t j = 0 ; j != needles.size (); ++j) {
686+ void all_pairs (haystacks_type_ &&haystacks, needles_type_ &&needles,
687+ match_callback_type_ &&callback) const noexcept {
688+
689+ // A wise man once said, `omp parallel for collapse(2) schedule(dynamic, 1)`...
690+ // But the compiler wasn't listening, and won't compile the cancellation point!
691+ std::size_t const total_tasks = haystacks.size () * needles.size ();
692+ #pragma omp parallel
693+ {
694+ #pragma omp for schedule(dynamic, 1)
695+ for (std::size_t task = 0 ; task != total_tasks; ++task) {
696+ #pragma omp cancellation point for
697+ std::size_t const i = task / needles.size ();
698+ std::size_t const j = task % needles.size ();
699+
700+ auto const &haystack = haystacks[i];
661701 auto const &needle = needles[j];
662- // Define iterators for the current haystack and the needle.
663- auto haystack_begin = haystack.begin ();
664- auto haystack_end = haystack.end ();
665- auto needle_begin = needle.begin ();
666- auto needle_end = needle.end ();
667-
668- // Use `std::search` to find all occurrences of needle in haystack.
669- while (true ) {
670- auto it = std::search (haystack_begin, haystack_end, needle_begin, needle_end);
671- if (it == haystack_end) break ;
672-
673- // Compute the starting index of the found occurrence.
674- std::size_t found = static_cast <std::size_t >(std::distance (haystack.begin (), it));
675-
676- // Construct a match record.
677- match_t match;
702+ bool keep_going = true ;
703+ one_pair (haystack, needle, [&](match_t match) {
678704 match.haystack_index = i;
679705 match.needle_index = j;
680- match.haystack = {haystack.data (), haystack.size ()};
681- match.needle = {haystack.data () + found, needle.size ()};
682-
683- // Invoke the callback. If it returns false, abort all further processing.
684- if (!callback (match)) return ;
685-
686- // Advance the starting iterator for the next search.
687- haystack_begin = it + 1 ;
706+ #pragma omp critical
707+ { keep_going = callback (match); }
708+ return keep_going;
709+ });
710+ // Quit the outer loop if the callback returns false
711+ if (!keep_going) {
712+ #pragma omp cancel for
688713 }
689714 }
690715 }
@@ -693,7 +718,7 @@ struct find_many_baselines_t {
693718 template <typename haystacks_type_>
694719 status_t try_count (haystacks_type_ &&haystacks, span<size_t > counts) const noexcept {
695720 for (size_t &count : counts) count = 0 ;
696- iterate_through_unsorted_matches (haystacks, needles_, [&](match_t const &match) {
721+ all_pairs (haystacks, needles_, [&](match_t const &match) noexcept {
697722 counts[match.haystack_index ] += 1 ;
698723 return true ;
699724 });
@@ -703,8 +728,9 @@ struct find_many_baselines_t {
703728 template <typename haystacks_type_, typename output_matches_type_>
704729 status_t try_find (haystacks_type_ &&haystacks, output_matches_type_ &&matches,
705730 size_t &matches_total) const noexcept {
731+
706732 size_t count_found = 0 , count_allowed = matches.size ();
707- iterate_through_unsorted_matches (haystacks, needles_, [&](match_t const &match) {
733+ all_pairs (haystacks, needles_, [&](match_t const &match) noexcept {
708734 matches[count_found] = match;
709735 count_found += 1 ;
710736 return count_found < count_allowed;
@@ -783,6 +809,8 @@ void test_find_many_fixed(base_operator_ &&base_operator, simd_operator_ &&simd_
783809 _sz_assert (total_found_simd == matches_simd.size ());
784810
785811 // Check the contents and order of the matches
812+ std::sort (matches_base.begin (), matches_base.end (), match_t ::less_globally);
813+ std::sort (matches_simd.begin (), matches_simd.end (), match_t ::less_globally);
786814 for (std::size_t i = 0 ; i != matches_base.size (); ++i) {
787815 _sz_assert (matches_base[i].haystack .data () == matches_simd[i].haystack .data ());
788816 _sz_assert (matches_base[i].needle .data () == matches_simd[i].needle .data ());
@@ -818,6 +846,8 @@ void test_find_many_fixed(base_operator_ &&base_operator, simd_operator_ &&simd_
818846 _sz_assert (total_found_simd == matches_simd.size ());
819847
820848 // Check the contents and order of the matches
849+ std::sort (matches_base.begin (), matches_base.end (), match_t ::less_globally);
850+ std::sort (matches_simd.begin (), matches_simd.end (), match_t ::less_globally);
821851 for (std::size_t i = 0 ; i != matches_base.size (); ++i) {
822852 _sz_assert (matches_base[i].haystack .data () == matches_simd[i].haystack .data ());
823853 _sz_assert (matches_base[i].needle .data () == matches_simd[i].needle .data ());
@@ -830,60 +860,98 @@ void test_find_many_fixed(base_operator_ &&base_operator, simd_operator_ &&simd_
830860 * @brief Fuzzy test for multi-pattern exact search algorithms using randomly-generated haystacks and needles.
831861 */
832862template <typename base_operator_, typename simd_operator_, typename ... extra_args_>
833- void test_find_many_fuzzy (base_operator_ &&base_operator, simd_operator_ &&simd_operator,
834- fuzzy_config_t needles_config = {}, fuzzy_config_t haystacks_config = {} ,
835- std:: size_t iterations = 10 , extra_args_ &&...extra_args) {
863+ void test_find_many (base_operator_ &&base_operator, simd_operator_ &&simd_operator,
864+ arrow_strings_tape_t const &haystacks_tape, arrow_strings_tape_t const &needles_tape ,
865+ extra_args_ &&...extra_args) {
836866
837867 using match_t = find_many_match_t ;
838868 unified_vector<match_t > results_base, results_simd;
839869 unified_vector<size_t > counts_base, counts_simd;
870+
871+ counts_base.resize (haystacks_tape.size ());
872+ counts_simd.resize (haystacks_tape.size ());
873+
874+ // Build the matchers
875+ _sz_assert (base_operator.try_build (needles_tape.view ()) == status_t ::success_k);
876+ _sz_assert (simd_operator.try_build (needles_tape.view ()) == status_t ::success_k);
877+
878+ // Count the number of matches with both backends
879+ span<size_t > counts_base_span {counts_base.data (), counts_base.size ()};
880+ span<size_t > counts_simd_span {counts_simd.data (), counts_simd.size ()};
881+ status_t status_count_base = base_operator.try_count (haystacks_tape.view (), counts_base_span);
882+ status_t status_count_simd = simd_operator.try_count (haystacks_tape.view (), counts_simd_span, extra_args...);
883+ _sz_assert (status_count_base == status_t ::success_k);
884+ _sz_assert (status_count_simd == status_t ::success_k);
885+ size_t total_count_base = std::accumulate (counts_base.begin (), counts_base.end (), 0 );
886+ size_t total_count_simd = std::accumulate (counts_simd.begin (), counts_simd.end (), 0 );
887+ _sz_assert (total_count_base == total_count_simd);
888+ _sz_assert (std::equal (counts_base.begin (), counts_base.end (), counts_simd.begin ()));
889+
890+ // Compute with both backends
891+ results_base.resize (total_count_base);
892+ results_simd.resize (total_count_simd);
893+ size_t count_base = 0 , count_simd = 0 ;
894+ status_t status_base = base_operator.try_find (haystacks_tape.view (), results_base, count_base);
895+ status_t status_simd = simd_operator.try_find (haystacks_tape.view (), results_simd, count_simd, extra_args...);
896+ _sz_assert (status_base == status_t ::success_k);
897+ _sz_assert (status_simd == status_t ::success_k);
898+ _sz_assert (count_base == count_simd);
899+
900+ // Individually log the failed results
901+ std::sort (results_base.begin (), results_base.end (), match_t ::less_globally);
902+ std::sort (results_simd.begin (), results_simd.end (), match_t ::less_globally);
903+ for (std::size_t i = 0 ; i != results_base.size (); ++i) {
904+ _sz_assert (results_base[i].haystack_index == results_simd[i].haystack_index );
905+ _sz_assert (results_base[i].needle_index == results_simd[i].needle_index );
906+ _sz_assert (results_base[i].needle .data () == results_simd[i].needle .data ());
907+ }
908+
909+ base_operator.reset ();
910+ simd_operator.reset ();
911+ }
912+
913+ /* *
914+ * @brief Fuzzy test for multi-pattern exact search algorithms using randomly-generated haystacks and needles.
915+ */
916+ template <typename base_operator_, typename simd_operator_, typename ... extra_args_>
917+ void test_find_many_fuzzy (base_operator_ &&base_operator, simd_operator_ &&simd_operator,
918+ fuzzy_config_t needles_config = {}, fuzzy_config_t haystacks_config = {},
919+ std::size_t iterations = 10 , extra_args_ &&...extra_args) {
920+
840921 std::vector<std::string> haystacks_array, needles_array;
841922 arrow_strings_tape_t haystacks_tape, needles_tape;
842923
843924 // Generate some random strings, using a small alphabet
844925 for (std::size_t iteration_idx = 0 ; iteration_idx < iterations; ++iteration_idx) {
845926 randomize_strings (haystacks_config, haystacks_array, haystacks_tape);
846927 randomize_strings (needles_config, needles_array, needles_tape, true );
847- counts_base.resize (haystacks_array.size ());
848- counts_simd.resize (haystacks_array.size ());
928+ test_find_many (base_operator, simd_operator, haystacks_tape, needles_tape, extra_args...);
929+ }
930+ }
849931
850- // Build the matchers
851- _sz_assert (base_operator.try_build (needles_tape.view ()) == status_t ::success_k);
852- _sz_assert (simd_operator.try_build (needles_tape.view ()) == status_t ::success_k);
932+ /* *
933+ * @brief Fuzzy test for multi-pattern exact search algorithms using randomly-generated haystacks,
934+ * and using incrementally longer potentially-overlapping substrings as needles.
935+ */
936+ template <typename base_operator_, typename simd_operator_, typename ... extra_args_>
937+ void test_find_many_prefixes (base_operator_ &&base_operator, simd_operator_ &&simd_operator,
938+ fuzzy_config_t haystacks_config, std::size_t needle_length_limit,
939+ std::size_t iterations = 10 , extra_args_ &&...extra_args) {
853940
854- // Count the number of matches with both backends
855- span<size_t > counts_base_span {counts_base.data (), counts_base.size ()};
856- span<size_t > counts_simd_span {counts_simd.data (), counts_simd.size ()};
857- status_t status_count_base = base_operator.try_count (haystacks_tape.view (), counts_base_span);
858- status_t status_count_simd = simd_operator.try_count (haystacks_tape.view (), counts_simd_span, extra_args...);
859- _sz_assert (status_count_base == status_t ::success_k);
860- _sz_assert (status_count_simd == status_t ::success_k);
861- size_t total_count_base = std::accumulate (counts_base.begin (), counts_base.end (), 0 );
862- size_t total_count_simd = std::accumulate (counts_simd.begin (), counts_simd.end (), 0 );
863- _sz_assert (total_count_base == total_count_simd);
864- _sz_assert (std::equal (counts_base.begin (), counts_base.end (), counts_simd.begin ()));
941+ std::vector<std::string> haystacks_array;
942+ std::vector<std::string_view> needles_array;
943+ arrow_strings_tape_t haystacks_tape, needles_tape;
865944
866- // Compute with both backends
867- results_base.resize (total_count_base);
868- results_simd.resize (total_count_simd);
869- size_t count_base = 0 , count_simd = 0 ;
870- status_t status_base = base_operator.try_find (haystacks_tape.view (), results_base, count_base);
871- status_t status_simd = simd_operator.try_find (haystacks_tape.view (), results_simd, count_simd, extra_args...);
872- _sz_assert (status_base == status_t ::success_k);
873- _sz_assert (status_simd == status_t ::success_k);
874- _sz_assert (count_base == count_simd);
945+ for (std::size_t iteration_idx = 0 ; iteration_idx < iterations; ++iteration_idx) {
946+ randomize_strings (haystacks_config, haystacks_array, haystacks_tape);
875947
876- // Individually log the failed results
877- std::sort (results_base.begin (), results_base.end (), match_t ::less_globally);
878- std::sort (results_simd.begin (), results_simd.end (), match_t ::less_globally);
879- for (std::size_t i = 0 ; i != results_base.size (); ++i) {
880- _sz_assert (results_base[i].haystack_index == results_simd[i].haystack_index );
881- _sz_assert (results_base[i].needle_index == results_simd[i].needle_index );
882- _sz_assert (results_base[i].needle .data () == results_simd[i].needle .data ());
883- }
948+ // Pick various substrings as needles from the first haystack
949+ needles_array.resize (std::min (haystacks_array[0 ].size (), needle_length_limit));
950+ for (std::size_t i = 0 ; i != needles_array.size (); ++i)
951+ needles_array[i] = std::string_view (haystacks_array[0 ]).substr (0 , i + 1 );
952+ needles_tape.try_assign (needles_array.data (), needles_array.data () + needles_array.size ());
884953
885- base_operator.reset ();
886- simd_operator.reset ();
954+ test_find_many (base_operator, simd_operator, haystacks_tape, needles_tape, extra_args...);
887955 }
888956}
889957
@@ -894,33 +962,31 @@ void test_find_many_fuzzy(base_operator_ &&base_operator, simd_operator_ &&simd_
894962void test_find_many_equivalence () {
895963
896964 cpu_specs_t default_cpu_specs;
897- fuzzy_config_t needles_short_config, needles_long_config, haystacks_long_config ;
898- haystacks_long_config .batch_size = default_cpu_specs.cores_total () * 4 ;
899- haystacks_long_config .max_string_length = default_cpu_specs.l3_bytes ;
965+ fuzzy_config_t needles_short_config, needles_long_config, haystacks_config ;
966+ haystacks_config .batch_size = default_cpu_specs.cores_total () * 4 ;
967+ haystacks_config .max_string_length = default_cpu_specs.l3_bytes ;
900968
901969 needles_long_config.min_string_length = 8 ;
902970 needles_long_config.max_string_length = 10 ;
903971 needles_long_config.batch_size =
904972 std::pow (needles_long_config.alphabet .size (), needles_long_config.max_string_length );
905973
906- needles_long_config .min_string_length = 1 ;
907- needles_long_config .max_string_length = 6 ;
908- needles_long_config .batch_size =
909- std::pow (needles_long_config .alphabet .size (), needles_long_config .max_string_length );
974+ needles_short_config .min_string_length = 1 ;
975+ needles_short_config .max_string_length = 6 ;
976+ needles_short_config .batch_size =
977+ std::pow (needles_short_config .alphabet .size (), needles_short_config .max_string_length );
910978
911- // Single-threaded serial Levenshtein distance implementation
979+ // Single-threaded serial Aho-Corasick implementation
912980 test_find_many_fixed (find_many_baselines_t {}, find_many_serial_t {});
913- test_find_many_fuzzy (find_many_baselines_t {}, find_many_serial_t {}, needles_short_config, haystacks_long_config,
914- 1 );
915- test_find_many_fuzzy (find_many_baselines_t {}, find_many_serial_t {}, needles_long_config, haystacks_long_config,
916- 1 );
981+ test_find_many_fuzzy (find_many_baselines_t {}, find_many_serial_t {}, needles_short_config, haystacks_config, 1 );
982+ test_find_many_fuzzy (find_many_baselines_t {}, find_many_serial_t {}, needles_long_config, haystacks_config, 1 );
983+ test_find_many_prefixes (find_many_baselines_t {}, find_many_serial_t {}, haystacks_config, 1024 , 1 );
917984
918- // Multi-threaded parallel Levenshtein distance implementation
985+ // Multi-threaded parallel Aho-Corasick implementation
919986 test_find_many_fixed (find_many_baselines_t {}, find_many_parallel_t {});
920- test_find_many_fuzzy (find_many_baselines_t {}, find_many_parallel_t {}, needles_short_config, haystacks_long_config,
921- 10 );
922- test_find_many_fuzzy (find_many_baselines_t {}, find_many_parallel_t {}, needles_long_config, haystacks_long_config,
923- 10 );
987+ test_find_many_fuzzy (find_many_baselines_t {}, find_many_parallel_t {}, needles_short_config, haystacks_config, 10 );
988+ test_find_many_fuzzy (find_many_baselines_t {}, find_many_parallel_t {}, needles_long_config, haystacks_config, 10 );
989+ test_find_many_prefixes (find_many_baselines_t {}, find_many_parallel_t {}, haystacks_config, 1024 , 10 );
924990}
925991
926992} // namespace scripts
0 commit comments