Skip to content

Commit 6ebc7b0

Browse files
committed
Improve: Parallel baseline for substring search
1 parent 83bc966 commit 6ebc7b0

File tree

1 file changed

+153
-87
lines changed

1 file changed

+153
-87
lines changed

scripts/test_stringcuzilla.cuh

Lines changed: 153 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,6 @@ void test_similarity_scores_memory_usage() {
640640
}
641641

642642
struct find_many_baselines_t {
643-
using state_id_t = sz_u32_t;
644643
using match_t = find_many_match_t;
645644

646645
arrow_strings_tape_t needles_;
@@ -652,39 +651,65 @@ struct find_many_baselines_t {
652651

653652
void reset() noexcept { needles_.reset(); }
654653

654+
template <typename haystack_type_, typename needle_type_, typename match_callback_type_>
655+
void one_pair(haystack_type_ const &haystack, needle_type_ const &needle,
656+
match_callback_type_ &&callback) const noexcept {
657+
658+
// Define iterators for the current haystack and the needle.
659+
auto haystack_begin = haystack.begin();
660+
auto haystack_end = haystack.end();
661+
auto needle_begin = needle.begin();
662+
auto needle_end = needle.end();
663+
664+
// Use `std::search` to find all occurrences of needle in haystack.
665+
while (true) {
666+
auto it = std::search(haystack_begin, haystack_end, needle_begin, needle_end);
667+
if (it == haystack_end) break;
668+
669+
// Compute the starting index of the found occurrence.
670+
std::size_t found = static_cast<std::size_t>(std::distance(haystack.begin(), it));
671+
672+
// Construct a match record.
673+
match_t match;
674+
match.haystack = {haystack.data(), haystack.size()};
675+
match.needle = {haystack.data() + found, needle.size()};
676+
677+
// Invoke the callback. If it returns false, abort all further processing.
678+
if (!callback(match)) return;
679+
680+
// Advance the starting iterator for the next search.
681+
haystack_begin = it + 1;
682+
}
683+
}
684+
655685
template <typename haystacks_type_, typename needles_type_, typename match_callback_type_>
656-
void iterate_through_unsorted_matches(haystacks_type_ &&haystacks, needles_type_ &&needles,
657-
match_callback_type_ &&callback) const noexcept {
658-
for (std::size_t i = 0; i != haystacks.size(); ++i) {
659-
auto const &haystack = haystacks[i];
660-
for (std::size_t j = 0; j != needles.size(); ++j) {
686+
void all_pairs(haystacks_type_ &&haystacks, needles_type_ &&needles,
687+
match_callback_type_ &&callback) const noexcept {
688+
689+
// A wise man once said, `omp parallel for collapse(2) schedule(dynamic, 1)`...
690+
// But the compiler wasn't listening, and won't compile the cancellation point!
691+
std::size_t const total_tasks = haystacks.size() * needles.size();
692+
#pragma omp parallel
693+
{
694+
#pragma omp for schedule(dynamic, 1)
695+
for (std::size_t task = 0; task != total_tasks; ++task) {
696+
#pragma omp cancellation point for
697+
std::size_t const i = task / needles.size();
698+
std::size_t const j = task % needles.size();
699+
700+
auto const &haystack = haystacks[i];
661701
auto const &needle = needles[j];
662-
// Define iterators for the current haystack and the needle.
663-
auto haystack_begin = haystack.begin();
664-
auto haystack_end = haystack.end();
665-
auto needle_begin = needle.begin();
666-
auto needle_end = needle.end();
667-
668-
// Use `std::search` to find all occurrences of needle in haystack.
669-
while (true) {
670-
auto it = std::search(haystack_begin, haystack_end, needle_begin, needle_end);
671-
if (it == haystack_end) break;
672-
673-
// Compute the starting index of the found occurrence.
674-
std::size_t found = static_cast<std::size_t>(std::distance(haystack.begin(), it));
675-
676-
// Construct a match record.
677-
match_t match;
702+
bool keep_going = true;
703+
one_pair(haystack, needle, [&](match_t match) {
678704
match.haystack_index = i;
679705
match.needle_index = j;
680-
match.haystack = {haystack.data(), haystack.size()};
681-
match.needle = {haystack.data() + found, needle.size()};
682-
683-
// Invoke the callback. If it returns false, abort all further processing.
684-
if (!callback(match)) return;
685-
686-
// Advance the starting iterator for the next search.
687-
haystack_begin = it + 1;
706+
#pragma omp critical
707+
{ keep_going = callback(match); }
708+
return keep_going;
709+
});
710+
// Quit the outer loop if the callback returns false
711+
if (!keep_going) {
712+
#pragma omp cancel for
688713
}
689714
}
690715
}
@@ -693,7 +718,7 @@ struct find_many_baselines_t {
693718
template <typename haystacks_type_>
694719
status_t try_count(haystacks_type_ &&haystacks, span<size_t> counts) const noexcept {
695720
for (size_t &count : counts) count = 0;
696-
iterate_through_unsorted_matches(haystacks, needles_, [&](match_t const &match) {
721+
all_pairs(haystacks, needles_, [&](match_t const &match) noexcept {
697722
counts[match.haystack_index] += 1;
698723
return true;
699724
});
@@ -703,8 +728,9 @@ struct find_many_baselines_t {
703728
template <typename haystacks_type_, typename output_matches_type_>
704729
status_t try_find(haystacks_type_ &&haystacks, output_matches_type_ &&matches,
705730
size_t &matches_total) const noexcept {
731+
706732
size_t count_found = 0, count_allowed = matches.size();
707-
iterate_through_unsorted_matches(haystacks, needles_, [&](match_t const &match) {
733+
all_pairs(haystacks, needles_, [&](match_t const &match) noexcept {
708734
matches[count_found] = match;
709735
count_found += 1;
710736
return count_found < count_allowed;
@@ -783,6 +809,8 @@ void test_find_many_fixed(base_operator_ &&base_operator, simd_operator_ &&simd_
783809
_sz_assert(total_found_simd == matches_simd.size());
784810

785811
// Check the contents and order of the matches
812+
std::sort(matches_base.begin(), matches_base.end(), match_t::less_globally);
813+
std::sort(matches_simd.begin(), matches_simd.end(), match_t::less_globally);
786814
for (std::size_t i = 0; i != matches_base.size(); ++i) {
787815
_sz_assert(matches_base[i].haystack.data() == matches_simd[i].haystack.data());
788816
_sz_assert(matches_base[i].needle.data() == matches_simd[i].needle.data());
@@ -818,6 +846,8 @@ void test_find_many_fixed(base_operator_ &&base_operator, simd_operator_ &&simd_
818846
_sz_assert(total_found_simd == matches_simd.size());
819847

820848
// Check the contents and order of the matches
849+
std::sort(matches_base.begin(), matches_base.end(), match_t::less_globally);
850+
std::sort(matches_simd.begin(), matches_simd.end(), match_t::less_globally);
821851
for (std::size_t i = 0; i != matches_base.size(); ++i) {
822852
_sz_assert(matches_base[i].haystack.data() == matches_simd[i].haystack.data());
823853
_sz_assert(matches_base[i].needle.data() == matches_simd[i].needle.data());
@@ -830,60 +860,98 @@ void test_find_many_fixed(base_operator_ &&base_operator, simd_operator_ &&simd_
830860
* @brief Fuzzy test for multi-pattern exact search algorithms using randomly-generated haystacks and needles.
831861
*/
832862
template <typename base_operator_, typename simd_operator_, typename... extra_args_>
833-
void test_find_many_fuzzy(base_operator_ &&base_operator, simd_operator_ &&simd_operator,
834-
fuzzy_config_t needles_config = {}, fuzzy_config_t haystacks_config = {},
835-
std::size_t iterations = 10, extra_args_ &&...extra_args) {
863+
void test_find_many(base_operator_ &&base_operator, simd_operator_ &&simd_operator,
864+
arrow_strings_tape_t const &haystacks_tape, arrow_strings_tape_t const &needles_tape,
865+
extra_args_ &&...extra_args) {
836866

837867
using match_t = find_many_match_t;
838868
unified_vector<match_t> results_base, results_simd;
839869
unified_vector<size_t> counts_base, counts_simd;
870+
871+
counts_base.resize(haystacks_tape.size());
872+
counts_simd.resize(haystacks_tape.size());
873+
874+
// Build the matchers
875+
_sz_assert(base_operator.try_build(needles_tape.view()) == status_t::success_k);
876+
_sz_assert(simd_operator.try_build(needles_tape.view()) == status_t::success_k);
877+
878+
// Count the number of matches with both backends
879+
span<size_t> counts_base_span {counts_base.data(), counts_base.size()};
880+
span<size_t> counts_simd_span {counts_simd.data(), counts_simd.size()};
881+
status_t status_count_base = base_operator.try_count(haystacks_tape.view(), counts_base_span);
882+
status_t status_count_simd = simd_operator.try_count(haystacks_tape.view(), counts_simd_span, extra_args...);
883+
_sz_assert(status_count_base == status_t::success_k);
884+
_sz_assert(status_count_simd == status_t::success_k);
885+
size_t total_count_base = std::accumulate(counts_base.begin(), counts_base.end(), 0);
886+
size_t total_count_simd = std::accumulate(counts_simd.begin(), counts_simd.end(), 0);
887+
_sz_assert(total_count_base == total_count_simd);
888+
_sz_assert(std::equal(counts_base.begin(), counts_base.end(), counts_simd.begin()));
889+
890+
// Compute with both backends
891+
results_base.resize(total_count_base);
892+
results_simd.resize(total_count_simd);
893+
size_t count_base = 0, count_simd = 0;
894+
status_t status_base = base_operator.try_find(haystacks_tape.view(), results_base, count_base);
895+
status_t status_simd = simd_operator.try_find(haystacks_tape.view(), results_simd, count_simd, extra_args...);
896+
_sz_assert(status_base == status_t::success_k);
897+
_sz_assert(status_simd == status_t::success_k);
898+
_sz_assert(count_base == count_simd);
899+
900+
// Individually log the failed results
901+
std::sort(results_base.begin(), results_base.end(), match_t::less_globally);
902+
std::sort(results_simd.begin(), results_simd.end(), match_t::less_globally);
903+
for (std::size_t i = 0; i != results_base.size(); ++i) {
904+
_sz_assert(results_base[i].haystack_index == results_simd[i].haystack_index);
905+
_sz_assert(results_base[i].needle_index == results_simd[i].needle_index);
906+
_sz_assert(results_base[i].needle.data() == results_simd[i].needle.data());
907+
}
908+
909+
base_operator.reset();
910+
simd_operator.reset();
911+
}
912+
913+
/**
914+
* @brief Fuzzy test for multi-pattern exact search algorithms using randomly-generated haystacks and needles.
915+
*/
916+
template <typename base_operator_, typename simd_operator_, typename... extra_args_>
917+
void test_find_many_fuzzy(base_operator_ &&base_operator, simd_operator_ &&simd_operator,
918+
fuzzy_config_t needles_config = {}, fuzzy_config_t haystacks_config = {},
919+
std::size_t iterations = 10, extra_args_ &&...extra_args) {
920+
840921
std::vector<std::string> haystacks_array, needles_array;
841922
arrow_strings_tape_t haystacks_tape, needles_tape;
842923

843924
// Generate some random strings, using a small alphabet
844925
for (std::size_t iteration_idx = 0; iteration_idx < iterations; ++iteration_idx) {
845926
randomize_strings(haystacks_config, haystacks_array, haystacks_tape);
846927
randomize_strings(needles_config, needles_array, needles_tape, true);
847-
counts_base.resize(haystacks_array.size());
848-
counts_simd.resize(haystacks_array.size());
928+
test_find_many(base_operator, simd_operator, haystacks_tape, needles_tape, extra_args...);
929+
}
930+
}
849931

850-
// Build the matchers
851-
_sz_assert(base_operator.try_build(needles_tape.view()) == status_t::success_k);
852-
_sz_assert(simd_operator.try_build(needles_tape.view()) == status_t::success_k);
932+
/**
933+
* @brief Fuzzy test for multi-pattern exact search algorithms using randomly-generated haystacks,
934+
* and using incrementally longer potentially-overlapping substrings as needles.
935+
*/
936+
template <typename base_operator_, typename simd_operator_, typename... extra_args_>
937+
void test_find_many_prefixes(base_operator_ &&base_operator, simd_operator_ &&simd_operator,
938+
fuzzy_config_t haystacks_config, std::size_t needle_length_limit,
939+
std::size_t iterations = 10, extra_args_ &&...extra_args) {
853940

854-
// Count the number of matches with both backends
855-
span<size_t> counts_base_span {counts_base.data(), counts_base.size()};
856-
span<size_t> counts_simd_span {counts_simd.data(), counts_simd.size()};
857-
status_t status_count_base = base_operator.try_count(haystacks_tape.view(), counts_base_span);
858-
status_t status_count_simd = simd_operator.try_count(haystacks_tape.view(), counts_simd_span, extra_args...);
859-
_sz_assert(status_count_base == status_t::success_k);
860-
_sz_assert(status_count_simd == status_t::success_k);
861-
size_t total_count_base = std::accumulate(counts_base.begin(), counts_base.end(), 0);
862-
size_t total_count_simd = std::accumulate(counts_simd.begin(), counts_simd.end(), 0);
863-
_sz_assert(total_count_base == total_count_simd);
864-
_sz_assert(std::equal(counts_base.begin(), counts_base.end(), counts_simd.begin()));
941+
std::vector<std::string> haystacks_array;
942+
std::vector<std::string_view> needles_array;
943+
arrow_strings_tape_t haystacks_tape, needles_tape;
865944

866-
// Compute with both backends
867-
results_base.resize(total_count_base);
868-
results_simd.resize(total_count_simd);
869-
size_t count_base = 0, count_simd = 0;
870-
status_t status_base = base_operator.try_find(haystacks_tape.view(), results_base, count_base);
871-
status_t status_simd = simd_operator.try_find(haystacks_tape.view(), results_simd, count_simd, extra_args...);
872-
_sz_assert(status_base == status_t::success_k);
873-
_sz_assert(status_simd == status_t::success_k);
874-
_sz_assert(count_base == count_simd);
945+
for (std::size_t iteration_idx = 0; iteration_idx < iterations; ++iteration_idx) {
946+
randomize_strings(haystacks_config, haystacks_array, haystacks_tape);
875947

876-
// Individually log the failed results
877-
std::sort(results_base.begin(), results_base.end(), match_t::less_globally);
878-
std::sort(results_simd.begin(), results_simd.end(), match_t::less_globally);
879-
for (std::size_t i = 0; i != results_base.size(); ++i) {
880-
_sz_assert(results_base[i].haystack_index == results_simd[i].haystack_index);
881-
_sz_assert(results_base[i].needle_index == results_simd[i].needle_index);
882-
_sz_assert(results_base[i].needle.data() == results_simd[i].needle.data());
883-
}
948+
// Pick various substrings as needles from the first haystack
949+
needles_array.resize(std::min(haystacks_array[0].size(), needle_length_limit));
950+
for (std::size_t i = 0; i != needles_array.size(); ++i)
951+
needles_array[i] = std::string_view(haystacks_array[0]).substr(0, i + 1);
952+
needles_tape.try_assign(needles_array.data(), needles_array.data() + needles_array.size());
884953

885-
base_operator.reset();
886-
simd_operator.reset();
954+
test_find_many(base_operator, simd_operator, haystacks_tape, needles_tape, extra_args...);
887955
}
888956
}
889957

@@ -894,33 +962,31 @@ void test_find_many_fuzzy(base_operator_ &&base_operator, simd_operator_ &&simd_
894962
void test_find_many_equivalence() {
895963

896964
cpu_specs_t default_cpu_specs;
897-
fuzzy_config_t needles_short_config, needles_long_config, haystacks_long_config;
898-
haystacks_long_config.batch_size = default_cpu_specs.cores_total() * 4;
899-
haystacks_long_config.max_string_length = default_cpu_specs.l3_bytes;
965+
fuzzy_config_t needles_short_config, needles_long_config, haystacks_config;
966+
haystacks_config.batch_size = default_cpu_specs.cores_total() * 4;
967+
haystacks_config.max_string_length = default_cpu_specs.l3_bytes;
900968

901969
needles_long_config.min_string_length = 8;
902970
needles_long_config.max_string_length = 10;
903971
needles_long_config.batch_size =
904972
std::pow(needles_long_config.alphabet.size(), needles_long_config.max_string_length);
905973

906-
needles_long_config.min_string_length = 1;
907-
needles_long_config.max_string_length = 6;
908-
needles_long_config.batch_size =
909-
std::pow(needles_long_config.alphabet.size(), needles_long_config.max_string_length);
974+
needles_short_config.min_string_length = 1;
975+
needles_short_config.max_string_length = 6;
976+
needles_short_config.batch_size =
977+
std::pow(needles_short_config.alphabet.size(), needles_short_config.max_string_length);
910978

911-
// Single-threaded serial Levenshtein distance implementation
979+
// Single-threaded serial Aho-Corasick implementation
912980
test_find_many_fixed(find_many_baselines_t {}, find_many_serial_t {});
913-
test_find_many_fuzzy(find_many_baselines_t {}, find_many_serial_t {}, needles_short_config, haystacks_long_config,
914-
1);
915-
test_find_many_fuzzy(find_many_baselines_t {}, find_many_serial_t {}, needles_long_config, haystacks_long_config,
916-
1);
981+
test_find_many_fuzzy(find_many_baselines_t {}, find_many_serial_t {}, needles_short_config, haystacks_config, 1);
982+
test_find_many_fuzzy(find_many_baselines_t {}, find_many_serial_t {}, needles_long_config, haystacks_config, 1);
983+
test_find_many_prefixes(find_many_baselines_t {}, find_many_serial_t {}, haystacks_config, 1024, 1);
917984

918-
// Multi-threaded parallel Levenshtein distance implementation
985+
// Multi-threaded parallel Aho-Corasick implementation
919986
test_find_many_fixed(find_many_baselines_t {}, find_many_parallel_t {});
920-
test_find_many_fuzzy(find_many_baselines_t {}, find_many_parallel_t {}, needles_short_config, haystacks_long_config,
921-
10);
922-
test_find_many_fuzzy(find_many_baselines_t {}, find_many_parallel_t {}, needles_long_config, haystacks_long_config,
923-
10);
987+
test_find_many_fuzzy(find_many_baselines_t {}, find_many_parallel_t {}, needles_short_config, haystacks_config, 10);
988+
test_find_many_fuzzy(find_many_baselines_t {}, find_many_parallel_t {}, needles_long_config, haystacks_config, 10);
989+
test_find_many_prefixes(find_many_baselines_t {}, find_many_parallel_t {}, haystacks_config, 1024, 10);
924990
}
925991

926992
} // namespace scripts

0 commit comments

Comments
 (0)