Skip to content

Commit de9a851

Browse files
committed
Add diagonal_term in kernel.
1 parent 2447f18 commit de9a851

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

qmb/_hamiltonian.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ TORCH_LIBRARY_FRAGMENT(QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), m) {
8484
m.def("apply_within(Tensor configs_i, Tensor psi_i, Tensor configs_j, Tensor site, Tensor kind, Tensor coef) -> Tensor");
8585
m.def("find_relative(Tensor configs_i, Tensor psi_i, int count_selected, Tensor site, Tensor kind, Tensor coef, Tensor configs_exclude) -> Tensor"
8686
);
87+
m.def("diagonal_term(Tensor configs, Tensor site, Tensor kind, Tensor coef) -> Tensor");
8788
m.def("single_relative(Tensor configs, Tensor site, Tensor kind, Tensor coef) -> Tensor");
8889
}
8990
#undef QMB_LIBRARY

qmb/_hamiltonian_cuda.cu

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,134 @@ auto find_relative_interface(
674674
return unique_nonzero_result_config;
675675
}
676676

677+
template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
678+
__device__ void diagonal_term_kernel(
679+
std::int64_t term_index,
680+
std::int64_t batch_index,
681+
std::int64_t term_number,
682+
std::int64_t batch_size,
683+
const std::array<std::int16_t, max_op_number>* site, // term_number
684+
const std::array<std::uint8_t, max_op_number>* kind, // term_number
685+
const std::array<double, 2>* coef, // term_number
686+
const std::array<std::uint8_t, n_qubytes>* configs, // batch_size
687+
std::array<double, 2>* result_psi
688+
) {
689+
std::array<std::uint8_t, n_qubytes> current_configs = configs[batch_index];
690+
auto [success, parity] = hamiltonian_apply_kernel<max_op_number, n_qubytes, particle_cut>(
691+
/*current_configs=*/current_configs,
692+
/*term_index=*/term_index,
693+
/*batch_index=*/batch_index,
694+
/*site=*/site,
695+
/*kind=*/kind
696+
);
697+
698+
if (!success) {
699+
return;
700+
}
701+
auto less = array_less<std::uint8_t, n_qubytes>();
702+
if (less(current_configs, configs[batch_index]) || less(configs[batch_index], current_configs)) {
703+
return; // The term does not apply to the current configuration
704+
}
705+
std::int8_t sign = parity ? -1 : +1;
706+
atomicAdd(&result_psi[batch_index][0], sign * coef[term_index][0]);
707+
atomicAdd(&result_psi[batch_index][1], sign * coef[term_index][1]);
708+
}
709+
710+
template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
711+
__global__ void diagonal_term_kernel_interface(
712+
std::int64_t term_number,
713+
std::int64_t batch_size,
714+
const std::array<std::int16_t, max_op_number>* site, // term_number
715+
const std::array<std::uint8_t, max_op_number>* kind, // term_number
716+
const std::array<double, 2>* coef, // term_number
717+
const std::array<std::uint8_t, n_qubytes>* configs, // batch_size
718+
std::array<double, 2>* result_psi
719+
) {
720+
std::int64_t term_index = blockIdx.x * blockDim.x + threadIdx.x;
721+
std::int64_t batch_index = blockIdx.y * blockDim.y + threadIdx.y;
722+
723+
if (term_index < term_number && batch_index < batch_size) {
724+
diagonal_term_kernel<max_op_number, n_qubytes, particle_cut>(
725+
/*term_index=*/term_index,
726+
/*batch_index=*/batch_index,
727+
/*term_number=*/term_number,
728+
/*batch_size=*/batch_size,
729+
/*site=*/site,
730+
/*kind=*/kind,
731+
/*coef=*/coef,
732+
/*configs=*/configs,
733+
/*result_psi=*/result_psi
734+
);
735+
}
736+
}
737+
738+
template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
739+
auto diagonal_term_interface(const torch::Tensor& configs, const torch::Tensor& site, const torch::Tensor& kind, const torch::Tensor& coef)
740+
-> torch::Tensor {
741+
std::int64_t device_id = configs.device().index();
742+
std::int64_t batch_size = configs.size(0);
743+
std::int64_t term_number = site.size(0);
744+
at::cuda::CUDAGuard cuda_device_guard(device_id);
745+
746+
TORCH_CHECK(configs.device().type() == torch::kCUDA, "configs must be on CUDA.")
747+
TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others.");
748+
TORCH_CHECK(configs.is_contiguous(), "configs must be contiguous.")
749+
TORCH_CHECK(configs.dtype() == torch::kUInt8, "configs must be uint8.")
750+
TORCH_CHECK(configs.dim() == 2, "configs must be 2D.")
751+
TORCH_CHECK(configs.size(0) == batch_size, "configs batch size must match the provided batch_size.");
752+
TORCH_CHECK(configs.size(1) == n_qubytes, "configs must have the same number of qubits as the provided n_qubytes.");
753+
754+
TORCH_CHECK(site.device().type() == torch::kCUDA, "site must be on CUDA.")
755+
TORCH_CHECK(site.device().index() == device_id, "site must be on the same device as others.");
756+
TORCH_CHECK(site.is_contiguous(), "site must be contiguous.")
757+
TORCH_CHECK(site.dtype() == torch::kInt16, "site must be int16.")
758+
TORCH_CHECK(site.dim() == 2, "site must be 2D.")
759+
TORCH_CHECK(site.size(0) == term_number, "site size must match the provided term_number.");
760+
TORCH_CHECK(site.size(1) == max_op_number, "site must match the provided max_op_number.");
761+
762+
TORCH_CHECK(kind.device().type() == torch::kCUDA, "kind must be on CUDA.")
763+
TORCH_CHECK(kind.device().index() == device_id, "kind must be on the same device as others.");
764+
TORCH_CHECK(kind.is_contiguous(), "kind must be contiguous.")
765+
TORCH_CHECK(kind.dtype() == torch::kUInt8, "kind must be uint8.")
766+
TORCH_CHECK(kind.dim() == 2, "kind must be 2D.")
767+
TORCH_CHECK(kind.size(0) == term_number, "kind size must match the provided term_number.");
768+
TORCH_CHECK(kind.size(1) == max_op_number, "kind must match the provided max_op_number.");
769+
770+
TORCH_CHECK(coef.device().type() == torch::kCUDA, "coef must be on CUDA.")
771+
TORCH_CHECK(coef.device().index() == device_id, "coef must be on the same device as others.");
772+
TORCH_CHECK(coef.is_contiguous(), "coef must be contiguous.")
773+
TORCH_CHECK(coef.dtype() == torch::kFloat64, "coef must be float64.")
774+
TORCH_CHECK(coef.dim() == 2, "coef must be 2D.")
775+
TORCH_CHECK(coef.size(0) == term_number, "coef size must match the provided term_number.");
776+
TORCH_CHECK(coef.size(1) == 2, "coef must contain 2 elements for each term.");
777+
778+
auto stream = at::cuda::getCurrentCUDAStream(device_id);
779+
auto policy = thrust::device.on(stream);
780+
781+
cudaDeviceProp prop;
782+
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id));
783+
std::int64_t max_threads_per_block = prop.maxThreadsPerBlock;
784+
785+
auto result_psi = torch::zeros({batch_size, 2}, torch::TensorOptions().dtype(torch::kFloat64).device(device, device_id));
786+
787+
auto threads_per_block = dim3{1, max_threads_per_block >> 1}; // I don't know why, but need to divide by 2 to avoid errors
788+
auto num_blocks =
789+
dim3{(term_number + threads_per_block.x - 1) / threads_per_block.x, (batch_size + threads_per_block.y - 1) / threads_per_block.y};
790+
791+
diagonal_term_kernel_interface<max_op_number, n_qubytes, particle_cut><<<num_blocks, threads_per_block, 0, stream>>>(
792+
/*term_number=*/term_number,
793+
/*batch_size=*/batch_size,
794+
/*site=*/reinterpret_cast<const std::array<std::int16_t, max_op_number>*>(site.data_ptr()),
795+
/*kind=*/reinterpret_cast<const std::array<std::uint8_t, max_op_number>*>(kind.data_ptr()),
796+
/*coef=*/reinterpret_cast<const std::array<double, 2>*>(coef.data_ptr()),
797+
/*configs=*/reinterpret_cast<const std::array<std::uint8_t, n_qubytes>*>(configs.data_ptr()),
798+
/*result_psi=*/reinterpret_cast<std::array<double, 2>*>(result_psi.data_ptr())
799+
);
800+
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
801+
802+
return result_psi;
803+
}
804+
677805
template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
678806
__device__ void single_relative_kernel(
679807
std::int64_t term_index,
@@ -880,6 +1008,7 @@ auto single_relative_interface(const torch::Tensor& configs, const torch::Tensor
8801008
TORCH_LIBRARY_IMPL(QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), CUDA, m) {
8811009
m.impl("apply_within", apply_within_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
8821010
m.impl("find_relative", find_relative_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
1011+
m.impl("diagonal_term", diagonal_term_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
8831012
m.impl("single_relative", single_relative_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
8841013
}
8851014
#undef QMB_LIBRARY

0 commit comments

Comments
 (0)