diff --git a/source/source_base/module_device/device.h b/source/source_base/module_device/device.h index ed560affdf..a073bdab91 100644 --- a/source/source_base/module_device/device.h +++ b/source/source_base/module_device/device.h @@ -65,8 +65,8 @@ std::string get_device_flag(const std::string& device, /** * @brief Get the rank of current node * Note that GPU can only be binded with CPU in the same node - * - * @return int + * + * @return int */ int get_node_rank(); int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm = MPI_COMM_WORLD); @@ -91,6 +91,14 @@ void record_device_memory(const Device* dev, std::ofstream& ofs_device, std::str return; } +#if defined(__CUDA) || defined(__ROCM) +template <> +void print_device_info(const base_device::DEVICE_GPU *ctx, std::ofstream &ofs_device); + +template <> +void record_device_memory(const base_device::DEVICE_GPU* dev, std::ofstream& ofs_device, std::string str, size_t size); +#endif + } // end of namespace information } // end of namespace base_device diff --git a/source/source_base/module_device/output_device.cpp b/source/source_base/module_device/output_device.cpp index a0cf817844..1d0f018814 100644 --- a/source/source_base/module_device/output_device.cpp +++ b/source/source_base/module_device/output_device.cpp @@ -190,7 +190,7 @@ void print_device_info( ofs_device << "Detected " << deviceCount << " CUDA Capable device(s)\n"; } int dev = 0, driverVersion = 0, runtimeVersion = 0; - cudaErrcheck(cudaSetDevice(dev)); + cudaErrcheck(cudaGetDevice(&dev)); cudaDeviceProp deviceProp; cudaErrcheck(cudaGetDeviceProperties(&deviceProp, dev)); ofs_device << "\nDevice " << dev << ":\t " << deviceProp.name << std::endl; @@ -429,7 +429,7 @@ void print_device_info( ofs_device << "Detected " << deviceCount << " CUDA Capable device(s)\n"; } int dev = 0, driverVersion = 0, runtimeVersion = 0; - hipErrcheck(hipSetDevice(dev)); + hipErrcheck(hipGetDevice(&dev)); hipDeviceProp_t deviceProp; hipErrcheck(hipGetDeviceProperties(&deviceProp, dev)); ofs_device << "\nDevice " << dev << ":\t " << deviceProp.name << std::endl;