Skip to content

Commit 402ba75

Browse files
committed
Add max pool size support for rldiag.
This PR add an option for rldiag script, which control the max pool size for the state. BTW, the PR also contains a commit which deletes code snippet which is unused now in rldiag. PR Tracking at: USTC-KnowledgeComputingLab/qmb#22
2 parents 0d80b7e + e84ea05 commit 402ba75

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

qmb/rldiag.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
This file implements the reinforcement learning based subspace diagonalization algorithm.
33
"""
44

5-
import sys
65
import logging
76
import typing
87
import dataclasses
@@ -78,6 +77,8 @@ class RldiagConfig:
7877

7978
# The initial configuration for the first step, which is usually the Hatree-Fock state for quantum chemistry system
8079
initial_config: typing.Annotated[typing.Optional[str], tyro.conf.arg(aliases=["-i"])] = None
80+
# The maximum size of the configuration pool
81+
max_pool_size: typing.Annotated[int, tyro.conf.arg(aliases=["-n"])] = 32768
8182
# The learning rate for the local optimizer
8283
learning_rate: typing.Annotated[float, tyro.conf.arg(aliases=["-r"])] = 1e-3
8384
# The step of lanczos iteration for calculating the energy
@@ -158,9 +159,12 @@ def main(self) -> None:
158159
# | pruned | remained | expanded |
159160
# | new config pool |
160161
action = score.real >= -self.alpha
161-
action[0] = True
162162
_, topk = torch.topk(score.real, k=score.size(0) // 2, dim=0)
163163
action[topk] = True
164+
if score.size(0) > self.max_pool_size:
165+
_, topk = torch.topk(-score.real, k=score.size(0) - self.max_pool_size)
166+
action[topk] = False
167+
action[0] = True
164168
remained_configs = configs[action]
165169
pruned_configs = configs[torch.logical_not(action)]
166170
expanded_configs = model.single_relative(remained_configs) # There are duplicated config here
@@ -174,9 +178,6 @@ def main(self) -> None:
174178
logging.info("Configuration pool size: %d", configs_size)
175179
writer.add_scalar("rldiag/configs/global", configs_size, data["rldiag"]["global"]) # type: ignore[no-untyped-call]
176180
writer.add_scalar("rldiag/configs/local", configs_size, data["rldiag"]["local"]) # type: ignore[no-untyped-call]
177-
if configs_size == 0:
178-
logging.info("All configurations has been pruned, please start a new configuration pool state")
179-
sys.exit(0)
180181

181182
if last_state is not None:
182183
old_state = last_state[torch.cat([action.nonzero()[:, 0], torch.logical_not(action).nonzero()[:, 0]])]

0 commit comments

Comments
 (0)