Skip to content

Commit 89bbc5b

Browse files
Get device index from local rank if multi-client, otherwise use the current device. (#6405)
* Fix random generator * Get device index from local rank if multi-client, otherwise use current device. Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
1 parent 6273773 commit 89bbc5b

File tree

3 files changed

+28
-12
lines changed

3 files changed

+28
-12
lines changed

oneflow/core/framework/random_generator.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Maybe<Generator> DefaultCUDAGenerator(int device_index) {
7070
static std::vector<std::once_flag> init_flags(device_count);
7171
static std::vector<std::shared_ptr<Generator>> default_cuda_generator(device_count);
7272

73-
if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); }
73+
if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); }
7474
CHECK_OR_RETURN(device_index >= 0 && device_index < device_count)
7575
<< "Invalid device index " << device_index;
7676
std::call_once(init_flags[device_index], [&]() {
@@ -91,7 +91,7 @@ Maybe<Generator> MakeCPUGenerator() {
9191

9292
#ifdef WITH_CUDA
9393
Maybe<Generator> MakeCUDAGenerator(int device_index) {
94-
if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); }
94+
if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); }
9595
CHECK_OR_RETURN(device_index >= 0 && device_index < detail::GetCudaDeviceCount())
9696
<< "Invalid device index " << device_index;
9797
return std::make_shared<Generator>(

oneflow/core/framework/random_generator_impl.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,9 @@ int GetThreadNum(const cudaDeviceProp& prop) {
145145
}
146146
}
147147

148-
Maybe<void> CUDASynchronize(int device_index) {
148+
Maybe<void> CUDASynchronize() {
149149
// Synchronize cuda device to avoid state been modified in random kernels.
150150
JUST(CPUSynchronize());
151-
OF_CUDA_CHECK(cudaSetDevice(device_index));
152151
OF_CUDA_CHECK(cudaDeviceSynchronize());
153152
return Maybe<void>::Ok();
154153
}
@@ -161,25 +160,29 @@ CUDAGeneratorImpl::CUDAGeneratorImpl(uint64_t seed, int device_index)
161160
OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_index));
162161
max_block_num_ = prop.multiProcessorCount;
163162
max_thread_num_ = GetThreadNum(prop);
164-
OF_CUDA_CHECK(cudaSetDevice(device_index));
163+
164+
CudaCurrentDeviceGuard dev_guard(device_index);
165165
OF_CUDA_CHECK(
166166
cudaMalloc(&curand_states_, max_block_num_ * max_thread_num_ * sizeof(curandState)));
167167
detail::InitCurandStates(seed, max_block_num_, max_thread_num_, curand_states_);
168168
}
169169

170170
CUDAGeneratorImpl::~CUDAGeneratorImpl() {
171-
CHECK_JUST(CUDASynchronize(this->device_index()));
171+
CudaCurrentDeviceGuard dev_guard(this->device_index());
172+
CHECK_JUST(CUDASynchronize());
172173
OF_CUDA_CHECK(cudaFree(curand_states_));
173174
}
174175

175176
void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
176-
CHECK_JUST(CUDASynchronize(this->device_index()));
177+
CudaCurrentDeviceGuard dev_guard(this->device_index());
178+
CHECK_JUST(CUDASynchronize());
177179
seed_ = seed;
178180
detail::InitCurandStates(seed_, max_block_num_, max_thread_num_, curand_states_);
179181
}
180182

181183
Maybe<Tensor> CUDAGeneratorImpl::GetState() const {
182-
JUST(CUDASynchronize(this->device_index()));
184+
CudaCurrentDeviceGuard dev_guard(this->device_index());
185+
JUST(CUDASynchronize());
183186
int64_t state_size = max_block_num_ * max_thread_num_ * sizeof(curandState);
184187
int64_t total_size = state_size + sizeof(int64_t);
185188
const auto& device = JUST(Device::New("cpu"));
@@ -207,7 +210,8 @@ Maybe<void> CUDAGeneratorImpl::SetState(const std::shared_ptr<Tensor>& tensor_st
207210
<< total_size << ", but got " << tensor_state->shape()->elem_cnt();
208211
}
209212

210-
JUST(CUDASynchronize(this->device_index()));
213+
CudaCurrentDeviceGuard dev_guard(this->device_index());
214+
JUST(CUDASynchronize());
211215
const auto& callback = std::make_shared<std::function<void(uint64_t)>>([&](uint64_t of_blob_ptr) {
212216
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
213217
const int8_t* data = of_blob->blob().dptr<int8_t>();
@@ -398,16 +402,27 @@ Maybe<CPUGeneratorImpl> MakeGeneratorImpl<CPUGeneratorImpl>(uint64_t seed, int d
398402
}
399403

400404
#ifdef WITH_CUDA
405+
406+
int GetCudaDeviceIndex() {
407+
int cuda_device_index = 0;
408+
if (CHECK_JUST(GlobalMultiClientEnv())) {
409+
cuda_device_index = GlobalProcessCtx::LocalRank();
410+
} else {
411+
OF_CUDA_CHECK(cudaGetDevice(&cuda_device_index));
412+
}
413+
return cuda_device_index;
414+
}
415+
401416
int GetCudaDeviceCount() {
402-
/* static */ int cuda_device_count;
403-
OF_CUDA_CHECK(cudaSetDevice(GlobalProcessCtx::LocalRank()));
417+
/* static */ int cuda_device_count = 0;
418+
CudaCurrentDeviceGuard dev_guard(detail::GetCudaDeviceIndex());
404419
OF_CUDA_CHECK(cudaGetDeviceCount(&cuda_device_count));
405420
return cuda_device_count;
406421
}
407422

408423
template<>
409424
DeviceKey MakeDeviceKey<CUDAGeneratorImpl>(int device_index) {
410-
if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); }
425+
if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); }
411426
DeviceKey device_key;
412427
device_key.device_type = DeviceType::kGPU;
413428
device_key.device_index = device_index;

oneflow/core/framework/random_generator_impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class CUDAGeneratorImpl : public DeviceGeneratorImpl {
137137

138138
namespace detail {
139139

140+
int GetCudaDeviceIndex();
140141
int GetCudaDeviceCount();
141142

142143
void InitCurandStates(uint64_t seed, int32_t block_num, int32_t thread_num, curandState* states);

0 commit comments

Comments
 (0)