@@ -528,12 +528,14 @@ def __init__(
528528 system ,
529529 use_aux_corr = False ,
530530 memory_threshold_mb = 4000 ,
531+ store_B_nPm = False ,
531532 metric_ortho_rtol = None ,
532533 ):
533534 self .system = system
534535 self .use_aux_corr = use_aux_corr
535536 self .nbf = system .nbf
536537 self .memory_threshold_mb = memory_threshold_mb
538+ self .store_B_nPm = store_B_nPm
537539 self .auxbasis = (
538540 self .system .auxiliary_basis_corr
539541 if self .use_aux_corr
@@ -565,6 +567,8 @@ def _allocate_buffers(self):
565567 nbytes = 16 if _cmplx else 8
566568 # number of buffers of variable types (Pmn is always real, but Pmi/Qmi are complex for two-component systems)
567569 nbuf_vt = 3 if _cmplx else 2
570+ # the tranposed buffer makes K builds much faster
571+ nbuf_vt += 1 if self .store_B_nPm else 0
568572 # total size = 8 * nb^2 p + nbytes * nbuf_vt * nb * na * i ~= (nbuf_vt * nbytes + 8) * nb * na * i
569573 total_bytes_per_iblk = (nbuf_vt * nbytes + 8 ) * self .nbf * self .naux
570574 self .iblksize = min (
@@ -603,6 +607,11 @@ def _allocate_buffers(self):
603607 (self .naux , self .nbf , self .iblksize ),
604608 dtype = np .complex128 if _cmplx else float ,
605609 )
610+ if self .store_B_nPm :
611+ self ._mPi_buf = np .zeros (
612+ (self .nbf , self .naux , self .iblksize ),
613+ dtype = np .complex128 if _cmplx else float ,
614+ )
606615 alloc_size_mb_Q = self .naux * self .nbf * self .iblksize * nbytes / 1024 ** 2
607616 logger .log_info1 (
608617 f"[FockBuilderOTF]: Allocated buffers for X_Qm[i] and X_Pm[i] with shape { self ._Qmi_buf .shape } and size { alloc_size_mb_Q * nbuf_vt :.2f} MB"
@@ -727,12 +736,27 @@ def _K_kernel(self, C):
727736 optimize = True ,
728737 out = self ._Pmi_buf [:, :, : i1 - i0 ],
729738 )
730- K += np .einsum (
731- "Pmi,Pni->mn" ,
732- self ._Pmi_buf [:, :, : i1 - i0 ],
733- self ._Pmi_buf [:, :, : i1 - i0 ].conj (),
734- optimize = True ,
735- )
739+ if self .store_B_nPm :
740+ # store the transposed Pmi buffer for faster K builds
741+ np .einsum (
742+ "Pmi->mPi" ,
743+ self ._Pmi_buf [:, :, : i1 - i0 ],
744+ optimize = True ,
745+ out = self ._mPi_buf [:, :, : i1 - i0 ],
746+ )
747+ K += np .einsum (
748+ "mPi,nPi->mn" ,
749+ self ._mPi_buf [:, :, : i1 - i0 ].conj (),
750+ self ._mPi_buf [:, :, : i1 - i0 ],
751+ optimize = True ,
752+ )
753+ else :
754+ K += np .einsum (
755+ "Pmi,Pni->mn" ,
756+ self ._Pmi_buf [:, :, : i1 - i0 ],
757+ self ._Pmi_buf [:, :, : i1 - i0 ].conj (),
758+ optimize = True ,
759+ )
736760 i0 = i1
737761 return K
738762
0 commit comments