Skip to content

Commit 6431900

Browse files
committed
Add: Ephemeral GPU executors if no device is passed
1 parent 24967c7 commit 6431900

File tree

4 files changed

+106
-25
lines changed

4 files changed

+106
-25
lines changed

c/stringzillas.cuh

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,28 @@ struct gpu_scope_t {
218218
};
219219
szs::cuda_executor_t &get_executor(gpu_scope_t &scope) noexcept { return scope.executor; }
220220
sz::gpu_specs_t get_specs(gpu_scope_t const &scope) noexcept { return scope.specs; }
221+
222+
/** Cached default GPU context (device 0) to avoid repeated scheduling boilerplate */
223+
struct default_gpu_context_t {
224+
sz::status_t status = sz::status_t::unknown_k;
225+
szs::cuda_executor_t executor;
226+
sz::gpu_specs_t specs;
227+
};
228+
229+
inline default_gpu_context_t &default_gpu_context() {
230+
static default_gpu_context_t ctx = [] {
231+
default_gpu_context_t result;
232+
auto specs_status = szs::gpu_specs_fetch(result.specs, 0);
233+
if (specs_status.status != sz::status_t::success_k) {
234+
result.status = specs_status.status;
235+
return result;
236+
}
237+
auto exec_status = result.executor.try_scheduling(0);
238+
result.status = exec_status.status;
239+
return result;
240+
}();
241+
return ctx;
242+
}
221243
#endif
222244

223245
struct device_scope_t {
@@ -293,6 +315,16 @@ sz_status_t szs_levenshtein_distances_for_(
293315
get_executor(device_scope), get_specs(device_scope));
294316
result = static_cast<sz_status_t>(status);
295317
}
318+
// Try ephemeral GPU on default scope (device 0)
319+
else if (std::holds_alternative<default_scope_t>(device->variants)) {
320+
auto &ctx = default_gpu_context();
321+
if (ctx.status != sz::status_t::success_k) { result = static_cast<sz_status_t>(ctx.status); }
322+
else {
323+
sz::status_t status = engine_variant( //
324+
a_container, b_container, results_strided, ctx.executor, ctx.specs);
325+
result = static_cast<sz_status_t>(status);
326+
}
327+
}
296328
else { result = sz_device_code_mismatch_k; }
297329
#else
298330
result = sz_status_unknown_k; // GPU support is not enabled
@@ -452,6 +484,15 @@ sz_status_t szs_needleman_wunsch_scores_for_(
452484
get_executor(device_scope), get_specs(device_scope));
453485
result = static_cast<sz_status_t>(status);
454486
}
487+
else if (std::holds_alternative<default_scope_t>(device->variants)) {
488+
auto &ctx = default_gpu_context();
489+
if (ctx.status != sz::status_t::success_k) { result = static_cast<sz_status_t>(ctx.status); }
490+
else {
491+
sz::status_t status = engine_variant( //
492+
a_container, b_container, results_strided, ctx.executor, ctx.specs);
493+
result = static_cast<sz_status_t>(status);
494+
}
495+
}
455496
else { result = sz_status_unknown_k; }
456497
#else
457498
result = sz_status_unknown_k; // GPU support is not enabled
@@ -540,6 +581,25 @@ sz_status_t szs_smith_waterman_scores_for_(
540581
get_executor(device_scope), get_specs(device_scope));
541582
result = static_cast<sz_status_t>(status);
542583
}
584+
else if (std::holds_alternative<default_scope_t>(device->variants)) {
585+
sz::gpu_specs_t specs;
586+
auto specs_status = szs::gpu_specs_fetch(specs, 0);
587+
if (specs_status.status != sz::status_t::success_k) {
588+
result = static_cast<sz_status_t>(specs_status.status);
589+
}
590+
else {
591+
szs::cuda_executor_t executor;
592+
auto exec_status = executor.try_scheduling(0);
593+
if (exec_status.status != sz::status_t::success_k) {
594+
result = static_cast<sz_status_t>(exec_status.status);
595+
}
596+
else {
597+
sz::status_t status = engine_variant( //
598+
a_container, b_container, results_strided, executor, specs);
599+
result = static_cast<sz_status_t>(status);
600+
}
601+
}
602+
}
543603
else { result = sz_status_unknown_k; }
544604
#else
545605
result = sz_status_unknown_k; // GPU support is not enabled
@@ -659,14 +719,23 @@ sz_status_t szs_fingerprints_for_( //
659719
auto const min_counts_rows = //
660720
strided_rows<sz_u32_t> {reinterpret_cast<sz_ptr_t>(min_counts), dims, min_counts_stride, texts_count};
661721

662-
// CPU fallback hashers can only work with CPU-compatible device scopes
722+
// GPU fallback hashers can work with GPU scope, or default scope via an ephemeral GPU executor
663723
if (std::holds_alternative<gpu_scope_t>(device->variants)) {
664724
auto &device_scope = std::get<gpu_scope_t>(device->variants);
665725
sz::status_t status = fallback_hashers( //
666726
texts_container, min_hashes_rows, min_counts_rows, //
667727
get_executor(device_scope), get_specs(device_scope));
668728
result = static_cast<sz_status_t>(status);
669729
}
730+
else if (std::holds_alternative<default_scope_t>(device->variants)) {
731+
auto &ctx = default_gpu_context();
732+
if (ctx.status != sz::status_t::success_k) { result = static_cast<sz_status_t>(ctx.status); }
733+
else {
734+
sz::status_t status = fallback_hashers( //
735+
texts_container, min_hashes_rows, min_counts_rows, ctx.executor, ctx.specs);
736+
result = static_cast<sz_status_t>(status);
737+
}
738+
}
670739
else { result = sz_status_unknown_k; }
671740
};
672741
#endif // SZ_USE_CUDA
@@ -704,6 +773,22 @@ sz_status_t szs_fingerprints_for_( //
704773
if (result != sz_success_k) break;
705774
}
706775
}
776+
else if (std::holds_alternative<default_scope_t>(device->variants)) {
777+
auto &ctx = default_gpu_context();
778+
if (ctx.status != sz::status_t::success_k) { result = static_cast<sz_status_t>(ctx.status); }
779+
else {
780+
for (std::size_t i = 0; i < unrolled_hashers.size(); ++i) {
781+
auto &engine_variant = unrolled_hashers[i];
782+
sz::status_t status = engine_variant( //
783+
texts_container, //
784+
min_hashes_rows.template shifted<fingerprint_slice_k>(i * bytes_per_slice_k), //
785+
min_counts_rows.template shifted<fingerprint_slice_k>(i * bytes_per_slice_k), //
786+
ctx.executor, ctx.specs);
787+
result = static_cast<sz_status_t>(status);
788+
if (result != sz_success_k) break;
789+
}
790+
}
791+
}
707792
else { result = sz_status_unknown_k; }
708793
#else
709794
result = sz_status_unknown_k; // GPU support is not enabled

