diff --git a/include/ctranslate2/devices.h b/include/ctranslate2/devices.h index 674713b8f..a87b02bab 100644 --- a/include/ctranslate2/devices.h +++ b/include/ctranslate2/devices.h @@ -26,6 +26,8 @@ namespace ctranslate2 { void synchronize_device(Device device, int index); void synchronize_stream(Device device); + void destroy_context(Device device); + class ScopedDeviceSetter { public: ScopedDeviceSetter(Device device, int index) diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h index 8c8e15d8e..b79adf623 100644 --- a/include/ctranslate2/replica_pool.h +++ b/include/ctranslate2/replica_pool.h @@ -354,6 +354,8 @@ namespace ctranslate2 { void finalize() override { _replica.reset(); + + destroy_context(_device); } private: diff --git a/src/cuda/random.cu b/src/cuda/random.cu index f016bb447..a0ff684de 100644 --- a/src/cuda/random.cu +++ b/src/cuda/random.cu @@ -48,12 +48,17 @@ namespace ctranslate2 { curandState* _states; }; + static thread_local std::unique_ptr> states; + curandStatePhilox4_32_10_t* get_curand_states(size_t num_states) { - static thread_local std::unique_ptr> states; if (!states || num_states > states->num_states()) states = std::make_unique>(num_states); return states->states(); } + void free_curand_states() { + states.reset(); + } + } } diff --git a/src/cuda/random.h b/src/cuda/random.h index e12ae20f6..149ae46c9 100644 --- a/src/cuda/random.h +++ b/src/cuda/random.h @@ -6,6 +6,7 @@ namespace ctranslate2 { namespace cuda { curandStatePhilox4_32_10_t* get_curand_states(size_t num_states); + void free_curand_states(); } } diff --git a/src/devices.cc b/src/devices.cc index a2936e0a6..6bb615ea5 100644 --- a/src/devices.cc +++ b/src/devices.cc @@ -2,6 +2,7 @@ #ifdef CT2_WITH_CUDA # include "cuda/utils.h" +# include "cuda/random.h" #endif #ifdef CT2_WITH_TENSOR_PARALLEL # include @@ -118,6 +119,17 @@ namespace ctranslate2 { (void)device; #endif } + + void destroy_context(Device device) { +#ifdef CT2_WITH_CUDA + if (device == Device::CUDA) { + cuda::free_curand_states(); + } +#else + (void)device; +#endif + } + // Initialize the static member variable #ifdef CT2_WITH_TENSOR_PARALLEL std::vector ScopedMPISetter::_nccl_comms;