Skip to content

Commit b30242f

Browse files
committed
Add diagonal_term in kernel.
PR: USTC-KnowledgeComputingLab/qmb#69 Signed-off-by: Hao Zhang <[email protected]>
2 parents b9a8dd1 + 62810f6 commit b30242f

File tree

10 files changed

+244
-0
lines changed

10 files changed

+244
-0
lines changed

qmb/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from . import precompile as _ # type: ignore[no-redef]
1616
from . import list_loss as _ # type: ignore[no-redef]
1717
from . import chop_imag as _ # type: ignore[no-redef]
18+
from . import pert as _ # type: ignore[no-redef]
1819
from . import run as _ # type: ignore[no-redef]
1920
from .subcommand_dict import subcommand_dict
2021

qmb/_hamiltonian.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ TORCH_LIBRARY_FRAGMENT(QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), m) {
8484
m.def("apply_within(Tensor configs_i, Tensor psi_i, Tensor configs_j, Tensor site, Tensor kind, Tensor coef) -> Tensor");
8585
m.def("find_relative(Tensor configs_i, Tensor psi_i, int count_selected, Tensor site, Tensor kind, Tensor coef, Tensor configs_exclude) -> Tensor"
8686
);
87+
m.def("diagonal_term(Tensor configs, Tensor site, Tensor kind, Tensor coef) -> Tensor");
8788
m.def("single_relative(Tensor configs, Tensor site, Tensor kind, Tensor coef) -> Tensor");
8889
}
8990
#undef QMB_LIBRARY

qmb/_hamiltonian_cuda.cu

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,134 @@ auto find_relative_interface(
674674
return unique_nonzero_result_config;
675675
}
676676