python/stringzillas.c

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,8 @@ static inline sz_bool_t try_swap_to_unified_allocator(PyObject *strs_obj) {
135135
* @brief Helper function to determine if unified memory is required based on capabilities and device scope.
136136
* @param[in] capabilities The capabilities bitmask of the current engine.
137137
*/
138-
static inline sz_bool_t requires_unified_memory(sz_capability_t capabilities, szs_device_scope_t device_handle) {
139-
// Only relevant if CUDA capability is enabled
140-
if ((capabilities & sz_cap_cuda_k) == 0) return sz_false_k;
141-
142-
// Check that the executor is a GPU device scope
143-
sz_size_t gpu_device = 0;
144-
if (szs_device_scope_get_gpu_device(device_handle, &gpu_device) == sz_success_k) return sz_true_k;
145-
return sz_false_k;
138+
static inline sz_bool_t requires_unified_memory(sz_capability_t capabilities) {
139+
return (capabilities & sz_cap_cuda_k) != 0;
146140
}
147141

148142
#pragma endregion
@@ -452,8 +446,8 @@ static PyObject *LevenshteinDistances_call(LevenshteinDistances *self, PyObject
452446
sz_status_t (*kernel_punned)(szs_levenshtein_distances_t, szs_device_scope_t, void *, void *, sz_size_t *,
453447
sz_size_t) = NULL;
454448

455-
// Swap allocators only when using CUDA with a GPU device
456-
if (requires_unified_memory(self->capabilities, device_handle))
449+
// Swap allocators only when using CUDA with a GPU device (inputs must be unified)
450+
if (requires_unified_memory(self->capabilities))
457451
if (!try_swap_to_unified_allocator(a_obj) || !try_swap_to_unified_allocator(b_obj)) return NULL;
458452

459453
// Handle 32-bit tape inputs
@@ -748,8 +742,8 @@ static PyObject *LevenshteinDistancesUTF8_call(LevenshteinDistancesUTF8 *self, P
748742
sz_status_t (*kernel_punned)(szs_levenshtein_distances_t, szs_device_scope_t, void *, void *, sz_size_t *,
749743
sz_size_t) = NULL;
750744

751-
// Swap allocators only when using CUDA with a GPU device
752-
if (requires_unified_memory(self->capabilities, device_handle))
745+
// Swap allocators when engine supports CUDA
746+
if (requires_unified_memory(self->capabilities))
753747
if (!try_swap_to_unified_allocator(a_obj) || !try_swap_to_unified_allocator(b_obj)) return NULL;
754748

755749
// Handle 32-bit tape inputs
@@ -1079,8 +1073,8 @@ static PyObject *NeedlemanWunsch_call(NeedlemanWunsch *self, PyObject *args, PyO
10791073
sz_status_t (*kernel_punned)(szs_needleman_wunsch_scores_t, szs_device_scope_t, void const *, void const *,
10801074
sz_ssize_t *, sz_size_t) = NULL;
10811075

1082-
// Swap allocators only when using CUDA with a GPU device
1083-
if (requires_unified_memory(self->capabilities, device_handle))
1076+
// Swap allocators only when using CUDA with a GPU device (inputs must be unified)
1077+
if (requires_unified_memory(self->capabilities))
10841078
if (!try_swap_to_unified_allocator(a_obj) || !try_swap_to_unified_allocator(b_obj)) return NULL;
10851079

10861080
// Handle 32-bit tape inputs
@@ -1393,8 +1387,8 @@ static PyObject *SmithWaterman_call(SmithWaterman *self, PyObject *args, PyObjec
13931387
sz_status_t (*kernel_punned)(szs_smith_waterman_scores_t, szs_device_scope_t, void const *, void const *,
13941388
sz_ssize_t *, sz_size_t) = NULL;
13951389

1396-
// Swap allocators only when using CUDA with a GPU device
1397-
if (requires_unified_memory(self->capabilities, device_handle))
1390+
// Swap allocators only when using CUDA with a GPU device (inputs must be unified)
1391+
if (requires_unified_memory(self->capabilities))
13981392
if (!try_swap_to_unified_allocator(a_obj) || !try_swap_to_unified_allocator(b_obj)) return NULL;
13991393

14001394
// Handle 32-bit tape inputs
@@ -1744,7 +1738,7 @@ static PyObject *Fingerprints_call(Fingerprints *self, PyObject *args, PyObject
17441738
}
17451739

17461740
// Swap allocators only when using CUDA with a GPU device (inputs must be unified)
1747-
sz_bool_t need_unified = requires_unified_memory(self->capabilities, device_handle);
1741+
sz_bool_t need_unified = requires_unified_memory(self->capabilities);
17481742
if (need_unified)
17491743
if (!try_swap_to_unified_allocator(texts_obj)) return NULL;
17501744

scripts/test_stringzilla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
77
uv pip install numpy pyarrow pytest pytest-repeat
88
uv pip install -e . --force-reinstall --no-build-isolation
9+
uv run --no-project python -m pytest scripts/test_stringzilla.py -s -x
910
1011
Recommended flags for better diagnostics:
1112

scripts/test_stringzillas.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
uv pip install numpy pyarrow pytest pytest-repeat affine-gaps
88
SZ_TARGET=stringzillas-cpus uv pip install -e . --force-reinstall --no-build-isolation
99
uv run --no-project python -c "import stringzillas; print(stringzillas.__capabilities__)"
10+
uv run --no-project python -m pytest scripts/test_stringzillas.py -s -x
11+
12+
To run for the CUDA backend:
13+
14+
uv pip install numpy pyarrow pytest pytest-repeat affine-gaps
15+
SZ_TARGET=stringzillas-cuda uv pip install -e . --force-reinstall --no-build-isolation
16+
uv run --no-project python -c "import stringzillas; print(stringzillas.__capabilities__)"
17+
uv run --no-project python -m pytest scripts/test_stringzillas.py -s -x
1018
1119
Recommended flags for better diagnostics:
1220
@@ -20,13 +28,6 @@
2028
Example:
2129
2230
uv run --no-project python -X faulthandler -m pytest scripts/test_stringzillas.py -s -vv --maxfail=1 --full-trace
23-
24-
To run for the CUDA backend:
25-
26-
uv pip install numpy pyarrow pytest pytest-repeat affine-gaps
27-
SZ_TARGET=stringzillas-cuda uv pip install -e . --force-reinstall --no-build-isolation
28-
uv run --no-project python -c "import stringzillas; print(stringzillas.__capabilities__)"
29-
uv run --no-project python -m pytest scripts/test_stringzillas.py -s -x
3031
"""
3132

3233
from random import choice, randint, seed

0 commit comments

Comments
 (0)