@@ -70,7 +70,8 @@ TEST_F(AXTest, DoubleSerial)
7070 int size_v = s.naos * s.naos ;
7171 for (int istate = 0 ;istate < nstate;++istate)
7272 {
73- psi::Psi<double > c (s.nks , s.nocc + s.nvirt , s.naos );
73+ std::vector<int > temp (s.nks , s.naos );
74+ psi::Psi<double > c (s.nks , s.nocc + s.nvirt , s.naos , temp.data (), true );
7475 std::vector<container::Tensor> V (s.nks , container::Tensor (DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos , s.naos }));
7576 set_rand (&c (0 , 0 , 0 ), size_c);
7677 for (auto & v : V) { set_rand (v.data <double >(), size_v); }
@@ -91,7 +92,8 @@ TEST_F(AXTest, ComplexSerial)
9192 int size_v = s.naos * s.naos ;
9293 for (int istate = 0 ;istate < nstate;++istate)
9394 {
94- psi::Psi<std::complex <double >> c (s.nks , s.nocc + s.nvirt , s.naos );
95+ std::vector<int > temp (s.nks , s.naos );
96+ psi::Psi<std::complex <double >> c (s.nks , s.nocc + s.nvirt , s.naos , temp.data (), true );
9597 std::vector<container::Tensor> V (s.nks , container::Tensor (DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos , s.naos }));
9698 set_rand (&c (0 , 0 , 0 ), size_c);
9799 for (auto & v : V) { set_rand (v.data <std::complex <double >>(), size_v); }
@@ -113,7 +115,9 @@ TEST_F(AXTest, DoubleParallel)
113115 std::vector<container::Tensor> V (s.nks , container::Tensor (DAT::DT_DOUBLE, DEV::CpuDevice, { pV.get_col_size (), pV.get_row_size () }));
114116 Parallel_2D pc;
115117 LR_Util::setup_2d_division (pc, s.nb , s.naos , s.nocc + s.nvirt , pV.blacs_ctxt );
116- psi::Psi<double > c (s.nks , pc.get_col_size (), pc.get_row_size ());
118+
119+ std::vector<int > ngk_temp (s.nks , pc.get_row_size ());
120+ psi::Psi<double > c (s.nks , pc.get_col_size (), pc.get_row_size (), ngk_temp.data (), true );
117121 Parallel_2D px;
118122 LR_Util::setup_2d_division (px, s.nb , s.nvirt , s.nocc , pV.blacs_ctxt );
119123
@@ -139,7 +143,9 @@ TEST_F(AXTest, DoubleParallel)
139143 }
140144 // compare to global AX
141145 std::vector<container::Tensor> V_full (s.nks , container::Tensor (DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos , s.naos }));
142- psi::Psi<double > c_full (s.nks , s.nocc + s.nvirt , s.naos );
146+
147+ std::vector<int > ngk_temp_1 (s.nks , s.naos );
148+ psi::Psi<double > c_full (s.nks , s.nocc + s.nvirt , s.naos , ngk_temp_1.data (), true );
143149 for (int isk = 0 ;isk < s.nks ;++isk)
144150 {
145151 LR_Util::gather_2d_to_full (pV, V.at (isk).data <double >(), V_full.at (isk).data <double >(), false , s.naos , s.naos );
@@ -165,7 +171,9 @@ TEST_F(AXTest, ComplexParallel)
165171 std::vector<container::Tensor> V (s.nks , container::Tensor (DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { pV.get_col_size (), pV.get_row_size () }));
166172 Parallel_2D pc;
167173 LR_Util::setup_2d_division (pc, s.nb , s.naos , s.nocc + s.nvirt , pV.blacs_ctxt );
168- psi::Psi<std::complex <double >> c (s.nks , pc.get_col_size (), pc.get_row_size ());
174+
175+ std::vector<int > ngk_temp_1 (s.nks , pc.get_row_size ());
176+ psi::Psi<std::complex <double >> c (s.nks , pc.get_col_size (), pc.get_row_size (), ngk_temp_1.data (), true );
169177 Parallel_2D px;
170178 LR_Util::setup_2d_division (px, s.nb , s.nvirt , s.nocc , pV.blacs_ctxt );
171179
@@ -187,7 +195,10 @@ TEST_F(AXTest, ComplexParallel)
187195 }
188196 // compare to global AX
189197 std::vector<container::Tensor> V_full (s.nks , container::Tensor (DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos , s.naos }));
190- psi::Psi<std::complex <double >> c_full (s.nks , s.nocc + s.nvirt , s.naos );
198+
199+
200+ std::vector<int > ngk_temp_2 (s.nks , s.naos );
201+ psi::Psi<std::complex <double >> c_full (s.nks , s.nocc + s.nvirt , s.naos , ngk_temp_2.data (), true );
191202 for (int isk = 0 ;isk < s.nks ;++isk)
192203 {
193204 LR_Util::gather_2d_to_full (pV, V.at (isk).data <std::complex <double >>(), V_full.at (isk).data <std::complex <double >>(), false , s.naos , s.naos );
0 commit comments