Skip to content

Commit fdaa9f4

Browse files
authored
Multi-GPU for DF Gradient and Hessian (#270)
* multi-gpu for df hessian * remove tmp file * revert unnecessary changes * bug in eval_rho2 * bugfix && synchronize * Resolve comments
1 parent 894cd95 commit fdaa9f4

File tree

19 files changed

+1008
-705
lines changed

19 files changed

+1008
-705
lines changed

examples/00-h2o.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@
3636
atom=atom, # water molecule
3737
basis='def2-tzvpp', # basis set
3838
output='./pyscf.log', # save log file
39-
verbose=6 # control the level of print info
39+
verbose=6 # control the level of print info
4040
)
4141

4242
mf_GPU = rks.RKS( # restricted Kohn-Sham DFT
4343
mol, # pyscf.gto.object
44-
xc='b3lyp' # xc funtionals, such as pbe0, wb97m-v, tpss,
44+
xc='b3lyp' # xc funtionals, such as pbe0, wb97m-v, tpss,
4545
).density_fit() # density fitting
4646

4747
mf_GPU.grids.atom_grid = (99,590) # (99,590) lebedev grids, (75,302) is often enough

gpu4pyscf/__config__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,11 @@
2424
mem_fraction = 0.9
2525
cupy.get_default_memory_pool().set_limit(fraction=mem_fraction)
2626

27+
# Check P2P data transfer is available
28+
_p2p_access = True
29+
if _num_devices > 1:
30+
for src in range(_num_devices):
31+
for dst in range(_num_devices):
32+
if src != dst:
33+
can_access_peer = cupy.cuda.runtime.deviceCanAccessPeer(src, dst)
34+
_p2p_access &= can_access_peer

gpu4pyscf/df/df.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from cupyx.scipy.linalg import solve_triangular
2222
from pyscf import lib
2323
from pyscf.df import df, addons, incore
24-
from gpu4pyscf.lib.cupy_helper import cholesky, tag_array, get_avail_mem, cart2sph
24+
from gpu4pyscf.lib.cupy_helper import cholesky, tag_array, get_avail_mem, cart2sph, p2p_transfer
2525
from gpu4pyscf.df import int3c2e, df_jk
2626
from gpu4pyscf.lib import logger
2727
from gpu4pyscf import __config__
@@ -177,10 +177,10 @@ def loop(self, blksize=None, unpack=True):
177177
yield buf2, buf.T
178178
if isinstance(cderi_sparse, np.ndarray):
179179
cupy.cuda.Device().synchronize()
180-
180+
181181
if buf_prefetch is not None:
182182
buf = buf_prefetch
183-
183+
184184
def reset(self, mol=None):
185185
'''Reset mol and clean up relevant attributes for scanner mode'''
186186
if mol is not None:
@@ -208,13 +208,14 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low,
208208
npairs = len(intopt.cderi_row)
209209
log = logger.new_logger(mol, mol.verbose)
210210

211-
# if the matrix exceeds the limit, store CDERI in CPU memory
212-
# TODO: better estimate of memory consumption for each device
211+
# Available memory on Device 0.
213212
avail_mem = get_avail_mem()
214213

215214
if use_gpu_memory:
216-
# If GPU memory is not enough
217-
use_gpu_memory = naux * npairs * 8 < 0.4 * avail_mem
215+
# CDERI will be equally distributed to the devices
216+
# Other devices usually have more memory available than Device 0
217+
# CDERI will use up to 40% of the available memory
218+
use_gpu_memory = naux * npairs * 8 < 0.4 * avail_mem * _num_devices
218219

219220
if use_gpu_memory:
220221
log.debug("Saving CDERI on GPU")
@@ -244,9 +245,7 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low,
244245
cd_low_f = cupy.array(cd_low, order='F', copy=False)
245246
cd_low_f = tag_array(cd_low_f, tag=cd_low.tag)
246247

247-
for gpu_id in range(_num_devices):
248-
cupy.cuda.Device(gpu_id).synchronize()
249-
248+
cupy.cuda.get_current_stream().synchronize()
250249
futures = []
251250
with ThreadPoolExecutor(max_workers=_num_devices) as executor:
252251
for device_id in range(_num_devices):
@@ -258,9 +257,6 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low,
258257
for future in futures:
259258
future.result()
260259

261-
for device_id in range(_num_devices):
262-
cupy.cuda.Device(device_id).synchronize()
263-
264260
if not use_gpu_memory:
265261
cupy.cuda.Device().synchronize()
266262

