Skip to content

Commit 743b5a9

Browse files
committed
add Constructor 8-2
1 parent a364211 commit 743b5a9

File tree

5 files changed

+59
-16
lines changed

5 files changed

+59
-16
lines changed

source/module_hsolver/hsolver_pw.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
310310
#endif
311311

312312
/// solve eigenvector and eigenvalue for H(k)
313-
this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * psi.get_nbands());
313+
this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * psi.get_nbands(), this->wfc_basis->nks);
314314

315315
if (skip_charge)
316316
{
@@ -357,7 +357,8 @@ template <typename T, typename Device>
357357
void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
358358
psi::Psi<T, Device>& psi,
359359
std::vector<Real>& pre_condition,
360-
Real* eigenvalue)
360+
Real* eigenvalue,
361+
const int& nk_nums)
361362
{
362363
#ifdef __MPI
363364
const diag_comm_info comm_info = {POOL_WORLD, this->rank_in_pool, this->nproc_in_pool};
@@ -367,10 +368,16 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
367368

368369
auto ngk_pointer = psi.get_ngk_pointer();
369370

371+
std::vector<int> ngk_vector(nk_nums, 0);
372+
for (int i = 0; i < nk_nums; i++)
373+
{
374+
ngk_vector[i] = ngk_pointer[i];
375+
}
376+
370377
if (this->method == "cg")
371378
{
372379
// wrap the subspace_func into a lambda function
373-
auto subspace_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
380+
auto subspace_func = [hm, ngk_pointer, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
374381
// psi_in should be a 2D tensor:
375382
// psi_in.shape() = [nbands, nbasis]
376383
const auto ndim = psi_in.shape().ndim();
@@ -380,12 +387,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
380387
1,
381388
psi_in.shape().dim_size(0),
382389
psi_in.shape().dim_size(1),
383-
nullptr);
390+
ngk_vector);
384391
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
385392
1,
386393
psi_out.shape().dim_size(0),
387394
psi_out.shape().dim_size(1),
388-
nullptr);
395+
ngk_vector);
389396
auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
390397
ct::DeviceType::CpuDevice,
391398
ct::TensorShape({psi_in.shape().dim_size(0)}));
@@ -404,7 +411,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
404411
using ct_Device = typename ct::PsiToContainer<Device>::type;
405412

406413
// wrap the hpsi_func and spsi_func into a lambda function
407-
auto hpsi_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
414+
auto hpsi_func = [hm, ngk_pointer, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
408415
ModuleBase::timer::tick("DiagoCG_New", "hpsi_func");
409416
// psi_in should be a 2D tensor:
410417
// psi_in.shape() = [nbands, nbasis]
@@ -415,7 +422,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
415422
1,
416423
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
417424
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
418-
nullptr);
425+
ngk_vector);
419426
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
420427
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
421428
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<T>());
@@ -475,11 +482,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
475482
const int nband = psi.get_nbands();
476483
const int nbasis = psi.get_nbasis();
477484
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
478-
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
485+
auto hpsi_func = [hm, ngk_pointer, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
479486
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
480487

481488
// Convert "pointer data stucture" to a psi::Psi object
482-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, nullptr);
489+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
483490

484491
psi::Range bands_range(true, 0, 0, nvec - 1);
485492

@@ -496,11 +503,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
496503
else if (this->method == "dav_subspace")
497504
{
498505
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
499-
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
506+
auto hpsi_func = [hm, ngk_pointer, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
500507
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
501508

502509
// Convert "pointer data stucture" to a psi::Psi object
503-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, nullptr);
510+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
504511

505512
psi::Range bands_range(true, 0, 0, nvec - 1);
506513

@@ -547,11 +554,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
547554
// Davidson matrix-blockvector functions
548555
/// wrap hpsi into lambda function, Matrix \times blockvector
549556
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
550-
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
557+
auto hpsi_func = [hm, ngk_pointer, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
551558
ModuleBase::timer::tick("David", "hpsi_func");
552559

553560
// Convert pointer of psi_in to a psi::Psi object
554-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, nullptr);
561+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
555562

556563
psi::Range bands_range(true, 0, 0, nvec - 1);
557564

source/module_hsolver/hsolver_pw.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class HSolverPW
5656
void hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
5757
psi::Psi<T, Device>& psi,
5858
std::vector<Real>& pre_condition,
59-
Real* eigenvalue);
59+
Real* eigenvalue,
60+
const int& nk_nums);
6061

6162
// calculate the precondition array for diagonalization in PW base
6263
void update_precondition(std::vector<Real>& h_diag, const int ik, const int npw, const Real vl_of_0);

source/module_hsolver/hsolver_pw_sdft.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void HSolverPW_SDFT<T, Device>::solve(const UnitCell& ucell,
5454
this->update_precondition(precondition, ik, this->wfc_basis->npwk[ik], pes->pot->get_vl_of_0());
5555
/// solve eigenvector and eigenvalue for H(k)
5656
double* p_eigenvalues = &(pes->ekb(ik, 0));
57-
this->hamiltSolvePsiK(pHamilt, psi, precondition, p_eigenvalues);
57+
this->hamiltSolvePsiK(pHamilt, psi, precondition, p_eigenvalues, nks);
5858
}
5959

6060
#ifdef __MPI

source/module_psi/psi.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,31 @@ Psi<T, Device>::Psi(T* psi_pointer,
9393
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
9494
}
9595

96+
97+
template <typename T, typename Device>
98+
Psi<T, Device>::Psi(T* psi_pointer,
99+
const int nk_in,
100+
const int nbd_in,
101+
const int nbs_in,
102+
const std::vector<int>& ngk_vector_in,
103+
const bool k_first_in)
104+
{
105+
this->k_first = k_first_in;
106+
this->ngk = ngk_vector_in.data();
107+
this->current_b = 0;
108+
this->current_k = 0;
109+
this->npol = PARAM.globalv.npol;
110+
this->device = base_device::get_device_type<Device>(this->ctx);
111+
this->nk = nk_in;
112+
this->nbands = nbd_in;
113+
this->nbasis = nbs_in;
114+
this->current_nbasis = nbs_in;
115+
this->psi_current = this->psi = psi_pointer;
116+
this->allocate_inside = false;
117+
// Currently only GPU's implementation is supported for device recording!
118+
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
119+
}
120+
96121
template <typename T, typename Device>
97122
Psi<T, Device>::Psi(const Psi& psi_in, const int nk_in, int nband_in)
98123
{

source/module_psi/psi.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,23 @@ class Psi
5757
template <typename T_in, typename Device_in = Device>
5858
Psi(const Psi<T_in, Device_in>& psi_in);
5959

60-
// Constructor 8: a pointer version of constructor 3
60+
// Constructor 8-1: a pointer version of constructor 3
6161
Psi(T* psi_pointer,
6262
const int nk_in,
6363
const int nbd_in,
6464
const int nbs_in,
6565
const int* ngk_in = nullptr,
6666
const bool k_first_in = true);
67+
68+
// Constructor 8-2: a pointer version of constructor 3
69+
Psi(T* psi_pointer,
70+
const int nk_in,
71+
const int nbd_in,
72+
const int nbs_in,
73+
const std::vector<int>& ngk_vector_in,
74+
const bool k_first_in = true);
75+
76+
6777
// Destructor for deleting the psi array manually
6878
~Psi();
6979

0 commit comments

Comments
 (0)