|
1 | 1 | #include <ATen/cuda/Exceptions.h>
|
| 2 | +#include <c10/cuda/CUDAGuard.h> |
2 | 3 | #include <c10/cuda/CUDAStream.h>
|
3 | 4 | #include <cuda_runtime.h>
|
4 | 5 | #include <curand_kernel.h>
|
@@ -180,6 +181,7 @@ auto apply_within_interface(
|
180 | 181 | std::int64_t batch_size = configs.size(0);
|
181 | 182 | std::int64_t result_batch_size = result_configs.size(0);
|
182 | 183 | std::int64_t term_number = site.size(0);
|
| 184 | + at::cuda::CUDAGuard cuda_device_guard(device_id); |
183 | 185 |
|
184 | 186 | TORCH_CHECK(configs.device().type() == torch::kCUDA, "configs must be on CUDA.")
|
185 | 187 | TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others.");
|
@@ -560,6 +562,7 @@ auto find_relative_interface(
|
560 | 562 | std::int64_t batch_size = configs.size(0);
|
561 | 563 | std::int64_t term_number = site.size(0);
|
562 | 564 | std::int64_t exclude_size = exclude_configs.size(0);
|
| 565 | + at::cuda::CUDAGuard cuda_device_guard(device_id); |
563 | 566 |
|
564 | 567 | TORCH_CHECK(configs.device().type() == torch::kCUDA, "configs must be on CUDA.")
|
565 | 568 | TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others.");
|
@@ -779,6 +782,7 @@ auto single_relative_interface(const torch::Tensor& configs, const torch::Tensor
|
779 | 782 | std::int64_t device_id = configs.device().index();
|
780 | 783 | std::int64_t batch_size = configs.size(0);
|
781 | 784 | std::int64_t term_number = site.size(0);
|
| 785 | + at::cuda::CUDAGuard cuda_device_guard(device_id); |
782 | 786 |
|
783 | 787 | TORCH_CHECK(configs.device().type() == torch::kCUDA, "configs must be on CUDA.")
|
784 | 788 | TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others.");
|
|
0 commit comments