Skip to content

Commit aa7dd73

Browse files
committed
Add device guards in pytorch cuda kernels.
Closes: USTC-KnowledgeComputingLab/qmb#54 PR: USTC-KnowledgeComputingLab/qmb#60 Signed-off-by: Hao Zhang <[email protected]>
2 parents 73867d5 + 14d7ac3 commit aa7dd73

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

qmb/_hamiltonian_cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <ATen/cuda/Exceptions.h>
2+
#include <c10/cuda/CUDAGuard.h>
23
#include <c10/cuda/CUDAStream.h>
34
#include <cuda_runtime.h>
45
#include <curand_kernel.h>
@@ -180,6 +181,7 @@ auto apply_within_interface(
180181
std::int64_t batch_size = configs.size(0);
181182
std::int64_t result_batch_size = result_configs.size(0);
182183
std::int64_t term_number = site.size(0);
184+
at::cuda::CUDAGuard cuda_device_guard(device_id);
183185

184186
TORCH_CHECK(configs.device().type() == torch::kCUDA, "configs must be on CUDA.")
185187
TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others.");
@@ -560,6 +562,7 @@ auto find_relative_interface(
560562
std::int64_t batch_size = configs.size(0);
561563
std::int64_t term_number = site.size(0);
562564
std::int64_t exclude_size = exclude_configs.size(0);
565+
at::cuda::CUDAGuard cuda_device_guard(device_id);
563566

564567
TORCH_CHECK(configs.device().type() == torch::kCUDA, "configs must be on CUDA.")
565568
TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others.");
@@ -779,6 +782,7 @@ auto single_relative_interface(const torch::Tensor& configs, const torch::Tensor
779782
std::int64_t device_id = configs.device().index();
780783
std::int64_t batch_size = configs.size(0);
781784
std::int64_t term_number = site.size(0);
785+
at::cuda::CUDAGuard cuda_device_guard(device_id);
782786

783787
TORCH_CHECK(configs.device().type() == torch::kCUDA, "configs must be on CUDA.")
784788
TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others.");

0 commit comments

Comments
 (0)