Skip to content

Commit ff391b4

Browse files
committed
ri_ccsdt_naive: code
1 parent eb42b95 commit ff391b4

File tree

3 files changed

+108
-9
lines changed

3 files changed

+108
-9
lines changed

src/lib.rs

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub mod diis;
66
pub mod rhf_slow;
77
pub mod ri_rccsd;
88
pub mod ri_rccsdt;
9+
pub mod ri_rccsdt_naive;
910
pub mod ri_rccsdt_slow;
1011
pub mod ri_rhf;
1112
pub mod ri_rhf_slow;
@@ -61,6 +62,41 @@ fn playground_ri_ccsd() {
6162
fn playground_ri_ccsdt() {
6263
use crate::prelude::*;
6364

65+
let cint_data = CInt::from_json("assets/h2o-tzvp.json");
66+
let aux_cint_data = CInt::from_json("assets/h2o-def2_jk.json");
67+
let rhf_results = ri_rhf::minimal_ri_rhf(&cint_data, &aux_cint_data);
68+
69+
let ccsd_info = RCCSDInfo {
70+
cint_data,
71+
aux_cint_data,
72+
mo_coeff: rhf_results.mo_coeff.clone(),
73+
mo_energy: rhf_results.mo_energy.clone(),
74+
};
75+
let ccsd_config = CCSDConfig::default();
76+
77+
let (ccsd_results, ccsd_intrm) = ri_rccsd::riccsd_iteration(&ccsd_info, &ccsd_config);
78+
println!("CCSD Corr Energy: {}", ccsd_results.e_corr);
79+
80+
println!("======");
81+
82+
let ccsdt_results = ri_rccsdt_naive::get_riccsd_pt_energy(&ccsd_info, &ccsd_intrm, &ccsd_results);
83+
println!("CCSD(T) Perturb Energy (naive algorithm): {}", ccsdt_results.e_corr_pt);
84+
85+
println!("======");
86+
87+
let ccsdt_results = ri_rccsdt_slow::get_riccsd_pt_energy(&ccsd_info, &ccsd_intrm, &ccsd_results);
88+
println!("CCSD(T) Perturb Energy (slow algorithm): {}", ccsdt_results.e_corr_pt);
89+
90+
println!("======");
91+
92+
let ccsdt_results = ri_rccsdt::get_riccsd_pt_energy(&ccsd_info, &ccsd_intrm, &ccsd_results);
93+
println!("CCSD(T) Perturb Energy (fast algorithm): {}", ccsdt_results.e_corr_pt);
94+
}
95+
96+
#[test]
97+
fn playground_ri_ccsdt_efficiency() {
98+
use crate::prelude::*;
99+
64100
let cint_data = CInt::from_json("assets/h2o_5-pvdz.json");
65101
let aux_cint_data = CInt::from_json("assets/h2o_5-pvdz_ri.json");
66102
let rhf_results = ri_rhf::minimal_ri_rhf(&cint_data, &aux_cint_data);
@@ -71,21 +107,18 @@ fn playground_ri_ccsdt() {
71107
mo_coeff: rhf_results.mo_coeff.clone(),
72108
mo_energy: rhf_results.mo_energy.clone(),
73109
};
110+
let ccsd_config = CCSDConfig::default();
74111

75-
// do not run actual CCSD, but use faked T1 amplitudes
76-
let mut ccsd_intrm = RCCSDIntermediates::default();
77-
ri_rccsd::get_riccsd_intermediates_cderi(&ccsd_info, &mut ccsd_intrm);
78-
let mut ccsd_results = ri_rccsd::get_riccsd_initial_guess(&ccsd_info, &ccsd_intrm);
79-
ccsd_results.t1 =
80-
rt::arange((ccsd_results.t1.size() as f64, ccsd_results.t1.device())).into_shape(ccsd_results.t1.shape());
112+
let (ccsd_results, ccsd_intrm) = ri_rccsd::riccsd_iteration(&ccsd_info, &ccsd_config);
113+
println!("CCSD Corr Energy: {}", ccsd_results.e_corr);
81114

82115
println!("======");
83116

84117
let ccsdt_results = ri_rccsdt_slow::get_riccsd_pt_energy(&ccsd_info, &ccsd_intrm, &ccsd_results);
85-
println!("CCSD(T) Perturb Energy: {}", ccsdt_results.e_corr_pt);
118+
println!("CCSD(T) Perturb Energy (slow algorithm): {}", ccsdt_results.e_corr_pt);
86119

87120
println!("======");
88121

89122
let ccsdt_results = ri_rccsdt::get_riccsd_pt_energy(&ccsd_info, &ccsd_intrm, &ccsd_results);
90-
println!("CCSD(T) Perturb Energy: {}", ccsdt_results.e_corr_pt);
123+
println!("CCSD(T) Perturb Energy (fast algorithm): {}", ccsdt_results.e_corr_pt);
91124
}

src/ri_rccsdt_naive.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
use crate::prelude::*;
2+
3+
pub fn get_riccsd_pt_energy(
4+
mol_info: &RCCSDInfo,
5+
ccsd_intermediates: &RCCSDIntermediates,
6+
ccsd_results: &RCCSDResults,
7+
) -> RCCSDTResults {
8+
let time_outer = std::time::Instant::now();
9+
10+
let nocc = mol_info.nocc();
11+
let nvir = mol_info.nvir();
12+
let nmo = nocc + nvir;
13+
let mo_energy = &mol_info.mo_energy;
14+
15+
let time = std::time::Instant::now();
16+
let intermediates = ri_rccsdt_slow::prepare_intermediates(mol_info, ccsd_intermediates, ccsd_results);
17+
println!("Time elapsed (CCSD(T) preparation): {:?}", time.elapsed());
18+
19+
let t1_t = intermediates.t1_t.as_ref().unwrap();
20+
let t2_t = intermediates.t2_t.as_ref().unwrap();
21+
let eri_vvoo_t = intermediates.eri_vvoo_t.as_ref().unwrap();
22+
let eri_vooo_t = intermediates.eri_vooo_t.as_ref().unwrap();
23+
let eri_vvov_t = intermediates.eri_vvov_t.as_ref().unwrap();
24+
25+
let device = t1_t.device().clone();
26+
27+
let time = std::time::Instant::now();
28+
29+
let wp: Tsr = unsafe { rt::empty(([nvir, nvir, nvir, nocc, nocc, nocc], &device)) };
30+
(0..nvir).into_par_iter().for_each(|a| {
31+
(0..nvir).into_par_iter().for_each(|b| {
32+
(0..nvir).into_par_iter().for_each(|c| unsafe {
33+
let mut wp = wp.force_mut();
34+
let wp_1 = eri_vvov_t.i((a, b)) % t2_t.i(c).reshape((nvir, -1));
35+
let wp_2 = t2_t.i((a, b)).t() % eri_vooo_t.i(c).reshape((nocc, -1));
36+
wp.i_mut([a, b, c]).assign(wp_1.reshape((nocc, nocc, nocc)) - wp_2.reshape((nocc, nocc, nocc)));
37+
});
38+
});
39+
});
40+
let w = wp.transpose((0, 1, 2, 3, 4, 5))
41+
+ wp.transpose((0, 2, 1, 3, 5, 4))
42+
+ wp.transpose((1, 0, 2, 4, 3, 5))
43+
+ wp.transpose((1, 2, 0, 4, 5, 3))
44+
+ wp.transpose((2, 0, 1, 5, 3, 4))
45+
+ wp.transpose((2, 1, 0, 5, 4, 3));
46+
let v = &w
47+
+ t1_t.i((.., None, None, .., None, None)) * eri_vvoo_t.i((None, .., .., None, .., ..))
48+
+ t1_t.i((None, .., None, None, .., None)) * eri_vvoo_t.i((.., None, .., .., None, ..))
49+
+ t1_t.i((None, None, .., None, None, ..)) * eri_vvoo_t.i((.., .., None, .., .., None));
50+
let (so, sv) = (slice!(0, nocc), slice!(nocc, nmo));
51+
let d = -mo_energy.i((sv, None, None, None, None, None))
52+
- mo_energy.i((None, sv, None, None, None, None))
53+
- mo_energy.i((None, None, sv, None, None, None))
54+
+ mo_energy.i((None, None, None, so, None, None))
55+
+ mo_energy.i((None, None, None, None, so, None))
56+
+ mo_energy.i((None, None, None, None, None, so));
57+
let wt = 4.0_f64 * &w + w.transpose((1, 2, 0, 3, 4, 5)) + w.transpose((2, 0, 1, 3, 4, 5));
58+
let vt = &v - v.transpose((2, 1, 0, 3, 4, 5));
59+
let e_corr_pt = (wt * vt / &d).sum() / 3.0;
60+
61+
println!("Time elapsed (CCSD(T) computation): {:?}", time.elapsed());
62+
println!("Total time elapsed (CCSD(T) energy): {:?}", time_outer.elapsed());
63+
println!("Time elapsed (CCSD(T) energy): {:?}", time.elapsed());
64+
65+
RCCSDTResults { e_corr_pt }
66+
}

src/ri_rccsdt_slow.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ fn ccsd_t_energy_contribution(abc: [usize; 3], mol_info: &RCCSDInfo, intermediat
9797
+ get_w([c, b, a], intermediates).transpose([2, 1, 0]);
9898
let v = &w
9999
+ t1_t.i((a, .., None, None)) * eri_vvoo_t.i([b, c]).i((None, .., ..))
100-
+ t1_t.i((b, None, .., None)) * eri_vvoo_t.i([c, a]).t().i((.., None, ..))
100+
+ t1_t.i((b, None, .., None)) * eri_vvoo_t.i([a, c]).i((.., None, ..))
101101
+ t1_t.i((c, None, None, ..)) * eri_vvoo_t.i([a, b]).i((.., .., None));
102102
let d = -(ev[[a]] + ev[[b]] + ev[[c]]) + d_ooo;
103103
let z = 4.0 * &w + w.transpose([1, 2, 0]) + w.transpose([2, 0, 1])

0 commit comments

Comments
 (0)