1616# along with this program. If not, see <http://www.gnu.org/licenses/>.
1717
1818import cupy
19- from pyscf .dft import uks
19+ from pyscf .dft import uks as uks_cpu
2020from pyscf import lib
2121from gpu4pyscf .lib import logger
2222from gpu4pyscf .dft import numint , gen_grid , rks
2323from gpu4pyscf .scf import hf , uhf
2424from gpu4pyscf .lib .cupy_helper import tag_array
25+ from gpu4pyscf .lib import utils
2526
2627
2728def get_veff (ks , mol = None , dm = None , dm_last = 0 , vhf_last = 0 , hermi = 1 ):
@@ -30,8 +31,9 @@ def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
3031 '''
3132 if mol is None : mol = ks .mol
3233 if dm is None : dm = ks .make_rdm1 ()
34+ assert dm .ndim == 3
3335 t0 = logger .init_timer (ks )
34- rks .initialize_grids (ks , mol , dm )
36+ rks .initialize_grids (ks , mol , cupy . asarray ( dm [ 0 ] + dm [ 1 ]) )
3537
3638 if hasattr (ks , 'screen_tol' ) and ks .screen_tol is not None :
3739 ks .direct_scf_tol = ks .screen_tol
@@ -42,7 +44,7 @@ def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
4244 n , exc , vxc = (0 ,0 ), 0 , 0
4345 else :
4446 max_memory = ks .max_memory - lib .current_memory ()[0 ]
45- n , exc , vxc = ni .nr_uks (mol , ks .grids , ks .xc , dm , max_memory = max_memory )
47+ n , exc , vxc = ni .nr_uks (mol , ks .grids , ks .xc , dm . view ( cupy . ndarray ) , max_memory = max_memory )
4648 logger .debug (ks , 'nelec by numeric integration = %s' , n )
4749 if ks .do_nlc ():
4850 if ni .libxc .is_nlc (ks .xc ):
@@ -61,7 +63,10 @@ def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
6163 vk = None
6264 if (ks ._eri is None and ks .direct_scf and
6365 getattr (vhf_last , 'vj' , None ) is not None ):
64- ddm = cupy .asarray (dm ) - cupy .asarray (dm_last )
66+ dm_last = cupy .asarray (dm_last )
67+ dm = cupy .asarray (dm )
68+ assert dm_last .ndim == 0 or dm_last .ndim == dm .ndim
69+ ddm = dm - dm_last
6570 vj = ks .get_j (mol , ddm [0 ]+ ddm [1 ], hermi )
6671 vj += vhf_last .vj
6772 else :
@@ -71,7 +76,10 @@ def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
7176 omega , alpha , hyb = ni .rsh_and_hybrid_coeff (ks .xc , spin = mol .spin )
7277 if (ks ._eri is None and ks .direct_scf and
7378 getattr (vhf_last , 'vk' , None ) is not None ):
74- ddm = cupy .asarray (dm ) - cupy .asarray (dm_last )
79+ dm_last = cupy .asarray (dm_last )
80+ dm = cupy .asarray (dm )
81+ assert dm_last .ndim == 0 or dm_last .ndim == dm .ndim
82+ ddm = dm - dm_last
7583 vj , vk = ks .get_jk (mol , ddm , hermi )
7684 vk *= hyb
7785 if abs (omega ) > 1e-10 : # For range separated Coulomb operator
@@ -113,19 +121,15 @@ def energy_elec(ks, dm=None, h1e=None, vhf=None):
113121
114122
115123class UKS (rks .KohnShamDFT , uhf .UHF ):
116- from gpu4pyscf .lib .utils import to_gpu , device
117- _keys = {'disp' , 'screen_tol' }
118-
119- def __init__ (self , mol , xc = 'LDA,VWN' , disp = None ):
124+ def __init__ (self , mol , xc = 'LDA,VWN' ):
120125 uhf .UHF .__init__ (self , mol )
121126 rks .KohnShamDFT .__init__ (self , xc )
122- self .disp = disp
123127
124128 get_veff = get_veff
125- get_vasp = uks .get_vsap
129+ get_vasp = uks_cpu .get_vsap
126130 energy_elec = energy_elec
127131 energy_tot = hf .RHF .energy_tot
128- init_guess_by_vsap = uks .UKS .init_guess_by_vsap
132+ init_guess_by_vsap = uks_cpu .UKS .init_guess_by_vsap
129133
130134 to_hf = NotImplemented
131135
@@ -141,9 +145,10 @@ def nuc_grad_method(self):
141145 from gpu4pyscf .grad import uks as uks_grad
142146 return uks_grad .Gradients (self )
143147
148+ to_gpu = utils .to_gpu
149+ device = utils .device
150+
144151 def to_cpu (self ):
145- from gpu4pyscf .lib import utils
146- mf = uks .UKS (self .mol , xc = self .xc )
147- mf .disp = self .disp
152+ mf = uks_cpu .UKS (self .mol , xc = self .xc )
148153 utils .to_cpu (self , mf )
149154 return mf
0 commit comments