Skip to content

Commit e5eeeb3

Browse files
committed
wip transpose + contract
1 parent cec4ca5 commit e5eeeb3

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

forte2/jkbuilder/jkbuilder.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -528,12 +528,14 @@ def __init__(
528528
system,
529529
use_aux_corr=False,
530530
memory_threshold_mb=4000,
531+
store_B_nPm=False,
531532
metric_ortho_rtol=None,
532533
):
533534
self.system = system
534535
self.use_aux_corr = use_aux_corr
535536
self.nbf = system.nbf
536537
self.memory_threshold_mb = memory_threshold_mb
538+
self.store_B_nPm = store_B_nPm
537539
self.auxbasis = (
538540
self.system.auxiliary_basis_corr
539541
if self.use_aux_corr
@@ -565,6 +567,8 @@ def _allocate_buffers(self):
565567
nbytes = 16 if _cmplx else 8
566568
# number of buffers of variable types (Pmn is always real, but Pmi/Qmi are complex for two-component systems)
567569
nbuf_vt = 3 if _cmplx else 2
570+
# the tranposed buffer makes K builds much faster
571+
nbuf_vt += 1 if self.store_B_nPm else 0
568572
# total size = 8 * nb^2 p + nbytes * nbuf_vt * nb * na * i ~= (nbuf_vt * nbytes + 8) * nb * na * i
569573
total_bytes_per_iblk = (nbuf_vt * nbytes + 8) * self.nbf * self.naux
570574
self.iblksize = min(
@@ -603,6 +607,11 @@ def _allocate_buffers(self):
603607
(self.naux, self.nbf, self.iblksize),
604608
dtype=np.complex128 if _cmplx else float,
605609
)
610+
if self.store_B_nPm:
611+
self._mPi_buf = np.zeros(
612+
(self.nbf, self.naux, self.iblksize),
613+
dtype=np.complex128 if _cmplx else float,
614+
)
606615
alloc_size_mb_Q = self.naux * self.nbf * self.iblksize * nbytes / 1024**2
607616
logger.log_info1(
608617
f"[FockBuilderOTF]: Allocated buffers for X_Qm[i] and X_Pm[i] with shape {self._Qmi_buf.shape} and size {alloc_size_mb_Q*nbuf_vt:.2f} MB"
@@ -727,12 +736,27 @@ def _K_kernel(self, C):
727736
optimize=True,
728737
out=self._Pmi_buf[:, :, : i1 - i0],
729738
)
730-
K += np.einsum(
731-
"Pmi,Pni->mn",
732-
self._Pmi_buf[:, :, : i1 - i0],
733-
self._Pmi_buf[:, :, : i1 - i0].conj(),
734-
optimize=True,
735-
)
739+
if self.store_B_nPm:
740+
# store the transposed Pmi buffer for faster K builds
741+
np.einsum(
742+
"Pmi->mPi",
743+
self._Pmi_buf[:, :, : i1 - i0],
744+
optimize=True,
745+
out=self._mPi_buf[:, :, : i1 - i0],
746+
)
747+
K += np.einsum(
748+
"mPi,nPi->mn",
749+
self._mPi_buf[:, :, : i1 - i0].conj(),
750+
self._mPi_buf[:, :, : i1 - i0],
751+
optimize=True,
752+
)
753+
else:
754+
K += np.einsum(
755+
"Pmi,Pni->mn",
756+
self._Pmi_buf[:, :, : i1 - i0],
757+
self._Pmi_buf[:, :, : i1 - i0].conj(),
758+
optimize=True,
759+
)
736760
i0 = i1
737761
return K
738762

0 commit comments

Comments
 (0)