@@ -21,7 +21,7 @@ class AXTest : public testing::Test
2121{
2222public:
2323 std::vector<matsize> sizes{
24- // {2, 3, 2, 1},
24+ // {2, 3, 2, 1}
2525 {2 , 13 , 7 , 4 },
2626 {2 , 14 , 8 , 5 }
2727 };
@@ -61,32 +61,32 @@ TEST_F(AXTest, DoubleSerial)
6161{
6262 for (auto s : this ->sizes )
6363 {
64- psi::Psi<double , base_device::DEVICE_CPU> X (s.nks , nstate, s.nocc * s.nvirt , nullptr , false );
65- psi::Psi<double , base_device::DEVICE_CPU> AX_for (s.nks , nstate, s.nocc * s.nvirt , nullptr , false );
66- psi::Psi<double , base_device::DEVICE_CPU> AX_blas (s.nks , nstate, s.nocc * s.nvirt , nullptr , false );
64+ psi::Psi<double , base_device::DEVICE_CPU> X (s.nks , nstate, s.nocc * s.nvirt , {} , false );
65+ psi::Psi<double , base_device::DEVICE_CPU> AX_for (s.nks , nstate, s.nocc * s.nvirt , {} , false );
66+ psi::Psi<double , base_device::DEVICE_CPU> AX_blas (s.nks , nstate, s.nocc * s.nvirt , {} , false );
6767 const int size_x = nstate * s.nks * s.nocc * s.nvirt ;
6868 set_rand (X.get_pointer (), size_x);
6969
7070 const int size_c = s.nks * (s.nocc + s.nvirt ) * s.naos ;
7171 const int size_v = s.naos * s.naos ;
7272 for (int istate = 0 ;istate < nstate;++istate)
7373 {
74- psi::Psi<double , base_device::DEVICE_CPU> c (s.nks , s.nocc + s.nvirt , s.naos );
74+ psi::Psi<double , base_device::DEVICE_CPU> c (s.nks , s.nocc + s.nvirt , s.naos , {}, true );
7575 std::vector<container::Tensor> V (s.nks , container::Tensor (DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos , s.naos }));
7676 set_rand (c.get_pointer (), size_c);
7777 for (auto & v : V)set_rand (v.data <double >(), size_v);
7878 X.fix_b (istate);
7979 AX_for.fix_b (istate);
8080 AX_blas.fix_b (istate);
8181 // occ
82- CVCX_occ_forloop_serial (V, c, X, s.naos , s.nocc , s.nvirt , AX_for);
83- CVCX_occ_blas (V, c, X, s.naos , s.nocc , s.nvirt , AX_blas, false );
82+ LR:: CVCX_occ_forloop_serial (V, c, X. get_pointer () , s.naos , s.nocc , s.nvirt , AX_for. get_pointer () );
83+ LR:: CVCX_occ_blas (V, c, X. get_pointer () , s.naos , s.nocc , s.nvirt , AX_blas. get_pointer () , false );
8484 AX_for.fix_k (0 );
8585 AX_blas.fix_k (0 );
8686 check_eq (AX_for.get_pointer (), AX_blas.get_pointer (), s.nks * s.nocc * s.nvirt );
8787 // virt
88- CVCX_virt_forloop_serial (V, c, X, s.naos , s.nocc , s.nvirt , AX_for);
89- CVCX_virt_blas (V, c, X, s.naos , s.nocc , s.nvirt , AX_blas, false );
88+ LR:: CVCX_virt_forloop_serial (V, c, X. get_pointer () , s.naos , s.nocc , s.nvirt , AX_for. get_pointer () );
89+ LR:: CVCX_virt_blas (V, c, X. get_pointer () , s.naos , s.nocc , s.nvirt , AX_blas. get_pointer () , false );
9090 AX_for.fix_k (0 );
9191 AX_blas.fix_k (0 );
9292 check_eq (AX_for.get_pointer (), AX_blas.get_pointer (), s.nks * s.nocc * s.nvirt );
@@ -98,32 +98,32 @@ TEST_F(AXTest, ComplexSerial)
9898{
9999 for (auto s : this ->sizes )
100100 {
101- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> X (s.nks , nstate, s.nocc * s.nvirt , nullptr , false );
102- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_for (s.nks , nstate, s.nocc * s.nvirt , nullptr , false );
103- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_blas (s.nks , nstate, s.nocc * s.nvirt , nullptr , false );
101+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> X (s.nks , nstate, s.nocc * s.nvirt , {} , false );
102+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_for (s.nks , nstate, s.nocc * s.nvirt , {} , false );
103+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_blas (s.nks , nstate, s.nocc * s.nvirt , {} , false );
104104 const int size_x = nstate * s.nks * s.nocc * s.nvirt ;
105105 set_rand (X.get_pointer (), size_x);
106106
107107 int size_c = s.nks * (s.nocc + s.nvirt ) * s.naos ;
108108 int size_v = s.naos * s.naos ;
109109 for (int istate = 0 ;istate < nstate;++istate)
110110 {
111- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> c (s.nks , s.nocc + s.nvirt , s.naos );
111+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> c (s.nks , s.nocc + s.nvirt , s.naos , {}, true );
112112 std::vector<container::Tensor> V (s.nks , container::Tensor (DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos , s.naos }));
113113 set_rand (c.get_pointer (), size_c);
114114 for (auto & v : V)set_rand (v.data <std::complex <double >>(), size_v);
115115 X.fix_b (istate);
116116 AX_for.fix_b (istate);
117117 AX_blas.fix_b (istate);
118118 // occ
119- CVCX_occ_forloop_serial (V, c, X, s.naos , s.nocc , s.nvirt , AX_for);
120- CVCX_occ_blas (V, c, X, s.naos , s.nocc , s.nvirt , AX_blas, false );
119+ LR:: CVCX_occ_forloop_serial (V, c, X. get_pointer () , s.naos , s.nocc , s.nvirt , AX_for. get_pointer () );
120+ LR:: CVCX_occ_blas (V, c, X. get_pointer () , s.naos , s.nocc , s.nvirt , AX_blas. get_pointer () , false );
121121 AX_for.fix_k (0 );
122122 AX_blas.fix_k (0 );
123123 check_eq (AX_for.get_pointer (), AX_blas.get_pointer (), s.nks * s.nocc * s.nvirt );
124124 // virt
125- CVCX_virt_forloop_serial (V, c, X, s.naos , s.nocc , s.nvirt , AX_for);
126- CVCX_virt_blas (V, c, X, s.naos , s.nocc , s.nvirt , AX_blas, false );
125+ LR:: CVCX_virt_forloop_serial (V, c, X. get_pointer () , s.naos , s.nocc , s.nvirt , AX_for. get_pointer () );
126+ LR:: CVCX_virt_blas (V, c, X. get_pointer () , s.naos , s.nocc , s.nvirt , AX_blas. get_pointer () , false );
127127 AX_for.fix_k (0 );
128128 AX_blas.fix_k (0 );
129129 check_eq (AX_for.get_pointer (), AX_blas.get_pointer (), s.nks * s.nocc * s.nvirt );
@@ -142,7 +142,7 @@ TEST_F(AXTest, DoubleParallel)
142142 std::vector<container::Tensor> V (s.nks , container::Tensor (DAT::DT_DOUBLE, DEV::CpuDevice, { pV.get_col_size (), pV.get_row_size () }));
143143 Parallel_2D pc;
144144 LR_Util::setup_2d_division (pc, s.nb , s.naos , s.nocc + s.nvirt , pV.blacs_ctxt );
145- psi::Psi<double , base_device::DEVICE_CPU> c (s.nks , pc.get_col_size (), pc.get_row_size ());
145+ psi::Psi<double , base_device::DEVICE_CPU> c (s.nks , pc.get_col_size (), pc.get_row_size (), {}, true );
146146 Parallel_2D px;
147147 LR_Util::setup_2d_division (px, s.nb , s.nvirt , s.nocc , pV.blacs_ctxt );
148148
@@ -152,13 +152,13 @@ TEST_F(AXTest, DoubleParallel)
152152 EXPECT_GE (s.nocc , px.dim1 );
153153 EXPECT_GE (s.naos , pc.dim0 );
154154
155- psi::Psi<double , base_device::DEVICE_CPU> AX_pblas_loc (s.nks , nstate, px.get_local_size ());
156- psi::Psi<double , base_device::DEVICE_CPU> AX_gather (s.nks , nstate, s.nocc * s.nvirt , nullptr , false );
155+ psi::Psi<double , base_device::DEVICE_CPU> AX_pblas_loc (s.nks , nstate, px.get_local_size (), {}, false );
156+ psi::Psi<double , base_device::DEVICE_CPU> AX_gather (s.nks , nstate, s.nocc * s.nvirt , {} , false );
157157
158158 // set X and X_full
159- psi::Psi<double , base_device::DEVICE_CPU> X (s.nks , nstate, px.get_local_size (), nullptr , false );
159+ psi::Psi<double , base_device::DEVICE_CPU> X (s.nks , nstate, px.get_local_size (), {} , false );
160160 set_rand (X.get_pointer (), nstate * s.nks * px.get_local_size ());
161- psi::Psi<double , base_device::DEVICE_CPU> X_full (s.nks , nstate, s.nocc * s.nvirt , nullptr , false ); // allocate X_full
161+ psi::Psi<double , base_device::DEVICE_CPU> X_full (s.nks , nstate, s.nocc * s.nvirt , {} , false ); // allocate X_full
162162 for (int istate = 0 ;istate < nstate;++istate)
163163 {
164164 X.fix_b (istate);
@@ -183,7 +183,7 @@ TEST_F(AXTest, DoubleParallel)
183183 X_full.fix_b (istate);
184184 AX_pblas_loc.fix_b (istate);
185185 AX_gather.fix_b (istate);
186- CVCX_occ_pblas (V, pV, c, pc, X, px, s.naos , s.nocc , s.nvirt , AX_pblas_loc, false );
186+ LR:: CVCX_occ_pblas (V, pV, c, pc, X. get_pointer () , px, s.naos , s.nocc , s.nvirt , AX_pblas_loc. get_pointer () , false );
187187 // gather AX and output
188188 for (int isk = 0 ;isk < s.nks ;++isk)
189189 {
@@ -193,7 +193,7 @@ TEST_F(AXTest, DoubleParallel)
193193 }
194194 // compare to global AX
195195 std::vector<container::Tensor> V_full (s.nks , container::Tensor (DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos , s.naos }));
196- psi::Psi<double , base_device::DEVICE_CPU> c_full (s.nks , s.nocc + s.nvirt , s.naos );
196+ psi::Psi<double , base_device::DEVICE_CPU> c_full (s.nks , s.nocc + s.nvirt , s.naos , {}, true );
197197 for (int isk = 0 ;isk < s.nks ;++isk)
198198 {
199199 LR_Util::gather_2d_to_full (pV, V.at (isk).data <double >(), V_full.at (isk).data <double >());
@@ -203,15 +203,19 @@ TEST_F(AXTest, DoubleParallel)
203203 }
204204 if (my_rank == 0 )
205205 {
206- psi::Psi<double , base_device::DEVICE_CPU> AX_full_istate (s.nks , 1 , s.nocc * s.nvirt , nullptr , false );
207- CVCX_occ_blas (V_full, c_full, X_full, s.naos , s.nocc , s.nvirt , AX_full_istate, false );
206+ psi::Psi<double , base_device::DEVICE_CPU> AX_full_istate (s.nks , 1 , s.nocc * s.nvirt , {}, true );
207+ LR:: CVCX_occ_blas (V_full, c_full, X_full. get_pointer () , s.naos , s.nocc , s.nvirt , AX_full_istate. get_pointer () , false );
208208 AX_full_istate.fix_b (0 );
209209 AX_gather.fix_b (istate);
210210 check_eq (AX_full_istate.get_pointer (), AX_gather.get_pointer (), s.nks * s.nocc * s.nvirt );
211211 }
212212
213213 // //============ the same for virtual ==========
214- CVCX_virt_pblas (V, pV, c, pc, X, px, s.naos , s.nocc , s.nvirt , AX_pblas_loc, false );
214+ X.fix_b (istate);
215+ X_full.fix_b (istate);
216+ AX_pblas_loc.fix_b (istate);
217+ AX_gather.fix_b (istate);
218+ LR::CVCX_virt_pblas (V, pV, c, pc, X.get_pointer (), px, s.naos , s.nocc , s.nvirt , AX_pblas_loc.get_pointer (), false );
215219 for (int isk = 0 ;isk < s.nks ;++isk)
216220 {
217221 AX_pblas_loc.fix_k (isk);
@@ -220,8 +224,8 @@ TEST_F(AXTest, DoubleParallel)
220224 }
221225 if (my_rank == 0 )
222226 {
223- psi::Psi<double , base_device::DEVICE_CPU> AX_full_istate (s.nks , 1 , s.nocc * s.nvirt , nullptr , false );
224- CVCX_virt_blas (V_full, c_full, X_full, s.naos , s.nocc , s.nvirt , AX_full_istate, false );
227+ psi::Psi<double , base_device::DEVICE_CPU> AX_full_istate (s.nks , 1 , s.nocc * s.nvirt , {}, true );
228+ LR:: CVCX_virt_blas (V_full, c_full, X_full. get_pointer () , s.naos , s.nocc , s.nvirt , AX_full_istate. get_pointer () , false );
225229 AX_full_istate.fix_b (0 );
226230 AX_gather.fix_b (istate);
227231 check_eq (AX_full_istate.get_pointer (), AX_gather.get_pointer (), s.nks * s.nocc * s.nvirt );
@@ -240,17 +244,17 @@ TEST_F(AXTest, ComplexParallel)
240244 std::vector<container::Tensor> V (s.nks , container::Tensor (DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { pV.get_col_size (), pV.get_row_size () }));
241245 Parallel_2D pc;
242246 LR_Util::setup_2d_division (pc, s.nb , s.naos , s.nocc + s.nvirt , pV.blacs_ctxt );
243- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> c (s.nks , pc.get_col_size (), pc.get_row_size ());
247+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> c (s.nks , pc.get_col_size (), pc.get_row_size (), {}, true );
244248 Parallel_2D px;
245249 LR_Util::setup_2d_division (px, s.nb , s.nvirt , s.nocc , pV.blacs_ctxt );
246250
247- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_pblas_loc (s.nks , nstate, px.get_local_size ());
248- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_gather (s.nks , nstate, s.nocc * s.nvirt , nullptr , false );
251+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_pblas_loc (s.nks , nstate, px.get_local_size (), {}, false );
252+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_gather (s.nks , nstate, s.nocc * s.nvirt , {} , false );
249253
250254 // set X and X_full
251- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> X (s.nks , nstate, px.get_local_size (), nullptr , false );
255+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> X (s.nks , nstate, px.get_local_size (), {} , false );
252256 set_rand (X.get_pointer (), nstate * s.nks * px.get_local_size ());
253- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> X_full (s.nks , nstate, s.nocc * s.nvirt , nullptr , false ); // allocate X_full
257+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> X_full (s.nks , nstate, s.nocc * s.nvirt , {} , false ); // allocate X_full
254258 for (int istate = 0 ;istate < nstate;++istate)
255259 {
256260 X.fix_b (istate);
@@ -275,7 +279,7 @@ TEST_F(AXTest, ComplexParallel)
275279 X_full.fix_b (istate);
276280 AX_pblas_loc.fix_b (istate);
277281 AX_gather.fix_b (istate);
278- CVCX_occ_pblas (V, pV, c, pc, X, px, s.naos , s.nocc , s.nvirt , AX_pblas_loc, false );
282+ LR:: CVCX_occ_pblas (V, pV, c, pc, X. get_pointer () , px, s.naos , s.nocc , s.nvirt , AX_pblas_loc. get_pointer () , false );
279283
280284 // gather AX and output
281285 for (int isk = 0 ;isk < s.nks ;++isk)
@@ -286,7 +290,7 @@ TEST_F(AXTest, ComplexParallel)
286290 }
287291 // compare to global AX
288292 std::vector<container::Tensor> V_full (s.nks , container::Tensor (DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos , s.naos }));
289- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> c_full (s.nks , s.nocc + s.nvirt , s.naos );
293+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> c_full (s.nks , s.nocc + s.nvirt , s.naos , {}, true );
290294 for (int isk = 0 ;isk < s.nks ;++isk)
291295 {
292296 LR_Util::gather_2d_to_full (pV, V.at (isk).data <std::complex <double >>(), V_full.at (isk).data <std::complex <double >>());
@@ -296,14 +300,18 @@ TEST_F(AXTest, ComplexParallel)
296300 }
297301 if (my_rank == 0 )
298302 {
299- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_full_istate (s.nks , 1 , s.nocc * s.nvirt , nullptr , false );
300- CVCX_occ_blas (V_full, c_full, X_full, s.naos , s.nocc , s.nvirt , AX_full_istate, false );
303+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_full_istate (s.nks , 1 , s.nocc * s.nvirt , {} , false );
304+ LR:: CVCX_occ_blas (V_full, c_full, X_full. get_pointer () , s.naos , s.nocc , s.nvirt , AX_full_istate. get_pointer () , false );
301305 AX_full_istate.fix_b (0 );
302306 AX_gather.fix_b (istate);
303307 check_eq (AX_full_istate.get_pointer (), AX_gather.get_pointer (), s.nks * s.nocc * s.nvirt );
304308 }
305309 // //============ the same for virtual ==========
306- CVCX_virt_pblas (V, pV, c, pc, X, px, s.naos , s.nocc , s.nvirt , AX_pblas_loc, false );
310+ X.fix_b (istate);
311+ X_full.fix_b (istate);
312+ AX_pblas_loc.fix_b (istate);
313+ AX_gather.fix_b (istate);
314+ LR::CVCX_virt_pblas (V, pV, c, pc, X.get_pointer (), px, s.naos , s.nocc , s.nvirt , AX_pblas_loc.get_pointer (), false );
307315 for (int isk = 0 ;isk < s.nks ;++isk)
308316 {
309317 AX_pblas_loc.fix_k (isk);
@@ -312,8 +320,8 @@ TEST_F(AXTest, ComplexParallel)
312320 }
313321 if (my_rank == 0 )
314322 {
315- psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_full_istate (s.nks , 1 , s.nocc * s.nvirt , nullptr , false );
316- CVCX_virt_blas (V_full, c_full, X_full, s.naos , s.nocc , s.nvirt , AX_full_istate, false );
323+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU> AX_full_istate (s.nks , 1 , s.nocc * s.nvirt , {} , false );
324+ LR:: CVCX_virt_blas (V_full, c_full, X_full. get_pointer () , s.naos , s.nocc , s.nvirt , AX_full_istate. get_pointer () , false );
317325 AX_full_istate.fix_b (0 );
318326 AX_gather.fix_b (istate);
319327 check_eq (AX_full_istate.get_pointer (), AX_gather.get_pointer (), s.nks * s.nocc * s.nvirt );
0 commit comments