@@ -15,10 +15,6 @@ Evolve_elec<Device>::Evolve_elec(){};
1515template <typename Device>
1616Evolve_elec<Device>::~Evolve_elec (){};
1717
18- template <typename Device>
19- Device* Evolve_elec<Device>::ctx = {};
20- template <typename Device>
21- base_device::DEVICE_CPU* Evolve_elec<Device>::cpu_ctx = {};
2218template <typename Device>
2319ct::DeviceType Evolve_elec<Device>::ct_device_type = ct::DeviceTypeToEnum<Device>::value;
2420
@@ -89,53 +85,69 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
8985 }
9086 else
9187 {
92- // std::cout << "nband = " << nband << std::endl;
93- // std::cout << "psi->get_nbands() = " << psi->get_nbands() << std::endl;
94- // std::cout << "nlocal = " << nlocal << std::endl;
95- // std::cout << "psi->get_nbasis() = " << psi->get_nbasis() << std::endl;
96- // std::cout << "para_orb.nloc = " << para_orb.nloc << std::endl;
97- // std::cout << "para_orb.nrow = " << para_orb.nrow << std::endl;
98- // std::cout << "para_orb.ncol = " << para_orb.ncol << std::endl;
99- // std::cout << "ekb.nr = " << ekb.nr << std::endl;
100- // std::cout << "ekb.nc = " << ekb.nc << std::endl;
88+ const int len_psi_k_1 = use_lapack ? nband : psi->get_nbands ();
89+ const int len_psi_k_2 = use_lapack ? nlocal : psi->get_nbasis ();
90+ const int len_HS_laststep = use_lapack ? nlocal * nlocal : para_orb.nloc ;
10191
10292 // Create Tensor for psi_k, psi_k_laststep, H_laststep, S_laststep, ekb
10393 ct::Tensor psi_k_tensor (ct::DataType::DT_COMPLEX_DOUBLE,
10494 ct_device_type,
105- ct::TensorShape ({psi-> get_nbands (), psi-> get_nbasis () }));
95+ ct::TensorShape ({len_psi_k_1, len_psi_k_2 }));
10696 ct::Tensor psi_k_laststep_tensor (ct::DataType::DT_COMPLEX_DOUBLE,
10797 ct_device_type,
108- ct::TensorShape ({psi-> get_nbands (), psi-> get_nbasis () }));
98+ ct::TensorShape ({len_psi_k_1, len_psi_k_2 }));
10999 ct::Tensor H_laststep_tensor (ct::DataType::DT_COMPLEX_DOUBLE,
110100 ct_device_type,
111- ct::TensorShape ({para_orb. nloc }));
101+ ct::TensorShape ({len_HS_laststep }));
112102 ct::Tensor S_laststep_tensor (ct::DataType::DT_COMPLEX_DOUBLE,
113103 ct_device_type,
114- ct::TensorShape ({para_orb. nloc }));
104+ ct::TensorShape ({len_HS_laststep }));
115105 ct::Tensor ekb_tensor (ct::DataType::DT_DOUBLE, ct_device_type, ct::TensorShape ({nband}));
116106
117- // Syncronize data from CPU to Device
118- syncmem_complex_h2d_op ()(ctx,
119- cpu_ctx,
120- psi_k_tensor.data <std::complex <double >>(),
121- psi[0 ].get_pointer (),
122- psi->get_nbands () * psi->get_nbasis ());
123- syncmem_complex_h2d_op ()(ctx,
124- cpu_ctx,
125- psi_k_laststep_tensor.data <std::complex <double >>(),
126- psi_laststep[0 ].get_pointer (),
127- psi->get_nbands () * psi->get_nbasis ());
128- syncmem_complex_h2d_op ()(ctx,
129- cpu_ctx,
130- H_laststep_tensor.data <std::complex <double >>(),
107+ // Global psi
108+ ModuleESolver::Matrix_g<std::complex <double >> psi_g;
109+ ModuleESolver::Matrix_g<std::complex <double >> psi_laststep_g;
110+
111+ if (use_lapack)
112+ {
113+ // Need to gather the psi to the root process on CPU
114+ // H_laststep and S_laststep are already gathered in esolver_ks_lcao_tddft.cpp
115+ #ifdef __MPI
116+ // Access the rank of the calling process in the communicator
117+ int myid, root_proc = 0 ;
118+ MPI_Comm_rank (MPI_COMM_WORLD, &myid);
119+
120+ // Gather psi to the root process
121+ gatherPsi (myid, root_proc, psi[0 ].get_pointer (), para_orb, psi_g);
122+ gatherPsi (myid, root_proc, psi_laststep[0 ].get_pointer (), para_orb, psi_laststep_g);
123+
124+ // Syncronize data from CPU to Device
125+ syncmem_complex_h2d_op ()(psi_k_tensor.data <std::complex <double >>(),
126+ psi_g.p .get (),
127+ len_psi_k_1 * len_psi_k_2);
128+ syncmem_complex_h2d_op ()(psi_k_laststep_tensor.data <std::complex <double >>(),
129+ psi_laststep_g.p .get (),
130+ len_psi_k_1 * len_psi_k_2);
131+ #endif
132+ }
133+ else
134+ {
135+ // Syncronize data from CPU to Device
136+ syncmem_complex_h2d_op ()(psi_k_tensor.data <std::complex <double >>(),
137+ psi[0 ].get_pointer (),
138+ len_psi_k_1 * len_psi_k_2);
139+ syncmem_complex_h2d_op ()(psi_k_laststep_tensor.data <std::complex <double >>(),
140+ psi_laststep[0 ].get_pointer (),
141+ len_psi_k_1 * len_psi_k_2);
142+ }
143+
144+ syncmem_complex_h2d_op ()(H_laststep_tensor.data <std::complex <double >>(),
131145 Hk_laststep[ik],
132- para_orb.nloc );
133- syncmem_complex_h2d_op ()(ctx,
134- cpu_ctx,
135- S_laststep_tensor.data <std::complex <double >>(),
146+ len_HS_laststep);
147+ syncmem_complex_h2d_op ()(S_laststep_tensor.data <std::complex <double >>(),
136148 Sk_laststep[ik],
137- para_orb. nloc );
138- syncmem_double_h2d_op ()(ctx, cpu_ctx, ekb_tensor.data <double >(), &(ekb (ik, 0 )), nband);
149+ len_HS_laststep );
150+ syncmem_double_h2d_op ()(ekb_tensor.data <double >(), &(ekb (ik, 0 )), nband);
139151
140152 evolve_psi_tensor<Device>(nband,
141153 nlocal,
@@ -151,28 +163,40 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
151163 print_matrix,
152164 use_lapack);
153165
154- // Syncronize data from Device to CPU
155- syncmem_complex_d2h_op ()(cpu_ctx,
156- ctx,
157- psi[0 ].get_pointer (),
158- psi_k_tensor.data <std::complex <double >>(),
159- psi->get_nbands () * psi->get_nbasis ());
160- syncmem_complex_d2h_op ()(cpu_ctx,
161- ctx,
162- psi_laststep[0 ].get_pointer (),
163- psi_k_laststep_tensor.data <std::complex <double >>(),
164- psi->get_nbands () * psi->get_nbasis ());
165- syncmem_complex_d2h_op ()(cpu_ctx,
166- ctx,
167- Hk_laststep[ik],
166+ // Need to distribute global psi back to all processes
167+ if (use_lapack)
168+ {
169+ #ifdef __MPI
170+ // Syncronize data from Device to CPU
171+ syncmem_complex_d2h_op ()(psi_g.p .get (),
172+ psi_k_tensor.data <std::complex <double >>(),
173+ len_psi_k_1 * len_psi_k_2);
174+ syncmem_complex_d2h_op ()(psi_laststep_g.p .get (),
175+ psi_k_laststep_tensor.data <std::complex <double >>(),
176+ len_psi_k_1 * len_psi_k_2);
177+
178+ // Distribute psi to all processes
179+ distributePsi (para_orb, psi[0 ].get_pointer (), psi_g);
180+ distributePsi (para_orb, psi_laststep[0 ].get_pointer (), psi_laststep_g);
181+ #endif
182+ }
183+ else
184+ {
185+ // Syncronize data from Device to CPU
186+ syncmem_complex_d2h_op ()(psi[0 ].get_pointer (),
187+ psi_k_tensor.data <std::complex <double >>(),
188+ len_psi_k_1 * len_psi_k_2);
189+ syncmem_complex_d2h_op ()(psi_laststep[0 ].get_pointer (),
190+ psi_k_laststep_tensor.data <std::complex <double >>(),
191+ len_psi_k_1 * len_psi_k_2);
192+ }
193+ syncmem_complex_d2h_op ()(Hk_laststep[ik],
168194 H_laststep_tensor.data <std::complex <double >>(),
169- para_orb.nloc );
170- syncmem_complex_d2h_op ()(cpu_ctx,
171- ctx,
172- Sk_laststep[ik],
195+ len_HS_laststep);
196+ syncmem_complex_d2h_op ()(Sk_laststep[ik],
173197 S_laststep_tensor.data <std::complex <double >>(),
174- para_orb. nloc );
175- syncmem_double_d2h_op ()(cpu_ctx, ctx, &(ekb (ik, 0 )), ekb_tensor.data <double >(), nband);
198+ len_HS_laststep );
199+ syncmem_double_d2h_op ()(&(ekb (ik, 0 )), ekb_tensor.data <double >(), nband);
176200
177201 // std::cout << "Print ekb tensor: " << std::endl;
178202 // ekb.print(std::cout);
0 commit comments