Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions source/source_base/module_device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ std::string get_device_flag(const std::string& device,
/**
* @brief Get the rank of current node
* Note that GPU can only be binded with CPU in the same node
*
* @return int
*
* @return int
*/
int get_node_rank();
int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm = MPI_COMM_WORLD);
Expand All @@ -91,6 +91,14 @@ void record_device_memory(const Device* dev, std::ofstream& ofs_device, std::str
return;
}

#if defined(__CUDA) || defined(__ROCM)
template <>
void print_device_info<base_device::DEVICE_GPU>(const base_device::DEVICE_GPU *ctx, std::ofstream &ofs_device);

template <>
void record_device_memory<base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* dev, std::ofstream& ofs_device, std::string str, size_t size);
#endif

} // end of namespace information
} // end of namespace base_device

Expand Down
4 changes: 2 additions & 2 deletions source/source_base/module_device/output_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ void print_device_info<base_device::DEVICE_GPU>(
ofs_device << "Detected " << deviceCount << " CUDA Capable device(s)\n";
}
int dev = 0, driverVersion = 0, runtimeVersion = 0;
cudaErrcheck(cudaSetDevice(dev));
cudaErrcheck(cudaGetDevice(&dev));
cudaDeviceProp deviceProp;
cudaErrcheck(cudaGetDeviceProperties(&deviceProp, dev));
ofs_device << "\nDevice " << dev << ":\t " << deviceProp.name << std::endl;
Expand Down Expand Up @@ -429,7 +429,7 @@ void print_device_info<base_device::DEVICE_GPU>(
ofs_device << "Detected " << deviceCount << " CUDA Capable device(s)\n";
}
int dev = 0, driverVersion = 0, runtimeVersion = 0;
hipErrcheck(hipSetDevice(dev));
hipErrcheck(hipGetDevice(&dev));
hipDeviceProp_t deviceProp;
hipErrcheck(hipGetDeviceProperties(&deviceProp, dev));
ofs_device << "\nDevice " << dev << ":\t " << deviceProp.name << std::endl;
Expand Down
Loading