@@ -344,14 +340,14 @@ def _cderi_task(intopt, cd_low, task_list, _cderi, omega=None, sr_only=False, de
344340
# if CDERI is saved on CPU
345341
ij0 = pairs_loc[cp_ij_id]
346342
ij1 = pairs_loc[cp_ij_id+1]
347-
if isinstance(_cderi, np.ndarray):
343+
if isinstance(_cderi[0], np.ndarray):
348344
for slice_id, (p0,p1) in enumerate(lib.prange(0, naux, blksize)):
349345
for i in range(p0,p1):
350-
cderi_block[i].get(out=_cderi[slice_id][i,ij0:ij1])
346+
cderi_block[i].get(out=_cderi[slice_id][i-p0,ij0:ij1])
351347
else:
352348
# Copy data to other Devices
353349
for slice_id, (p0,p1) in enumerate(lib.prange(0, naux, blksize)):
354-
_cderi[slice_id][:,ij0:ij1] = cderi_block[p0:p1]
355-
350+
#_cderi[slice_id][:,ij0:ij1] = cderi_block[p0:p1]
351+
p2p_transfer(_cderi[slice_id][:,ij0:ij1], cderi_block[p0:p1])
356352
t1 = log.timer_debug1(f'transfer data for {cp_ij_id} / {nq} on Device {device_id}', *t1)
357353
return

gpu4pyscf/df/df_jk.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,7 @@ def _jk_task_with_mo(dfobj, dms, mo_coeff, mo_occ,
298298
rhok = rhok.reshape([-1,nao])
299299
vk[i] += cupy.dot(rhok.T, rhok)
300300
rhok = None
301-
cupy.cuda.get_current_stream().synchronize()
302-
301+
303302
if with_j:
304303
vj = cupy.zeros(dms_shape)
305304
vj[:,rows,cols] = vj_packed
@@ -390,13 +389,12 @@ def _jk_task_with_dm(dfobj, dms, with_j=True, with_k=True, hermi=0, device_id=0)
390389
else:
391390
dm_sparse *= 2
392391
dm_sparse[:, intopt.cderi_diag] *= .5
393-
392+
vj_sparse = cupy.zeros_like(dm_sparse)
393+
394394
if with_k:
395395
vk = cupy.zeros_like(dms)
396396

397397
nset = dms.shape[0]
398-
if with_j:
399-
vj_sparse = cupy.zeros_like(dm_sparse)
400398
blksize = dfobj.get_blksize()
401399
for cderi, cderi_sparse in dfobj.loop(blksize=blksize, unpack=with_k):
402400
if with_j:
@@ -406,7 +404,7 @@ def _jk_task_with_dm(dfobj, dms, with_j=True, with_k=True, hermi=0, device_id=0)
406404
for k in range(nset):
407405
rhok = contract('Lij,jk->Lki', cderi, dms[k]).reshape([-1,nao])
408406
#vk[k] += contract('Lki,Lkj->ij', rhok, cderi)
409-
vk[k] += cupy.dot(rhok.T, cderi.reshape([-1,nao]))
407+
vk[k] += cupy.dot(rhok.T, cderi.reshape([-1,nao]))
410408
if with_j:
411409
vj = cupy.zeros(dms_shape)
412410
vj[:,rows,cols] = vj_sparse
@@ -445,6 +443,7 @@ def get_jk(dfobj, dms_tag, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-
445443
intopt = dfobj.intopt
446444
dms = intopt.sort_orbitals(dms, axis=[1,2])
447445

446+
cupy.cuda.get_current_stream().synchronize()
448447
if getattr(dms_tag, 'mo_coeff', None) is not None:
449448
mo_occ = dms_tag.mo_occ
450449
mo_coeff = dms_tag.mo_coeff
@@ -498,13 +497,13 @@ def get_jk(dfobj, dms_tag, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-
498497
vj = vk = None
499498
if with_j:
500499
vj = [future.result()[0] for future in futures]
501-
vj = reduce_to_device(vj)
500+
vj = reduce_to_device(vj, inplace=True)
502501
vj = intopt.unsort_orbitals(vj, axis=[1,2])
503502
vj = vj.reshape(out_shape)
504-
503+
505504
if with_k:
506505
vk = [future.result()[1] for future in futures]
507-
vk = reduce_to_device(vk)
506+
vk = reduce_to_device(vk, inplace=True)
508507
vk = intopt.unsort_orbitals(vk, axis=[1,2])
509508
vk = vk.reshape(out_shape)
510509

gpu4pyscf/df/grad/jk.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from concurrent.futures import ThreadPoolExecutor
1717
import cupy
18-
from gpu4pyscf.lib.cupy_helper import contract
18+
from gpu4pyscf.lib.cupy_helper import contract, concatenate
1919
from gpu4pyscf.lib import logger
2020
from gpu4pyscf.__config__ import _streams, _num_devices
2121

@@ -58,6 +58,7 @@ def get_rhoj_rhok(with_df, dm, orbo, with_j=True, with_k=True):
5858
''' Calculate rhoj and rhok on Multi-GPU system
5959
'''
6060
futures = []
61+
cupy.cuda.get_current_stream().synchronize()
6162
with ThreadPoolExecutor(max_workers=_num_devices) as executor:
6263
for device_id in range(_num_devices):
6364
future = executor.submit(
@@ -74,8 +75,8 @@ def get_rhoj_rhok(with_df, dm, orbo, with_j=True, with_k=True):
7475

7576
rhoj = rhok = None
7677
if with_j:
77-
rhoj = cupy.concatenate(rhoj_total)
78+
rhoj = concatenate(rhoj_total)
7879
if with_k:
79-
rhok = cupy.concatenate(rhok_total)
80+
rhok = concatenate(rhok_total)
8081

8182
return rhoj, rhok

gpu4pyscf/df/hessian/rhf.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,29 @@ def partial_hess_elec(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
5454
atmlst, max_memory, verbose, True)
5555
return e1 + ej - ek
5656

57+
def _hk_ip1_ip1(rhok1_Pko, dm0, mocc_2):
58+
''' hk contributions due to (10|0)(0|10) + (10|0)(0|01)
59+
'''
60+
nnz = rhok1_Pko.shape[0]
61+
nao = dm0.shape[0]
62+
mem_avail = get_avail_mem()
63+
blksize = int((mem_avail*0.4/(nao*nao*3*8)/ALIGNED))*ALIGNED
64+
hk_ao_ao = cupy.zeros([nao,nao,3,3])
65+
for k0, k1 in lib.prange(0,nnz,blksize):
66+
rhok1_Pko_kslice = cupy.asarray(rhok1_Pko[k0:k1])
67+
68+
# (10|0)(0|10) without response of RI basis
69+
vk2_ip1_ip1 = contract('piox,pkoy->ikxy', rhok1_Pko_kslice, rhok1_Pko_kslice)
70+
hk_ao_ao += contract('ikxy,ik->ikxy', vk2_ip1_ip1, dm0)
71+
vk2_ip1_ip1 = None
72+
73+
# (10|0)(0|01) without response of RI basis
74+
rhok1_Pkl_kslice = contract('piox,ko->pikx', rhok1_Pko_kslice, mocc_2)
75+
hk_ao_ao += contract('pikx,pkiy->ikxy', rhok1_Pkl_kslice, rhok1_Pkl_kslice)
76+
rhok1_Pkl_kslice = None
77+
return hk_ao_ao
78+
79+
5780
def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
5881
atmlst=None, max_memory=4000, verbose=None, with_k=True, omega=None):
5982
'''Partial derivative
@@ -94,7 +117,7 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
94117
# ================================ sorted AO begin ===============================================
95118
intopt = int3c2e.VHFOpt(mol, auxmol, 'int2e')
96119
intopt.build(mf.direct_scf_tol, diag_block_with_triu=True, aosym=False, group_size=BLKSIZE, group_size_aux=BLKSIZE)
97-
naux = auxmol.nao #len(aux_ao_idx)
120+
naux = auxmol.nao
98121
mocc_2 = intopt.sort_orbitals(mocc_2, axis=[0])
99122
dm0 = intopt.sort_orbitals(dm0, axis=[0,1])
100123
dm0_tag = tag_array(dm0, occ_coeff=mocc_2)
@@ -118,7 +141,6 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
118141
rhoj0_P = solve_j2c(wj)
119142
rhok0_P__ = solve_j2c(wk_P__)
120143
wj = wk_P__ = None
121-
t1 = log.timer_debug1('intermediate variables with int3c2e', *t1)
122144

123145
# int3c_ip2 contributions
124146
wj_ip2, wk_ip2_P__ = int3c2e.get_int3c2e_ip2_wjk(intopt, dm0_tag, omega=omega)
@@ -188,36 +210,20 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
188210
rhok1_Pko[:,i0:i1] = contract('qp,qiox->piox', cd_low, wk1_tmp).get()
189211
wk1_tmp = None
190212
cd_low = None
191-
192-
mem_avail = get_avail_mem()
193-
blksize = int((mem_avail*0.4/(nao*nao*3*8)/ALIGNED))*ALIGNED
194-
log.debug(f'GPU Memory {mem_avail/GB:.1f} GB available, {blksize} aux AOs per block')
195-
for k0, k1 in lib.prange(0,nnz,blksize):
196-
rhok1_Pko_kslice = cupy.asarray(rhok1_Pko[k0:k1])
197-
198-
# (10|0)(0|10) without response of RI basis
199-
vk2_ip1_ip1 = contract('piox,pkoy->ikxy', rhok1_Pko_kslice, rhok1_Pko_kslice)
200-
hk_ao_ao += contract('ikxy,ik->ikxy', vk2_ip1_ip1, dm0)
201-
vk2_ip1_ip1 = None
202-
203-
# (10|0)(0|01) without response of RI basis
204-
rhok1_Pkl_kslice = contract('piox,ko->pikx', rhok1_Pko_kslice, mocc_2)
205-
hk_ao_ao += contract('pikx,pkiy->ikxy', rhok1_Pkl_kslice, rhok1_Pkl_kslice)
206-
rhok1_Pkl_kslice = None
207-
rhok1_Pko_kslice = None
208-
213+
214+
hk_ao_ao += _hk_ip1_ip1(rhok1_Pko, dm0, mocc_2)
209215
wk1_Pko = rhok1_Pko = None
210216
t1 = log.timer_debug1('intermediate variables with int3c2e_ip1', *t1)
211217

212218
cupy.get_default_memory_pool().free_all_blocks()
213219
# int3c_ipip1 contributions
214-
hj_ao_diag, hk_ao_diag = int3c2e.get_int3c2e_ipip1_hjk(intopt, rhoj0_P, rhok0_P__, dm0_tag, omega=omega, with_k=with_k)
220+
hj_ao_diag, hk_ao_diag = int3c2e.get_int3c2e_hjk(intopt, 'ipip1', rhoj0_P, rhok0_P__, dm0_tag, omega=omega, with_k=with_k)
215221
hj_ao_diag *= 2.0
216222
t1 = log.timer_debug1('intermediate variables with int3c2e_ipip1', *t1)
217223

218224
# int3c_ipvip1 contributions
219225
# (11|0), (0|00) without response of RI basis
220-
hj, hk = int3c2e.get_int3c2e_ipvip1_hjk(intopt, rhoj0_P, rhok0_P__, dm0_tag, omega=omega, with_k=with_k)
226+
hj, hk = int3c2e.get_int3c2e_hjk(intopt, 'ipvip1', rhoj0_P, rhok0_P__, dm0_tag, omega=omega, with_k=with_k)
221227
hj_ao_ao += 2.0*hj
222228
if with_k:
223229
hk_ao_ao += hk
@@ -227,7 +233,7 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
227233
# int3c_ip1ip2 contributions
228234
# (10|1), (0|0)(0|00)
229235
if hessobj.auxbasis_response:
230-
hj, hk = int3c2e.get_int3c2e_ip1ip2_hjk(intopt, rhoj0_P, rhok0_P__, dm0_tag, omega=omega, with_k=with_k)
236+
hj, hk = int3c2e.get_int3c2e_hjk(intopt, 'ip1ip2', rhoj0_P, rhok0_P__, dm0_tag, omega=omega, with_k=with_k)
231237
hj_ao_aux += hj
232238
if with_k:
233239
hk_ao_aux += hk
@@ -237,7 +243,7 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
237243
# int3c_ipip2 contributions
238244
if hessobj.auxbasis_response > 1:
239245
# (00|2), (0|0)(0|00)
240-
hj, hk = int3c2e.get_int3c2e_ipip2_hjk(intopt, rhoj0_P, rhok0_P__, dm0_tag, omega=omega, with_k=with_k)
246+
hj, hk = int3c2e.get_int3c2e_hjk(intopt, 'ipip2', rhoj0_P, rhok0_P__, dm0_tag, omega=omega, with_k=with_k)
241247
hj_aux_diag = hj
242248
if with_k:
243249
hk_aux_diag = .5*hk

gpu4pyscf/df/hessian/uhf.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -221,19 +221,19 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
221221

222222
cupy.get_default_memory_pool().free_all_blocks()
223223
# int3c_ipip1 contributions
224-
fn = int3c2e.get_int3c2e_ipip1_hjk
225-
hja_ao_diag, hka_ao_diag = fn(intopt, rhoj0_P, rhok0a_P__, dm0a_tag, omega=omega, with_k=with_k)
226-
hjb_ao_diag, hkb_ao_diag = fn(intopt, rhoj0_P, rhok0b_P__, dm0b_tag, omega=omega, with_k=with_k)
224+
fn = int3c2e.get_int3c2e_hjk
225+
hja_ao_diag, hka_ao_diag = fn(intopt, 'ipip1', rhoj0_P, rhok0a_P__, dm0a_tag, omega=omega, with_k=with_k)
226+
hjb_ao_diag, hkb_ao_diag = fn(intopt, 'ipip1', rhoj0_P, rhok0b_P__, dm0b_tag, omega=omega, with_k=with_k)
227227
hj_ao_diag = 2.0 * (hja_ao_diag + hjb_ao_diag)
228228
if with_k:
229229
hk_ao_diag = 2.0 * (hka_ao_diag + hkb_ao_diag)
230230
t1 = log.timer_debug1('intermediate variables with int3c2e_ipip1', *t1)
231231

232232
# int3c_ipvip1 contributions
233233
# (11|0), (0|00) without response of RI basis
234-
fn = int3c2e.get_int3c2e_ipvip1_hjk
235-
hja, hka = fn(intopt, rhoj0_P, rhok0a_P__, dm0a_tag, omega=omega, with_k=with_k)
236-
hjb, hkb = fn(intopt, rhoj0_P, rhok0b_P__, dm0b_tag, omega=omega, with_k=with_k)
234+
fn = int3c2e.get_int3c2e_hjk
235+
hja, hka = fn(intopt, 'ipvip1', rhoj0_P, rhok0a_P__, dm0a_tag, omega=omega, with_k=with_k)
236+
hjb, hkb = fn(intopt, 'ipvip1', rhoj0_P, rhok0b_P__, dm0b_tag, omega=omega, with_k=with_k)
237237
hj_ao_ao += 2.0*(hja + hjb)
238238
if with_k:
239239
hk_ao_ao += (hka + hkb)
@@ -243,9 +243,9 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
243243
# int3c_ip1ip2 contributions
244244
# (10|1), (0|0)(0|00)
245245
if hessobj.auxbasis_response:
246-
fn = int3c2e.get_int3c2e_ip1ip2_hjk
247-
hja, hka = fn(intopt, rhoj0_P, rhok0a_P__, dm0a_tag, omega=omega, with_k=with_k)
248-
hjb, hkb = fn(intopt, rhoj0_P, rhok0b_P__, dm0b_tag, omega=omega, with_k=with_k)
246+
fn = int3c2e.get_int3c2e_hjk
247+
hja, hka = fn(intopt, 'ip1ip2', rhoj0_P, rhok0a_P__, dm0a_tag, omega=omega, with_k=with_k)
248+
hjb, hkb = fn(intopt, 'ip1ip2', rhoj0_P, rhok0b_P__, dm0b_tag, omega=omega, with_k=with_k)
249249
hj_ao_aux += hja + hjb
250250
if with_k:
251251
hk_ao_aux += hka + hkb
@@ -255,9 +255,9 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
255255
# int3c_ipip2 contributions
256256
if hessobj.auxbasis_response > 1:
257257
# (00|2), (0|0)(0|00)
258-
fn = int3c2e.get_int3c2e_ipip2_hjk
259-
hja, hka = fn(intopt, rhoj0_P, rhok0a_P__, dm0a_tag, omega=omega, with_k=with_k)
260-
hjb, hkb = fn(intopt, rhoj0_P, rhok0b_P__, dm0b_tag, omega=omega, with_k=with_k)
258+
fn = int3c2e.get_int3c2e_hjk
259+
hja, hka = fn(intopt, 'ipip2', rhoj0_P, rhok0a_P__, dm0a_tag, omega=omega, with_k=with_k)
260+
hjb, hkb = fn(intopt, 'ipip2', rhoj0_P, rhok0b_P__, dm0b_tag, omega=omega, with_k=with_k)
261261
hj_aux_diag = hja + hjb
262262
if with_k:
263263
hk_aux_diag = (hka + hkb)

0 commit comments

Comments
 (0)