@@ -151,24 +151,25 @@ void edit_distance_log_mismatch(std::string const &first, std::string const &sec
151151 */
152152template <typename score_type_, typename base_operator_, typename simd_operator_>
153153static void edit_distances_compare (base_operator_ &&base_operator, simd_operator_ &&simd_operator,
154- std::size_t batch_size = 1024 * 16 , std::size_t max_string_length = 512 ) {
154+ std::size_t batch_size = 1024 * 16 , std::size_t max_string_length = 512 ,
155+ std::string_view allowed_chars = {}) {
155156
156157 using score_t = score_type_;
157158
158159 std::vector<std::pair<std::string, std::string>> test_cases = {
159160 {" ABC" , " ABC" }, // same string; distance ~ 0
160- {" listen " , " silent " }, // distance ~ 4
161- {" atca " , " ctactcaccc " }, // distance ~ 6
161+ {" LISTEN " , " SILENT " }, // distance ~ 4
162+ {" ATCA " , " CTACTCACCC " }, // distance ~ 6
162163 {" A" , " =" }, // distance ~ 1
163- {" a " , " a " }, // distance ~ 0
164+ {" A " , " A " }, // distance ~ 0
164165 {" " , " " }, // distance ~ 0
165- {" " , " abc " }, // distance ~ 3
166- {" abc " , " " }, // distance ~ 3
167- {" abc " , " ac " }, // one deletion; distance ~ 1
168- {" abc " , " a_bc " }, // one insertion; distance ~ 1
166+ {" " , " ABC " }, // distance ~ 3
167+ {" ABC " , " " }, // distance ~ 3
168+ {" ABC " , " AC " }, // one deletion; distance ~ 1
169+ {" ABC " , " A_BC " }, // one insertion; distance ~ 1
169170 {" ggbuzgjux{}l" , " gbuzgjux{}l" }, // one (prepended) insertion; distance ~ 1
170- {" abc " , " adc " }, // one substitution; distance ~ 1
171- {" apple " , " aple " }, // distance ~ 1
171+ {" ABC " , " ADC " }, // one substitution; distance ~ 1
172+ {" APPLE " , " APLE " }, // distance ~ 1
172173 //
173174 // Unicode:
174175 {" αβγδ" , " αγδ" }, // Each Greek symbol is 2 bytes in size; 2 bytes, 1 runes diff.
@@ -188,8 +189,18 @@ static void edit_distances_compare(base_operator_ &&base_operator, simd_operator
188189 // First check with a batch-size of 1
189190 unified_vector<score_t > results_base (1 ), results_simd (1 );
190191 arrow_strings_tape_t first_tape, second_tape;
192+ bool contains_missing_in_any_case = false ;
191193 for (auto [first, second] : test_cases) {
192194
195+ // Check if the input strings fit into our allowed characters set
196+ if (!allowed_chars.empty ()) {
197+ bool contains_missing = false ;
198+ for (auto c : first) contains_missing |= allowed_chars.find (c) == std::string_view::npos;
199+ for (auto c : second) contains_missing |= allowed_chars.find (c) == std::string_view::npos;
200+ contains_missing_in_any_case |= contains_missing;
201+ if (contains_missing) continue ;
202+ }
203+
193204 // Reset the tapes and results
194205 results_base[0 ] = 0 , results_simd[0 ] = 0 ;
195206 first_tape.try_assign (&first, &first + 1 );
@@ -205,25 +216,27 @@ static void edit_distances_compare(base_operator_ &&base_operator, simd_operator
205216 }
206217
207218 // Unzip the test cases into two separate tapes and perform batch processing
208- results_base.resize (test_cases.size ());
209- results_simd.resize (test_cases.size ());
210- first_tape.reset ();
211- second_tape.reset ();
212- for (auto [first, second] : test_cases) {
213- _sz_assert (first_tape.try_append ({first.data (), first.size ()}) == sz::status_t ::success_k);
214- _sz_assert (second_tape.try_append ({second.data (), second.size ()}) == sz::status_t ::success_k);
215- }
219+ if (!contains_missing_in_any_case) {
220+ results_base.resize (test_cases.size ());
221+ results_simd.resize (test_cases.size ());
222+ first_tape.reset ();
223+ second_tape.reset ();
224+ for (auto [first, second] : test_cases) {
225+ _sz_assert (first_tape.try_append ({first.data (), first.size ()}) == sz::status_t ::success_k);
226+ _sz_assert (second_tape.try_append ({second.data (), second.size ()}) == sz::status_t ::success_k);
227+ }
216228
217- // Compute with both backends
218- sz::status_t status_base = base_operator (first_tape.view (), second_tape.view (), results_base.data ());
219- sz::status_t status_simd = simd_operator (first_tape.view (), second_tape.view (), results_simd.data ());
220- _sz_assert (status_base == sz::status_t ::success_k);
221- _sz_assert (status_simd == sz::status_t ::success_k);
229+ // Compute with both backends
230+ sz::status_t status_base = base_operator (first_tape.view (), second_tape.view (), results_base.data ());
231+ sz::status_t status_simd = simd_operator (first_tape.view (), second_tape.view (), results_simd.data ());
232+ _sz_assert (status_base == sz::status_t ::success_k);
233+ _sz_assert (status_simd == sz::status_t ::success_k);
222234
223- // Individually log the failed results
224- for (std::size_t i = 0 ; i != test_cases.size (); ++i) {
225- if (results_base[i] == results_simd[i]) continue ;
226- edit_distance_log_mismatch (test_cases[i].first , test_cases[i].second , results_base[i], results_simd[i]);
235+ // Individually log the failed results
236+ for (std::size_t i = 0 ; i != test_cases.size (); ++i) {
237+ if (results_base[i] == results_simd[i]) continue ;
238+ edit_distance_log_mismatch (test_cases[i].first , test_cases[i].second , results_base[i], results_simd[i]);
239+ }
227240 }
228241
229242 // Generate some random strings, using a small alphabet
@@ -232,8 +245,8 @@ static void edit_distances_compare(base_operator_ &&base_operator, simd_operator
232245 for (std::size_t i = 0 ; i != batch_size; ++i) {
233246 std::size_t first_length = 1u + std::rand () % max_string_length;
234247 std::size_t second_length = 1u + std::rand () % max_string_length;
235- first_array[i] = random_string (first_length, " abc " , 3 );
236- second_array[i] = random_string (second_length, " abc " , 3 );
248+ first_array[i] = random_string (first_length, " ABC " , 3 );
249+ second_array[i] = random_string (second_length, " ABC " , 3 );
237250 }
238251
239252 // Convert to a GPU-friendly layout
@@ -299,54 +312,58 @@ static void test_equivalence(std::size_t batch_size = 1024, std::size_t max_stri
299312 batch_size, max_string_length);
300313
301314 // Now let's take non-unary substitution costs, like BLOSUM62
302- constexpr error_t blosum62_gap_extension_cost = 4 ; // ? The inverted typical (-4) value
303- error_matrix_t blosum62 = sz::error_costs_26x26ascii_t::blosum62 ().decompressed ();
315+ constexpr error_t blosum62_gap_extension_cost = -4 ;
316+ error_mat_t blosum62_mat = sz::error_costs_26x26ascii_t::blosum62 ();
317+ error_matrix_t blosum62_matrix = blosum62_mat.decompressed ();
304318
319+ #if 0
305320 // Single-threaded serial NW implementation
306- edit_distances_compare<sz_ssize_t >( //
307- needleman_wunsch_baselines_t {blosum62 , blosum62_gap_extension_cost}, //
321+ edit_distances_compare<sz_ssize_t>( //
322+ needleman_wunsch_baselines_t {blosum62_matrix , blosum62_gap_extension_cost}, //
308323 sz::needleman_wunsch_scores<serial_k, char, error_matrix_t, std::allocator<char>> {
309- blosum62 , blosum62_gap_extension_cost}, //
324+ blosum62_matrix , blosum62_gap_extension_cost}, //
310325 batch_size, max_string_length);
311326
312327 // Multi-threaded parallel NW implementation
313- edit_distances_compare<sz_ssize_t >( //
314- needleman_wunsch_baselines_t {blosum62 , blosum62_gap_extension_cost}, //
328+ edit_distances_compare<sz_ssize_t>( //
329+ needleman_wunsch_baselines_t {blosum62_matrix , blosum62_gap_extension_cost}, //
315330 sz::needleman_wunsch_scores<parallel_k, char, error_matrix_t, std::allocator<char>> {
316- blosum62 , blosum62_gap_extension_cost}, //
331+ blosum62_matrix , blosum62_gap_extension_cost}, //
317332 batch_size, max_string_length);
318333
319334 // Single-threaded serial SW implementation
320- edit_distances_compare<sz_ssize_t >( //
321- smith_waterman_baselines_t {blosum62 , blosum62_gap_extension_cost}, //
335+ edit_distances_compare<sz_ssize_t>( //
336+ smith_waterman_baselines_t {blosum62_matrix , blosum62_gap_extension_cost}, //
322337 sz::smith_waterman_scores<serial_k, char, error_matrix_t, std::allocator<char>> {
323- blosum62 , blosum62_gap_extension_cost}, //
338+ blosum62_matrix , blosum62_gap_extension_cost}, //
324339 batch_size, max_string_length);
325340
326341 // Multi-threaded parallel SW implementation
327- edit_distances_compare<sz_ssize_t >( //
328- smith_waterman_baselines_t {blosum62 , blosum62_gap_extension_cost}, //
342+ edit_distances_compare<sz_ssize_t>( //
343+ smith_waterman_baselines_t {blosum62_matrix , blosum62_gap_extension_cost}, //
329344 sz::smith_waterman_scores<parallel_k, char, error_matrix_t, std::allocator<char>> {
330- blosum62 , blosum62_gap_extension_cost}, //
345+ blosum62_matrix , blosum62_gap_extension_cost}, //
331346 batch_size, max_string_length);
347+ #endif
332348
333349 // Switch to the GPU, using an identical matrix, but move it into unified memory
334- unified_vector<error_matrix_t > blosum62_unified (1 );
335- blosum62_unified[0 ] = blosum62 ;
350+ unified_vector<error_mat_t > blosum62_unified (1 );
351+ blosum62_unified[0 ] = blosum62_mat ;
336352
337353 // CUDA Levenshtein distance against Multi-threaded on CPU
338354 edit_distances_compare<sz_size_t >( //
339355 sz::levenshtein_distances<parallel_k, char , std::allocator<char >> {}, //
340356 sz::levenshtein_distances<cuda_k, char > {}, //
341357 batch_size, max_string_length);
342358
343- // CUDA Needleman-Wunsch distance against Multi-threaded on CPU
359+ // CUDA Needleman-Wunsch distance against Multi-threaded on CPU,
360+ // using a compressed smaller matrix to fit into GPU shared memory
361+ std::string_view ascii_alphabet = " ABCDEFGHIJKLMNOPQRSTUVWXYZ" ;
344362 edit_distances_compare<sz_ssize_t >( //
345363 sz::needleman_wunsch_scores<parallel_k, char , error_matrix_t , std::allocator<char >> {
346- blosum62, blosum62_gap_extension_cost}, //
347- sz::needleman_wunsch_scores<cuda_k, char , error_matrix_t *> {blosum62_unified.data (),
348- blosum62_gap_extension_cost},
349- batch_size, max_string_length);
364+ blosum62_matrix, blosum62_gap_extension_cost}, //
365+ sz::needleman_wunsch_scores<cuda_k, char , error_mat_t *> {blosum62_unified.data (), blosum62_gap_extension_cost},
366+ batch_size, max_string_length, ascii_alphabet);
350367};
351368
352369#if 0
0 commit comments