Skip to content

Commit 305bbf6

Browse files
Fix: Fix crash in Debug build with multi-GPU due to forced cudaSetDevice(0) (#6498)
Signed-off-by:Tianxiang Wang<[email protected]>, Contributed under MetaX Integrated Circuits (Shanghai) Co., Ltd.
1 parent a7255f2 commit 305bbf6

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

source/source_base/module_device/device.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ std::string get_device_flag(const std::string& device,
6565
/**
6666
* @brief Get the rank of current node
6767
* Note that GPU can only be binded with CPU in the same node
68-
*
69-
* @return int
68+
*
69+
* @return int
7070
*/
7171
int get_node_rank();
7272
int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm = MPI_COMM_WORLD);
@@ -91,6 +91,14 @@ void record_device_memory(const Device* dev, std::ofstream& ofs_device, std::str
9191
return;
9292
}
9393

94+
#if defined(__CUDA) || defined(__ROCM)
95+
template <>
96+
void print_device_info<base_device::DEVICE_GPU>(const base_device::DEVICE_GPU *ctx, std::ofstream &ofs_device);
97+
98+
template <>
99+
void record_device_memory<base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* dev, std::ofstream& ofs_device, std::string str, size_t size);
100+
#endif
101+
94102
} // end of namespace information
95103
} // end of namespace base_device
96104

source/source_base/module_device/output_device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ void print_device_info<base_device::DEVICE_GPU>(
190190
ofs_device << "Detected " << deviceCount << " CUDA Capable device(s)\n";
191191
}
192192
int dev = 0, driverVersion = 0, runtimeVersion = 0;
193-
cudaErrcheck(cudaSetDevice(dev));
193+
cudaErrcheck(cudaGetDevice(&dev));
194194
cudaDeviceProp deviceProp;
195195
cudaErrcheck(cudaGetDeviceProperties(&deviceProp, dev));
196196
ofs_device << "\nDevice " << dev << ":\t " << deviceProp.name << std::endl;
@@ -429,7 +429,7 @@ void print_device_info<base_device::DEVICE_GPU>(
429429
ofs_device << "Detected " << deviceCount << " CUDA Capable device(s)\n";
430430
}
431431
int dev = 0, driverVersion = 0, runtimeVersion = 0;
432-
hipErrcheck(hipSetDevice(dev));
432+
hipErrcheck(hipGetDevice(&dev));
433433
hipDeviceProp_t deviceProp;
434434
hipErrcheck(hipGetDeviceProperties(&deviceProp, dev));
435435
ofs_device << "\nDevice " << dev << ":\t " << deviceProp.name << std::endl;

0 commit comments

Comments
 (0)