@@ -145,10 +145,9 @@ int GetThreadNum(const cudaDeviceProp& prop) {
145145 }
146146}
147147
148- Maybe<void > CUDASynchronize (int device_index ) {
148+ Maybe<void > CUDASynchronize () {
149149 // Synchronize cuda device to avoid state been modified in random kernels.
150150 JUST (CPUSynchronize ());
151- OF_CUDA_CHECK (cudaSetDevice (device_index));
152151 OF_CUDA_CHECK (cudaDeviceSynchronize ());
153152 return Maybe<void >::Ok ();
154153}
@@ -161,25 +160,29 @@ CUDAGeneratorImpl::CUDAGeneratorImpl(uint64_t seed, int device_index)
161160 OF_CUDA_CHECK (cudaGetDeviceProperties (&prop, device_index));
162161 max_block_num_ = prop.multiProcessorCount ;
163162 max_thread_num_ = GetThreadNum (prop);
164- OF_CUDA_CHECK (cudaSetDevice (device_index));
163+
164+ CudaCurrentDeviceGuard dev_guard (device_index);
165165 OF_CUDA_CHECK (
166166 cudaMalloc (&curand_states_, max_block_num_ * max_thread_num_ * sizeof (curandState)));
167167 detail::InitCurandStates (seed, max_block_num_, max_thread_num_, curand_states_);
168168}
169169
170170CUDAGeneratorImpl::~CUDAGeneratorImpl () {
171- CHECK_JUST (CUDASynchronize (this ->device_index ()));
171+ CudaCurrentDeviceGuard dev_guard (this ->device_index ());
172+ CHECK_JUST (CUDASynchronize ());
172173 OF_CUDA_CHECK (cudaFree (curand_states_));
173174}
174175
175176void CUDAGeneratorImpl::set_current_seed (uint64_t seed) {
176- CHECK_JUST (CUDASynchronize (this ->device_index ()));
177+ CudaCurrentDeviceGuard dev_guard (this ->device_index ());
178+ CHECK_JUST (CUDASynchronize ());
177179 seed_ = seed;
178180 detail::InitCurandStates (seed_, max_block_num_, max_thread_num_, curand_states_);
179181}
180182
181183Maybe<Tensor> CUDAGeneratorImpl::GetState () const {
182- JUST (CUDASynchronize (this ->device_index ()));
184+ CudaCurrentDeviceGuard dev_guard (this ->device_index ());
185+ JUST (CUDASynchronize ());
183186 int64_t state_size = max_block_num_ * max_thread_num_ * sizeof (curandState);
184187 int64_t total_size = state_size + sizeof (int64_t );
185188 const auto & device = JUST (Device::New (" cpu" ));
@@ -207,7 +210,8 @@ Maybe<void> CUDAGeneratorImpl::SetState(const std::shared_ptr<Tensor>& tensor_st
207210 << total_size << " , but got " << tensor_state->shape ()->elem_cnt ();
208211 }
209212
210- JUST (CUDASynchronize (this ->device_index ()));
213+ CudaCurrentDeviceGuard dev_guard (this ->device_index ());
214+ JUST (CUDASynchronize ());
211215 const auto & callback = std::make_shared<std::function<void (uint64_t )>>([&](uint64_t of_blob_ptr) {
212216 auto * of_blob = reinterpret_cast <OfBlob*>(of_blob_ptr);
213217 const int8_t * data = of_blob->blob ().dptr <int8_t >();
@@ -398,16 +402,27 @@ Maybe<CPUGeneratorImpl> MakeGeneratorImpl<CPUGeneratorImpl>(uint64_t seed, int d
398402}
399403
400404#ifdef WITH_CUDA
405+
406+ int GetCudaDeviceIndex () {
407+ int cuda_device_index = 0 ;
408+ if (CHECK_JUST (GlobalMultiClientEnv ())) {
409+ cuda_device_index = GlobalProcessCtx::LocalRank ();
410+ } else {
411+ OF_CUDA_CHECK (cudaGetDevice (&cuda_device_index));
412+ }
413+ return cuda_device_index;
414+ }
415+
401416int GetCudaDeviceCount () {
402- /* static */ int cuda_device_count;
403- OF_CUDA_CHECK ( cudaSetDevice ( GlobalProcessCtx::LocalRank () ));
417+ /* static */ int cuda_device_count = 0 ;
418+ CudaCurrentDeviceGuard dev_guard ( detail::GetCudaDeviceIndex ( ));
404419 OF_CUDA_CHECK (cudaGetDeviceCount (&cuda_device_count));
405420 return cuda_device_count;
406421}
407422
408423template <>
409424DeviceKey MakeDeviceKey<CUDAGeneratorImpl>(int device_index) {
410- if (device_index == -1 ) { device_index = GlobalProcessCtx::LocalRank (); }
425+ if (device_index == -1 ) { device_index = detail::GetCudaDeviceIndex (); }
411426 DeviceKey device_key;
412427 device_key.device_type = DeviceType::kGPU ;
413428 device_key.device_index = device_index;
0 commit comments