|
11 | 11 | #ifndef STRINGZILLA_TYPES_CUH_ |
12 | 12 | #define STRINGZILLA_TYPES_CUH_ |
13 | 13 |
|
14 | | -#include <cuda_runtime.h> // `cudaMallocManaged`, `cudaFree`, `cudaSuccess`, `cudaGetErrorString` |
15 | | - |
16 | 14 | #include "stringzilla/types.hpp" |
17 | 15 |
|
| 16 | +#include <cuda_runtime.h> // `cudaMallocManaged`, `cudaFree`, `cudaSuccess`, `cudaGetErrorString` |
| 17 | +#include <optional> // `std::optional` |
| 18 | + |
18 | 19 | #if !defined(SZ_USE_HOPPER) |
19 | 20 | #if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ < 11) |
20 | 21 | #define SZ_USE_HOPPER (1) |
@@ -78,6 +79,35 @@ struct unified_alloc { |
78 | 79 | } |
79 | 80 | }; |
80 | 81 |
|
| 82 | +inline std::optional<gpu_specs_t> gpu_specs(int device = 0) noexcept { |
| 83 | + gpu_specs_t specs; |
| 84 | + cudaDeviceProp prop; |
| 85 | + cudaError_t cuda_error = cudaGetDeviceProperties(&prop, device); |
| 86 | + if (cuda_error != cudaSuccess) return std::nullopt; // ! Failed to get device properties |
| 87 | + |
| 88 | + // Set the GPU specs |
| 89 | + specs.streaming_multiprocessors = prop.multiProcessorCount; |
| 90 | + specs.constant_memory_bytes = prop.totalConstMem; |
| 91 | + specs.vram_bytes = prop.totalGlobalMem; |
| 92 | + |
| 93 | + // Infer other global settings, that CUDA doesn't expose directly |
| 94 | + specs.shared_memory_bytes = prop.sharedMemPerMultiprocessor * prop.multiProcessorCount; |
| 95 | + specs.cuda_cores = gpu_specs_t::cores_per_multiprocessor(prop.major, prop.minor) * specs.streaming_multiprocessors; |
| 96 | + |
| 97 | + // Scheduling-related constants |
| 98 | + specs.max_blocks_per_multiprocessor = prop.maxBlocksPerMultiProcessor; |
| 99 | + specs.reserved_memory_per_block = prop.reservedSharedMemPerBlock; |
| 100 | + return specs; |
| 101 | +} |
| 102 | + |
| 103 | +struct cuda_status_t { |
| 104 | + status_t status = status_t::success_k; |
| 105 | + cudaError_t cuda_error = cudaSuccess; |
| 106 | + float elapsed_milliseconds = 0.0; |
| 107 | + |
| 108 | + inline operator status_t() const noexcept { return status; } |
| 109 | +}; |
| 110 | + |
81 | 111 | } // namespace stringzilla |
82 | 112 | } // namespace ashvardanian |
83 | 113 |
|
|
0 commit comments