5151#include " bench.hpp"
5252#include " test.hpp" // `levenshtein_baseline`, `unary_substitution_costs`
5353
54- #include " stringzilla/similarities.hpp"
54+ #if SZ_USE_CUDA
55+ #include < stringzilla/similarity.cuh> // Parallel string processing on CUDA or OpenMP
56+ #endif
57+
58+ #if SZ_USE_OPENMP
59+ #include < stringzilla/similarity.hpp> // OpenMP templates for string similarity measures
60+ #endif
5561
5662using namespace ashvardanian ::stringzilla::scripts;
5763
@@ -65,7 +71,7 @@ struct hamming_from_sz {
6571 sz_size_t bound = SZ_SIZE_MAX;
6672
6773 inline call_result_t operator ()(std::size_t token_index) const noexcept {
68- return operator ()(env. tokens [token_index], env. tokens [env.tokens .size () - 1 - token_index]);
74+ return operator ()(env[token_index], env[env.tokens .size () - 1 - token_index]);
6975 }
7076
7177 inline call_result_t operator ()(std::string_view a, std::string_view b) const noexcept {
@@ -100,7 +106,7 @@ struct levenshtein_from_sz {
100106 sz_size_t bound = SZ_SIZE_MAX;
101107
102108 inline call_result_t operator ()(std::size_t token_index) const noexcept {
103- return operator ()(env. tokens [token_index], env. tokens [env.tokens .size () - 1 - token_index]);
109+ return operator ()(env[token_index], env[env.tokens .size () - 1 - token_index]);
104110 }
105111
106112 inline call_result_t operator ()(std::string_view a, std::string_view b) const noexcept {
@@ -125,7 +131,7 @@ struct alignment_score_from_sz {
125131 error_costs_256x256_t costs = unary_substitution_costs();
126132
127133 inline call_result_t operator ()(std::size_t token_index) const noexcept {
128- return operator ()(env. tokens [token_index], env. tokens [env.tokens .size () - 1 - token_index]);
134+ return operator ()(env[token_index], env[env.tokens .size () - 1 - token_index]);
129135 }
130136
131137 inline call_result_t operator ()(std::string_view a, std::string_view b) const noexcept {
@@ -143,33 +149,83 @@ struct alignment_score_from_sz {
143149 }
144150};
145151
152+ #if SZ_USE_OPENMP
153+
146154/* * @brief Wraps a hardware-specific Levenshtein-distance backend into something @b `bench_unary`-compatible . */
147155struct levenshtein_from_sz_openmp {
148156
149157 environment_t const &env;
150158 sz_size_t bound = SZ_SIZE_MAX;
151159
152160 inline call_result_t operator ()(std::size_t token_index) const noexcept {
153- return operator ()(env. tokens [token_index], env. tokens [env.tokens .size () - 1 - token_index]);
161+ return operator ()(env[token_index], env[env.tokens .size () - 1 - token_index]);
154162 }
155163
156164 inline call_result_t operator ()(std::string_view a, std::string_view b) const noexcept (false ) {
157- sz_size_t result_distance = sz::openmp::levenshtein_distance (a, b);
165+ sz_size_t result_distance = sz::openmp::levenshtein_distance (a, b, std::allocator< char >() );
158166 do_not_optimize (result_distance);
159167 std::size_t bytes_passed = std::min (a.size (), b.size ());
160168 std::size_t cells_passed = a.size () * b.size ();
161169 return {bytes_passed, static_cast <check_value_t >(result_distance), cells_passed};
162170 }
163171};
164172
173+ #endif
174+
175+ #if SZ_USE_CUDA
176+
177+ /* * @brief Wraps a hardware-specific Levenshtein-distance backend into something @b `bench_unary`-compatible . */
178+ struct levenshtein_from_sz_cuda {
179+
180+ environment_t const &env;
181+ std::vector<sz_size_t , sz::cuda::unified_alloc<sz_size_t >> results;
182+ sz_size_t bound = SZ_SIZE_MAX;
183+
184+ levenshtein_from_sz_cuda (environment_t const &env, sz_size_t batch_size) : env(env), results(batch_size) {
185+ if (env.tokens .size () <= batch_size) throw std::runtime_error (" Batch size is too large." );
186+ }
187+
188+ inline call_result_t operator ()(std::size_t batch_index) noexcept (false ) {
189+ std::size_t const batch_size = results.size ();
190+ std::size_t const forward_token_index = (batch_index * batch_size) % (env.tokens .size () - batch_size);
191+ std::size_t const backward_token_index = env.tokens .size () - forward_token_index - batch_size;
192+
193+ return operator ()({env.tokens .data () + forward_token_index, batch_size},
194+ {env.tokens .data () + backward_token_index, batch_size});
195+ }
196+
197+ inline call_result_t operator ()(std::span<token_view_t const > a, std::span<token_view_t const > b) noexcept (false ) {
198+ sz::status_t status = sz::cuda::levenshtein_distances (a, b, results.data ());
199+ if (status != sz::status_t ::success_k) throw std::runtime_error (cudaGetErrorString (cudaGetLastError ()));
200+ do_not_optimize (results);
201+ std::size_t bytes_passed = 0 , cells_passed = 0 ;
202+ for (std::size_t i = 0 ; i < results.size (); ++i) {
203+ bytes_passed += std::min (a[i].size (), b[i].size ());
204+ cells_passed += a[i].size () * b[i].size ();
205+ }
206+ call_result_t call_result;
207+ call_result.bytes_passed = bytes_passed;
208+ call_result.operations = cells_passed;
209+ call_result.inputs_processed = results.size ();
210+ return call_result;
211+ }
212+ };
213+
214+ #endif
215+
165216void bench_edits (environment_t const &env) {
166217 auto base_call = levenshtein_from_sz<sz_levenshtein_distance_serial>(env);
167218 bench_result_t base = bench_unary (env, " sz_levenshtein_distance_serial" , base_call).log ();
168219 auto base_utf8_call = levenshtein_from_sz<sz_levenshtein_distance_utf8_serial>(env);
169220 bench_result_t base_utf8 = bench_unary (env, " sz_levenshtein_distance_utf8_serial" , base_utf8_call).log (base);
170221 sz_unused (base_utf8);
171222
223+ #if SZ_USE_OPENMP
172224 bench_unary (env, " sz::openmp::levenshtein_distance" , levenshtein_from_sz_openmp (env)).log (base);
225+ #endif
226+ #if SZ_USE_CUDA
227+ bench_unary (env, " sz::cuda::levenshtein_distances(x1024)" , levenshtein_from_sz_cuda (env, 1024 )).log (base);
228+ #endif
173229
174230#if SZ_USE_ICE
175231 auto ice_call = levenshtein_from_sz<sz_levenshtein_distance_ice>(env);
@@ -185,16 +241,22 @@ void bench_edits(environment_t const &env) {
185241int main (int argc, char const **argv) {
186242 std::printf (" Welcome to StringZilla!\n " );
187243
188- std::printf (" Building up the environment...\n " );
189- environment_t env = build_environment ( //
190- argc, argv, //
191- " xlsum.csv" , // Preferred for UTF-8 content
192- environment_t ::tokenization_t ::words_k);
244+ try {
245+ std::printf (" Building up the environment...\n " );
246+ environment_t env = build_environment ( //
247+ argc, argv, //
248+ " xlsum.csv" , // Preferred for UTF-8 content
249+ environment_t ::tokenization_t ::lines_k);
193250
194- std::printf (" Starting string similarity benchmarks...\n " );
195- bench_hamming (env);
196- bench_edits (env);
251+ std::printf (" Starting string similarity benchmarks...\n " );
252+ bench_hamming (env);
253+ bench_edits (env);
254+ }
255+ catch (std::exception const &e) {
256+ std::fprintf (stderr, " Failed with: %s\n " , e.what ());
257+ return 1 ;
258+ }
197259
198- std::printf (" All benchmarks passed .\n " );
260+ std::printf (" All benchmarks finished .\n " );
199261 return 0 ;
200262}
0 commit comments