@@ -375,6 +375,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
375375 Real* eigenvalue,
376376 const int & nk_nums)
377377{
378+ ModuleBase::timer::tick (" HSolverPW" , " solve_psik" );
378379#ifdef __MPI
379380 const diag_comm_info comm_info = {POOL_WORLD, this ->rank_in_pool , this ->nproc_in_pool };
380381#else
@@ -421,7 +422,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
421422
422423 // wrap the hpsi_func and spsi_func into a lambda function
423424 auto hpsi_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
424- ModuleBase::timer::tick (" DiagoCG_New" , " hpsi_func" );
425425 // psi_in should be a 2D tensor:
426426 // psi_in.shape() = [nbands, nbasis]
427427 const auto ndim = psi_in.shape ().ndim ();
@@ -436,10 +436,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
436436 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
437437 hpsi_info info (&psi_wrapper, all_bands_range, hpsi_out.data <T>());
438438 hm->ops ->hPsi (info);
439- ModuleBase::timer::tick (" DiagoCG_New" , " hpsi_func" );
440439 };
441440 auto spsi_func = [this , hm](const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
442- ModuleBase::timer::tick (" DiagoCG_New" , " spsi_func" );
443441 // psi_in should be a 2D tensor:
444442 // psi_in.shape() = [nbands, nbasis]
445443 const auto ndim = psi_in.shape ().ndim ();
@@ -462,17 +460,18 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
462460 static_cast <size_t >((ndim == 1 ? 1 : psi_in.shape ().dim_size (0 ))
463461 * (ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 ))));
464462 }
465-
466- ModuleBase::timer::tick (" DiagoCG_New" , " spsi_func" );
467463 };
464+
468465 auto psi_tensor = ct::TensorMap (psi.get_pointer (),
469466 ct::DataTypeToEnum<T>::value,
470467 ct::DeviceTypeToEnum<ct_Device>::value,
471468 ct::TensorShape ({psi.get_nbands (), psi.get_nbasis ()}));
469+
472470 auto eigen_tensor = ct::TensorMap (eigenvalue,
473471 ct::DataTypeToEnum<Real>::value,
474472 ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
475473 ct::TensorShape ({psi.get_nbands ()}));
474+
476475 auto prec_tensor = ct::TensorMap (pre_condition.data (),
477476 ct::DataTypeToEnum<Real>::value,
478477 ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
@@ -491,7 +490,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
491490 const int ndim = psi.get_current_ngk ();
492491 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
493492 auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
494- ModuleBase::timer::tick (" diago_bpcg" , " hpsi_func" );
495493
496494 // Convert "pointer data stucture" to a psi::Psi object
497495 auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, cur_nbasis);
@@ -501,8 +499,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
501499 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
502500 hpsi_info info (&psi_iter_wrapper, bands_range, hpsi_out);
503501 hm->ops ->hPsi (info);
504-
505- ModuleBase::timer::tick (" diago_bpcg" , " hpsi_func" );
506502 };
507503 DiagoBPCG<T, Device> bpcg (pre_condition.data ());
508504 bpcg.init_iter (PARAM.inp .nbands , nband_l, nbasis, ndim);
@@ -512,7 +508,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
512508 {
513509 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
514510 auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
515- ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
516511
517512 // Convert "pointer data stucture" to a psi::Psi object
518513 auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, cur_nbasis);
@@ -522,8 +517,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
522517 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
523518 hpsi_info info (&psi_iter_wrapper, bands_range, hpsi_out);
524519 hm->ops ->hPsi (info);
525-
526- ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
527520 };
528521 bool scf = this ->calculation_type == " nscf" ? false : true ;
529522
@@ -565,7 +558,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
565558 // / wrap hpsi into lambda function, Matrix \times blockvector
566559 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
567560 auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
568- ModuleBase::timer::tick (" David" , " hpsi_func" );
569561
570562 // Convert pointer of psi_in to a psi::Psi object
571563 auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, cur_nbasis);
@@ -575,8 +567,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
575567 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
576568 hpsi_info info (&psi_iter_wrapper, bands_range, hpsi_out);
577569 hm->ops ->hPsi (info);
578-
579- ModuleBase::timer::tick (" David" , " hpsi_func" );
580570 };
581571
582572 // / wrap spsi into lambda function, Matrix \times blockvector
@@ -587,11 +577,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
587577 const int ld_psi, // Leading dimension of psi and spsi.
588578 const int nvec // Number of vectors(bands)
589579 ) {
590- ModuleBase::timer::tick (" David" , " spsi_func" );
591580 // sPsi determines S=I or not by PARAM.globalv.use_uspp inside
592581 // sPsi(psi, spsi, nrow, npw, nbands)
593582 hm->sPsi (psi_in, spsi_out, ld_psi, ld_psi, nvec);
594- ModuleBase::timer::tick (" David" , " spsi_func" );
595583 };
596584
597585 DiagoDavid<T, Device> david (pre_condition.data (), nband, dim, PARAM.inp .pw_diag_ndim , this ->use_paw , comm_info);
@@ -606,6 +594,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
606594 ntry_max,
607595 notconv_max));
608596 }
597+ ModuleBase::timer::tick (" HSolverPW" , " solve_psik" );
609598 return ;
610599}
611600
0 commit comments