Skip to content

Commit 62810f6

Browse files
committed
Add a perturbation estimator for haar result.
1 parent 750c9ab commit 62810f6

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

qmb/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from . import precompile as _ # type: ignore[no-redef]
1616
from . import list_loss as _ # type: ignore[no-redef]
1717
from . import chop_imag as _ # type: ignore[no-redef]
18+
from . import pert as _ # type: ignore[no-redef]
1819
from . import run as _ # type: ignore[no-redef]
1920
from .subcommand_dict import subcommand_dict
2021

qmb/pert.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
This file implements a perturbation estimator from haar.
3+
"""
4+
5+
import logging
6+
import typing
7+
import dataclasses
8+
import tyro
9+
from .common import CommonConfig
10+
from .subcommand_dict import subcommand_dict
11+
12+
13+
@dataclasses.dataclass
14+
class PerturbationConfig:
15+
"""
16+
The perturbation estimator from haar.
17+
"""
18+
19+
common: typing.Annotated[CommonConfig, tyro.conf.OmitArgPrefixes]
20+
21+
def main(self, *, model_param: typing.Any = None, network_param: typing.Any = None) -> None:
22+
"""
23+
The main function of two-step optimization process based on imaginary time.
24+
"""
25+
# pylint: disable=too-many-locals
26+
# pylint: disable=too-many-statements
27+
# pylint: disable=too-many-branches
28+
29+
model, _, data = self.common.main(model_param=model_param, network_param=network_param)
30+
31+
if "haar" not in data and "imag" in data:
32+
data["haar"] = data.pop("imag")
33+
configs, psi = data["haar"]["pool"]
34+
configs = configs.to(self.common.device)
35+
psi = psi.to(self.common.device)
36+
37+
energy0_num = psi.conj() @ model.apply_within(configs, psi, configs)
38+
energy0_den = psi.conj() @ psi
39+
energy0 = (energy0_num / energy0_den).real.item()
40+
logging.info("Current energy is %.8f", energy0)
41+
logging.info("Reference energy is %.8f", model.ref_energy)
42+
43+
number = configs.size(0)
44+
last_result_number = 0
45+
current_target_number = number
46+
logging.info("Starting finding relative configurations with %d.", number)
47+
while True:
48+
other_configs = model.find_relative(configs, psi, current_target_number, configs)
49+
current_result_number = other_configs.size(0)
50+
logging.info("Found %d relative configurations.", current_result_number)
51+
if current_result_number == last_result_number:
52+
logging.info("No new configurations found, stopping at %d.", current_result_number)
53+
break
54+
current_target_number = current_target_number * 2
55+
logging.info("Doubling target number to %d.", current_target_number)
56+
break
57+
58+
hamiltonian_psi = model.apply_within(configs, psi, other_configs)
59+
energy2_num = (hamiltonian_psi.conj() @ hamiltonian_psi).real / (psi.conj() @ psi).real
60+
energy2_den = energy0 - model.diagonal_term(other_configs).real
61+
energy2 = (energy2_num / energy2_den).sum().item()
62+
logging.info("Correct energy is %.8f", energy2)
63+
logging.info("Error is reduced from %.8f to %.8f", energy0 - model.ref_energy, energy2 - model.ref_energy)
64+
65+
66+
subcommand_dict["pert"] = PerturbationConfig

0 commit comments

Comments
 (0)