2
2
This file implements the reinforcement learning based subspace diagonalization algorithm.
3
3
"""
4
4
5
- import sys
6
5
import logging
7
6
import typing
8
7
import dataclasses
@@ -78,6 +77,8 @@ class RldiagConfig:
78
77
79
78
# The initial configuration for the first step, which is usually the Hatree-Fock state for quantum chemistry system
80
79
initial_config : typing .Annotated [typing .Optional [str ], tyro .conf .arg (aliases = ["-i" ])] = None
80
+ # The maximum size of the configuration pool
81
+ max_pool_size : typing .Annotated [int , tyro .conf .arg (aliases = ["-n" ])] = 32768
81
82
# The learning rate for the local optimizer
82
83
learning_rate : typing .Annotated [float , tyro .conf .arg (aliases = ["-r" ])] = 1e-3
83
84
# The step of lanczos iteration for calculating the energy
@@ -158,9 +159,12 @@ def main(self) -> None:
158
159
# | pruned | remained | expanded |
159
160
# | new config pool |
160
161
action = score .real >= - self .alpha
161
- action [0 ] = True
162
162
_ , topk = torch .topk (score .real , k = score .size (0 ) // 2 , dim = 0 )
163
163
action [topk ] = True
164
+ if score .size (0 ) > self .max_pool_size :
165
+ _ , topk = torch .topk (- score .real , k = score .size (0 ) - self .max_pool_size )
166
+ action [topk ] = False
167
+ action [0 ] = True
164
168
remained_configs = configs [action ]
165
169
pruned_configs = configs [torch .logical_not (action )]
166
170
expanded_configs = model .single_relative (remained_configs ) # There are duplicated config here
@@ -174,9 +178,6 @@ def main(self) -> None:
174
178
logging .info ("Configuration pool size: %d" , configs_size )
175
179
writer .add_scalar ("rldiag/configs/global" , configs_size , data ["rldiag" ]["global" ]) # type: ignore[no-untyped-call]
176
180
writer .add_scalar ("rldiag/configs/local" , configs_size , data ["rldiag" ]["local" ]) # type: ignore[no-untyped-call]
177
- if configs_size == 0 :
178
- logging .info ("All configurations has been pruned, please start a new configuration pool state" )
179
- sys .exit (0 )
180
181
181
182
if last_state is not None :
182
183
old_state = last_state [torch .cat ([action .nonzero ()[:, 0 ], torch .logical_not (action ).nonzero ()[:, 0 ]])]
0 commit comments