Skip to content

Commit fc413a8

Browse files
committed
Add max pool size limit for the rldiag script.
1 parent d37d09e commit fc413a8

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

qmb/rldiag.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ class RldiagConfig:
7878

7979
# The initial configuration for the first step, which is usually the Hatree-Fock state for quantum chemistry system
8080
initial_config: typing.Annotated[typing.Optional[str], tyro.conf.arg(aliases=["-i"])] = None
81+
# The maximum size of the configuration pool
82+
max_pool_size: typing.Annotated[int, tyro.conf.arg(aliases=["-n"])] = 32768
8183
# The learning rate for the local optimizer
8284
learning_rate: typing.Annotated[float, tyro.conf.arg(aliases=["-r"])] = 1e-3
8385
# The step of lanczos iteration for calculating the energy
@@ -158,9 +160,12 @@ def main(self) -> None:
158160
# | pruned | remained | expanded |
159161
# | new config pool |
160162
action = score.real >= -self.alpha
161-
action[0] = True
162163
_, topk = torch.topk(score.real, k=score.size(0) // 2, dim=0)
163164
action[topk] = True
165+
if score.size(0) > self.max_pool_size:
166+
_, topk = torch.topk(-score.real, k=score.size(0) - self.max_pool_size)
167+
action[topk] = False
168+
action[0] = True
164169
remained_configs = configs[action]
165170
pruned_configs = configs[torch.logical_not(action)]
166171
expanded_configs = model.single_relative(remained_configs) # There are duplicated config here

0 commit comments

Comments
 (0)