677+
template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
678+
__device__ void diagonal_term_kernel(
679+
std::int64_t term_index,
680+
std::int64_t batch_index,
681+
std::int64_t term_number,
682+
std::int64_t batch_size,
683+
const std::array<std::int16_t, max_op_number>* site, // term_number
684+
const std::array<std::uint8_t, max_op_number>* kind, // term_number
685+
const std::array<double, 2>* coef, // term_number
686+
const std::array<std::uint8_t, n_qubytes>* configs, // batch_size
687+
std::array<double, 2>* result_psi
688+
) {
689+
std::array<std::uint8_t, n_qubytes> current_configs = configs[batch_index];
690+
auto [success, parity] = hamiltonian_apply_kernel<max_op_number, n_qubytes, particle_cut>(
691+
/*current_configs=*/current_configs,
692+
/*term_index=*/term_index,
693+
/*batch_index=*/batch_index,
694+
/*site=*/site,
695+
/*kind=*/kind
696+
);
697+
698+
if (!success) {
699+
return;
700+
}
701+
auto less = array_less<std::uint8_t, n_qubytes>();
702+
if (less(current_configs, configs[batch_index]) || less(configs[batch_index], current_configs)) {
703+
return; // The term does not apply to the current configuration
704+
}
705+
std::int8_t sign = parity ? -1 : +1;
706+
atomicAdd(&result_psi[batch_index][0], sign * coef[term_index][0]);
707+
atomicAdd(&result_psi[batch_index][1], sign * coef[term_index][1]);
708+
}
709+
710+
template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
711+
__global__ void diagonal_term_kernel_interface(
712+
std::int64_t term_number,
713+
std::int64_t batch_size,
714+
const std::array<std::int16_t, max_op_number>* site, // term_number
715+
const std::array<std::uint8_t, max_op_number>* kind, // term_number
716+
const std::array<double, 2>* coef, // term_number
717+
const std::array<std::uint8_t, n_qubytes>* configs, // batch_size
718+
std::array<double, 2>* result_psi
719+
) {
720+
std::int64_t term_index = blockIdx.x * blockDim.x + threadIdx.x;
721+
std::int64_t batch_index = blockIdx.y * blockDim.y + threadIdx.y;
722+
723+
if (term_index < term_number && batch_index < batch_size) {
724+
diagonal_term_kernel<max_op_number, n_qubytes, particle_cut>(
725+
/*term_index=*/term_index,
726+
/*batch_index=*/batch_index,
727+
/*term_number=*/term_number,
728+
/*batch_size=*/batch_size,
729+
/*site=*/site,
730+
/*kind=*/kind,
731+
/*coef=*/coef,
732+
/*configs=*/configs,
733+
/*result_psi=*/result_psi
734+
);
735+
}
736+
}
737+
738+
template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
739+
auto diagonal_term_interface(const torch::Tensor& configs, const torch::Tensor& site, const torch::Tensor& kind, const torch::Tensor& coef)
740+
-> torch::Tensor {
741+
std::int64_t device_id = configs.device().index();
742+
std::int64_t batch_size = configs.size(0);
743+
std::int64_t term_number = site.size(0);
744+
at::cuda::CUDAGuard cuda_device_guard(device_id);
745+
746+
TORCH_CHECK(configs.device().type() == torch::kCUDA, "configs must be on CUDA.")
747+
TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others.");
748+
TORCH_CHECK(configs.is_contiguous(), "configs must be contiguous.")
749+
TORCH_CHECK(configs.dtype() == torch::kUInt8, "configs must be uint8.")
750+
TORCH_CHECK(configs.dim() == 2, "configs must be 2D.")
751+
TORCH_CHECK(configs.size(0) == batch_size, "configs batch size must match the provided batch_size.");
752+
TORCH_CHECK(configs.size(1) == n_qubytes, "configs must have the same number of qubits as the provided n_qubytes.");
753+
754+
TORCH_CHECK(site.device().type() == torch::kCUDA, "site must be on CUDA.")
755+
TORCH_CHECK(site.device().index() == device_id, "site must be on the same device as others.");
756+
TORCH_CHECK(site.is_contiguous(), "site must be contiguous.")
757+
TORCH_CHECK(site.dtype() == torch::kInt16, "site must be int16.")
758+
TORCH_CHECK(site.dim() == 2, "site must be 2D.")
759+
TORCH_CHECK(site.size(0) == term_number, "site size must match the provided term_number.");
760+
TORCH_CHECK(site.size(1) == max_op_number, "site must match the provided max_op_number.");
761+
762+
TORCH_CHECK(kind.device().type() == torch::kCUDA, "kind must be on CUDA.")
763+
TORCH_CHECK(kind.device().index() == device_id, "kind must be on the same device as others.");
764+
TORCH_CHECK(kind.is_contiguous(), "kind must be contiguous.")
765+
TORCH_CHECK(kind.dtype() == torch::kUInt8, "kind must be uint8.")
766+
TORCH_CHECK(kind.dim() == 2, "kind must be 2D.")
767+
TORCH_CHECK(kind.size(0) == term_number, "kind size must match the provided term_number.");
768+
TORCH_CHECK(kind.size(1) == max_op_number, "kind must match the provided max_op_number.");
769+
770+
TORCH_CHECK(coef.device().type() == torch::kCUDA, "coef must be on CUDA.")
771+
TORCH_CHECK(coef.device().index() == device_id, "coef must be on the same device as others.");
772+
TORCH_CHECK(coef.is_contiguous(), "coef must be contiguous.")
773+
TORCH_CHECK(coef.dtype() == torch::kFloat64, "coef must be float64.")
774+
TORCH_CHECK(coef.dim() == 2, "coef must be 2D.")
775+
TORCH_CHECK(coef.size(0) == term_number, "coef size must match the provided term_number.");
776+
TORCH_CHECK(coef.size(1) == 2, "coef must contain 2 elements for each term.");
777+
778+
auto stream = at::cuda::getCurrentCUDAStream(device_id);
779+
auto policy = thrust::device.on(stream);
780+
781+
cudaDeviceProp prop;
782+
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id));
783+
std::int64_t max_threads_per_block = prop.maxThreadsPerBlock;
784+
785+
auto result_psi = torch::zeros({batch_size, 2}, torch::TensorOptions().dtype(torch::kFloat64).device(device, device_id));
786+
787+
auto threads_per_block = dim3{1, max_threads_per_block >> 1}; // I don't know why, but need to divide by 2 to avoid errors
788+
auto num_blocks =
789+
dim3{(term_number + threads_per_block.x - 1) / threads_per_block.x, (batch_size + threads_per_block.y - 1) / threads_per_block.y};
790+
791+
diagonal_term_kernel_interface<max_op_number, n_qubytes, particle_cut><<<num_blocks, threads_per_block, 0, stream>>>(
792+
/*term_number=*/term_number,
793+
/*batch_size=*/batch_size,
794+
/*site=*/reinterpret_cast<const std::array<std::int16_t, max_op_number>*>(site.data_ptr()),
795+
/*kind=*/reinterpret_cast<const std::array<std::uint8_t, max_op_number>*>(kind.data_ptr()),
796+
/*coef=*/reinterpret_cast<const std::array<double, 2>*>(coef.data_ptr()),
797+
/*configs=*/reinterpret_cast<const std::array<std::uint8_t, n_qubytes>*>(configs.data_ptr()),
798+
/*result_psi=*/reinterpret_cast<std::array<double, 2>*>(result_psi.data_ptr())
799+
);
800+
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
801+
802+
return result_psi;
803+
}
804+
677805
template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
678806
__device__ void single_relative_kernel(
679807
std::int64_t term_index,
@@ -880,6 +1008,7 @@ auto single_relative_interface(const torch::Tensor& configs, const torch::Tensor
8801008
TORCH_LIBRARY_IMPL(QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), CUDA, m) {
8811009
m.impl("apply_within", apply_within_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
8821010
m.impl("find_relative", find_relative_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
1011+
m.impl("diagonal_term", diagonal_term_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
8831012
m.impl("single_relative", single_relative_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
8841013
}
8851014
#undef QMB_LIBRARY

qmb/fcidump.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ def apply_within(self, configs_i: torch.Tensor, psi_i: torch.Tensor, configs_j:
192192
def find_relative(self, configs_i: torch.Tensor, psi_i: torch.Tensor, count_selected: int, configs_exclude: torch.Tensor | None = None) -> torch.Tensor:
193193
return self.hamiltonian.find_relative(configs_i, psi_i, count_selected, configs_exclude)
194194

195+
def diagonal_term(self, configs: torch.Tensor) -> torch.Tensor:
196+
return self.hamiltonian.diagonal_term(configs)
197+
195198
def single_relative(self, configs: torch.Tensor) -> torch.Tensor:
196199
return self.hamiltonian.single_relative(configs)
197200

qmb/hamiltonian.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,26 @@ def find_relative(
165165
configs_j = _find_relative(configs_i, torch.view_as_real(psi_i), count_selected, self.site, self.kind, self.coef, configs_exclude)
166166
return configs_j
167167

168+
def diagonal_term(self, configs: torch.Tensor) -> torch.Tensor:
169+
"""
170+
Get the diagonal term of the Hamiltonian for the given configurations.
171+
172+
Parameters
173+
----------
174+
configs : torch.Tensor
175+
A uint8 tensor of shape [batch_size, n_qubytes] representing the input configurations.
176+
177+
Returns
178+
-------
179+
torch.Tensor
180+
A complex64 tensor of shape [batch_size] representing the diagonal term of the Hamiltonian for the given configurations.
181+
"""
182+
self._prepare_data(configs.device)
183+
_diagonal_term: typing.Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
184+
_diagonal_term = getattr(self._load_module(configs.device.type, configs.size(1), self.particle_cut), "diagonal_term")
185+
psi_result = torch.view_as_complex(_diagonal_term(configs, self.site, self.kind, self.coef))
186+
return psi_result
187+
168188
def single_relative(self, configs: torch.Tensor) -> torch.Tensor:
169189
"""
170190
Find a single relative configuration for each configurations.

qmb/hubbard.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def apply_within(self, configs_i: torch.Tensor, psi_i: torch.Tensor, configs_j:
112112
def find_relative(self, configs_i: torch.Tensor, psi_i: torch.Tensor, count_selected: int, configs_exclude: torch.Tensor | None = None) -> torch.Tensor:
113113
return self.hamiltonian.find_relative(configs_i, psi_i, count_selected, configs_exclude)
114114

115+
def diagonal_term(self, configs: torch.Tensor) -> torch.Tensor:
116+
return self.hamiltonian.diagonal_term(configs)
117+
115118
def single_relative(self, configs: torch.Tensor) -> torch.Tensor:
116119
return self.hamiltonian.single_relative(configs)
117120

qmb/ising.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ def apply_within(self, configs_i: torch.Tensor, psi_i: torch.Tensor, configs_j:
221221
def find_relative(self, configs_i: torch.Tensor, psi_i: torch.Tensor, count_selected: int, configs_exclude: torch.Tensor | None = None) -> torch.Tensor:
222222
return self.hamiltonian.find_relative(configs_i, psi_i, count_selected, configs_exclude)
223223

224+
def diagonal_term(self, configs: torch.Tensor) -> torch.Tensor:
225+
return self.hamiltonian.diagonal_term(configs)
226+
224227
def single_relative(self, configs: torch.Tensor) -> torch.Tensor:
225228
return self.hamiltonian.single_relative(configs)
226229

qmb/model_dict.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,21 @@ def find_relative(self, configs_i: torch.Tensor, psi_i: torch.Tensor, count_sele
173173
The relative configurations.
174174
"""
175175

176+
def diagonal_term(self, configs: torch.Tensor) -> torch.Tensor:
177+
"""
178+
Calculate the diagonal term for the given configurations.
179+
180+
Parameters
181+
----------
182+
configs : torch.Tensor
183+
The configurations to calculate the diagonal term for.
184+
185+
Returns
186+
-------
187+
torch.Tensor
188+
The diagonal term of the configurations.
189+
"""
190+
176191
def single_relative(self, configs: torch.Tensor) -> torch.Tensor:
177192
"""
178193
Find a single relative configuration for each configurations.

qmb/openfermion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def apply_within(self, configs_i: torch.Tensor, psi_i: torch.Tensor, configs_j:
8686
def find_relative(self, configs_i: torch.Tensor, psi_i: torch.Tensor, count_selected: int, configs_exclude: torch.Tensor | None = None) -> torch.Tensor:
8787
return self.hamiltonian.find_relative(configs_i, psi_i, count_selected, configs_exclude)
8888

89+
def diagonal_term(self, configs: torch.Tensor) -> torch.Tensor:
90+
return self.hamiltonian.diagonal_term(configs)
91+
8992
def single_relative(self, configs: torch.Tensor) -> torch.Tensor:
9093
return self.hamiltonian.single_relative(configs)
9194

qmb/pert.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
This file implements a perturbation estimator from haar.
3+
"""
4+
5+
import logging
6+
import typing
7+
import dataclasses
8+
import tyro
9+
from .common import CommonConfig
10+
from .subcommand_dict import subcommand_dict
11+
12+
13+
@dataclasses.dataclass
14+
class PerturbationConfig:
15+
"""
16+
The perturbation estimator from haar.
17+
"""
18+
19+
common: typing.Annotated[CommonConfig, tyro.conf.OmitArgPrefixes]
20+
21+
def main(self, *, model_param: typing.Any = None, network_param: typing.Any = None) -> None:
22+
"""
23+
The main function of two-step optimization process based on imaginary time.
24+
"""
25+
# pylint: disable=too-many-locals
26+
# pylint: disable=too-many-statements
27+
# pylint: disable=too-many-branches
28+
29+
model, _, data = self.common.main(model_param=model_param, network_param=network_param)
30+
31+
if "haar" not in data and "imag" in data:
32+
data["haar"] = data.pop("imag")
33+
configs, psi = data["haar"]["pool"]
34+
configs = configs.to(self.common.device)
35+
psi = psi.to(self.common.device)
36+
37+
energy0_num = psi.conj() @ model.apply_within(configs, psi, configs)
38+
energy0_den = psi.conj() @ psi
39+
energy0 = (energy0_num / energy0_den).real.item()
40+
logging.info("Current energy is %.8f", energy0)
41+
logging.info("Reference energy is %.8f", model.ref_energy)
42+
43+
number = configs.size(0)
44+
last_result_number = 0
45+
current_target_number = number
46+
logging.info("Starting finding relative configurations with %d.", number)
47+
while True:
48+
other_configs = model.find_relative(configs, psi, current_target_number, configs)
49+
current_result_number = other_configs.size(0)
50+
logging.info("Found %d relative configurations.", current_result_number)
51+
if current_result_number == last_result_number:
52+
logging.info("No new configurations found, stopping at %d.", current_result_number)
53+
break
54+
current_target_number = current_target_number * 2
55+
logging.info("Doubling target number to %d.", current_target_number)
56+
break
57+
58+
hamiltonian_psi = model.apply_within(configs, psi, other_configs)
59+
energy2_num = (hamiltonian_psi.conj() @ hamiltonian_psi).real / (psi.conj() @ psi).real
60+
energy2_den = energy0 - model.diagonal_term(other_configs).real
61+
energy2 = (energy2_num / energy2_den).sum().item()
62+
logging.info("Correct energy is %.8f", energy2)
63+
logging.info("Error is reduced from %.8f to %.8f", energy0 - model.ref_energy, energy2 - model.ref_energy)
64+
65+
66+
subcommand_dict["pert"] = PerturbationConfig

0 commit comments

Comments
 (0)