|
| 1 | +use crate::prelude::*; |
| 2 | + |
| 3 | +pub fn get_riccsd_pt_energy( |
| 4 | + mol_info: &RCCSDInfo, |
| 5 | + ccsd_intermediates: &RCCSDIntermediates, |
| 6 | + ccsd_results: &RCCSDResults, |
| 7 | +) -> RCCSDTResults { |
| 8 | + let time_outer = std::time::Instant::now(); |
| 9 | + |
| 10 | + let nocc = mol_info.nocc(); |
| 11 | + let nvir = mol_info.nvir(); |
| 12 | + let nmo = nocc + nvir; |
| 13 | + let mo_energy = &mol_info.mo_energy; |
| 14 | + |
| 15 | + let time = std::time::Instant::now(); |
| 16 | + let intermediates = ri_rccsdt_slow::prepare_intermediates(mol_info, ccsd_intermediates, ccsd_results); |
| 17 | + println!("Time elapsed (CCSD(T) preparation): {:?}", time.elapsed()); |
| 18 | + |
| 19 | + let t1_t = intermediates.t1_t.as_ref().unwrap(); |
| 20 | + let t2_t = intermediates.t2_t.as_ref().unwrap(); |
| 21 | + let eri_vvoo_t = intermediates.eri_vvoo_t.as_ref().unwrap(); |
| 22 | + let eri_vooo_t = intermediates.eri_vooo_t.as_ref().unwrap(); |
| 23 | + let eri_vvov_t = intermediates.eri_vvov_t.as_ref().unwrap(); |
| 24 | + |
| 25 | + let device = t1_t.device().clone(); |
| 26 | + |
| 27 | + let time = std::time::Instant::now(); |
| 28 | + |
| 29 | + let wp: Tsr = unsafe { rt::empty(([nvir, nvir, nvir, nocc, nocc, nocc], &device)) }; |
| 30 | + (0..nvir).into_par_iter().for_each(|a| { |
| 31 | + (0..nvir).into_par_iter().for_each(|b| { |
| 32 | + (0..nvir).into_par_iter().for_each(|c| unsafe { |
| 33 | + let mut wp = wp.force_mut(); |
| 34 | + let wp_1 = eri_vvov_t.i((a, b)) % t2_t.i(c).reshape((nvir, -1)); |
| 35 | + let wp_2 = t2_t.i((a, b)).t() % eri_vooo_t.i(c).reshape((nocc, -1)); |
| 36 | + wp.i_mut([a, b, c]).assign(wp_1.reshape((nocc, nocc, nocc)) - wp_2.reshape((nocc, nocc, nocc))); |
| 37 | + }); |
| 38 | + }); |
| 39 | + }); |
| 40 | + let w = wp.transpose((0, 1, 2, 3, 4, 5)) |
| 41 | + + wp.transpose((0, 2, 1, 3, 5, 4)) |
| 42 | + + wp.transpose((1, 0, 2, 4, 3, 5)) |
| 43 | + + wp.transpose((1, 2, 0, 4, 5, 3)) |
| 44 | + + wp.transpose((2, 0, 1, 5, 3, 4)) |
| 45 | + + wp.transpose((2, 1, 0, 5, 4, 3)); |
| 46 | + let v = &w |
| 47 | + + t1_t.i((.., None, None, .., None, None)) * eri_vvoo_t.i((None, .., .., None, .., ..)) |
| 48 | + + t1_t.i((None, .., None, None, .., None)) * eri_vvoo_t.i((.., None, .., .., None, ..)) |
| 49 | + + t1_t.i((None, None, .., None, None, ..)) * eri_vvoo_t.i((.., .., None, .., .., None)); |
| 50 | + let (so, sv) = (slice!(0, nocc), slice!(nocc, nmo)); |
| 51 | + let d = -mo_energy.i((sv, None, None, None, None, None)) |
| 52 | + - mo_energy.i((None, sv, None, None, None, None)) |
| 53 | + - mo_energy.i((None, None, sv, None, None, None)) |
| 54 | + + mo_energy.i((None, None, None, so, None, None)) |
| 55 | + + mo_energy.i((None, None, None, None, so, None)) |
| 56 | + + mo_energy.i((None, None, None, None, None, so)); |
| 57 | + let wt = 4.0_f64 * &w + w.transpose((1, 2, 0, 3, 4, 5)) + w.transpose((2, 0, 1, 3, 4, 5)); |
| 58 | + let vt = &v - v.transpose((2, 1, 0, 3, 4, 5)); |
| 59 | + let e_corr_pt = (wt * vt / &d).sum() / 3.0; |
| 60 | + |
| 61 | + println!("Time elapsed (CCSD(T) computation): {:?}", time.elapsed()); |
| 62 | + println!("Total time elapsed (CCSD(T) energy): {:?}", time_outer.elapsed()); |
| 63 | + println!("Time elapsed (CCSD(T) energy): {:?}", time.elapsed()); |
| 64 | + |
| 65 | + RCCSDTResults { e_corr_pt } |
| 66 | +} |
0 commit comments