@@ -254,27 +254,15 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
254254 // wrap the subspace_func into a lambda function
255255 // if S_orth is true, then assume psi is S-orthogonal, solve standard eigenproblem
256256 // otherwise, solve generalized eigenproblem
257- auto subspace_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool S_orth) {
258- // psi_in should be a 2D tensor:
259- // psi_in.shape() = [nbands, nbasis]
260- const auto ndim = psi_in.shape ().ndim ();
261- REQUIRES_OK (ndim == 2 , " dims of psi_in should be less than or equal to 2" );
262- // Convert a Tensor object to a psi::Psi object
263- auto psi_in_wrapper = psi::Psi<T, Device>(psi_in.data <T>(),
264- 1 ,
265- psi_in.shape ().dim_size (0 ),
266- psi_in.shape ().dim_size (1 ),
267- cur_nbasis);
268- auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data <T>(),
269- 1 ,
270- psi_out.shape ().dim_size (0 ),
271- psi_out.shape ().dim_size (1 ),
272- cur_nbasis);
273- auto eigen = ct::Tensor (ct::DataTypeToEnum<Real>::value,
274- ct::DeviceType::CpuDevice,
275- ct::TensorShape ({psi_in.shape ().dim_size (0 )}));
276-
277- DiagoIterAssist<T, Device>::diag_subspace (hm, psi_in_wrapper, psi_out_wrapper, eigen.data <Real>());
257+ auto subspace_func = [hm, cur_nbasis](T* psi_in,
258+ T* psi_out,
259+ const int ld_psi,
260+ const int nband,
261+ const bool S_orth) {
262+ auto psi_in_wrapper = psi::Psi<T, Device>(psi_in, 1 , nband, ld_psi, cur_nbasis);
263+ auto psi_out_wrapper = psi::Psi<T, Device>(psi_out, 1 , nband, ld_psi, cur_nbasis);
264+ std::vector<Real> eigen (nband, 0.0 );
265+ DiagoIterAssist<T, Device>::diag_subspace (hm, psi_in_wrapper, psi_out_wrapper, eigen.data ());
278266 };
279267 DiagoCG<T, Device> cg (this ->basis_type ,
280268 this ->calculation_type ,
@@ -284,70 +272,38 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
284272 this ->diag_iter_max ,
285273 this ->nproc_in_pool );
286274
287- // wrap the hpsi_func and spsi_func into a lambda function
288- using ct_Device = typename ct::PsiToContainer<Device>::type;
289-
290- // wrap the hpsi_func and spsi_func into a lambda function
291- auto hpsi_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
292- // psi_in should be a 2D tensor:
293- // psi_in.shape() = [nbands, nbasis]
294- const auto ndim = psi_in.shape ().ndim ();
295- REQUIRES_OK (ndim <= 2 , " dims of psi_in should be less than or equal to 2" );
296- // Convert a Tensor object to a psi::Psi object
297- auto psi_wrapper = psi::Psi<T, Device>(psi_in.data <T>(),
298- 1 ,
299- ndim == 1 ? 1 : psi_in.shape ().dim_size (0 ),
300- ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 ),
301- cur_nbasis);
302- psi::Range all_bands_range (true , psi_wrapper.get_current_k (), 0 , psi_wrapper.get_nbands () - 1 );
275+ // wrap the hpsi_func and spsi_func into lambda functions
276+ auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
277+ auto psi_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, cur_nbasis);
278+ psi::Range all_bands_range (true , 0 , 0 , nvec - 1 );
303279 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
304- hpsi_info info (&psi_wrapper, all_bands_range, hpsi_out. data <T>() );
280+ hpsi_info info (&psi_wrapper, all_bands_range, hpsi_out);
305281 hm->ops ->hPsi (info);
306282 };
307- auto spsi_func = [this , hm](const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
308- // psi_in should be a 2D tensor:
309- // psi_in.shape() = [nbands, nbasis]
310- const auto ndim = psi_in.shape ().ndim ();
311- REQUIRES_OK (ndim <= 2 , " dims of psi_in should be less than or equal to 2" );
312-
283+ auto spsi_func = [this , hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) {
313284 if (this ->use_uspp )
314285 {
315- // Convert a Tensor object to a psi::Psi object
316- hm->sPsi (psi_in.data <T>(),
317- spsi_out.data <T>(),
318- ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 ),
319- ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 ),
320- ndim == 1 ? 1 : psi_in.shape ().dim_size (0 ));
286+ hm->sPsi (psi_in, spsi_out, ld_psi, ld_psi, nvec);
321287 }
322288 else
323289 {
324290 base_device::memory::synchronize_memory_op<T, Device, Device>()(
325- spsi_out.data <T>(),
326- psi_in.data <T>(),
327- static_cast <size_t >((ndim == 1 ? 1 : psi_in.shape ().dim_size (0 ))
328- * (ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 ))));
291+ spsi_out,
292+ psi_in,
293+ static_cast <size_t >(nvec) * static_cast <size_t >(ld_psi));
329294 }
330295 };
331296
332- auto psi_tensor = ct::TensorMap (psi.get_pointer (),
333- ct::DataTypeToEnum<T>::value,
334- ct::DeviceTypeToEnum<ct_Device>::value,
335- ct::TensorShape ({psi.get_nbands (), psi.get_nbasis ()}));
336-
337- auto eigen_tensor = ct::TensorMap (eigenvalue,
338- ct::DataTypeToEnum<Real>::value,
339- ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
340- ct::TensorShape ({psi.get_nbands ()}));
341-
342- auto prec_tensor = ct::TensorMap (pre_condition.data (),
343- ct::DataTypeToEnum<Real>::value,
344- ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
345- ct::TensorShape ({static_cast <int >(pre_condition.size ())}))
346- .to_device <ct_Device>()
347- .slice ({0 }, {psi.get_current_ngk ()});
348-
349297 DiagoIterAssist<T, Device>::avg_iter += static_cast <double >(
350- cg.diag (hpsi_func, spsi_func, psi_tensor, eigen_tensor, this ->ethr_band , prec_tensor)
298+ cg.diag (hpsi_func,
299+ spsi_func,
300+ psi.get_nbasis (),
301+ psi.get_nbands (),
302+ psi.get_current_ngk (),
303+ psi.get_pointer (),
304+ eigenvalue,
305+ this ->ethr_band ,
306+ pre_condition.data ())
351307 );
352308 // TODO: Double check tensormap's potential problem
353309 // ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor);
0 commit comments