Skip to content

Commit 3728bfa

Browse files
committed
update Constructor 8
1 parent 55f7fdb commit 3728bfa

File tree

5 files changed

+103
-24
lines changed

5 files changed

+103
-24
lines changed

source/module_hsolver/hsolver_pw.cpp

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,11 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
310310
#endif
311311

312312
/// solve eigenvector and eigenvalue for H(k)
313-
this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * psi.get_nbands());
313+
this->hamiltSolvePsiK(pHamilt,
314+
psi,
315+
precondition,
316+
eigenvalues.data() + ik * psi.get_nbands(),
317+
this->wfc_basis->nks);
314318

315319
if (skip_charge)
316320
{
@@ -357,19 +361,28 @@ template <typename T, typename Device>
357361
void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
358362
psi::Psi<T, Device>& psi,
359363
std::vector<Real>& pre_condition,
360-
Real* eigenvalue)
364+
Real* eigenvalue,
365+
const int& nk_nums)
361366
{
362367
#ifdef __MPI
363368
const diag_comm_info comm_info = {POOL_WORLD, this->rank_in_pool, this->nproc_in_pool};
364369
#else
365370
const diag_comm_info comm_info = {this->rank_in_pool, this->nproc_in_pool};
366371
#endif
367372

373+
auto ngk_pointer = psi.get_ngk_pointer();
374+
375+
std::vector<int> ngk_vector_temp(nk_nums, 0);
376+
377+
for (size_t i = 0; i < nk_nums; i++)
378+
{
379+
ngk_vector_temp[i] = ngk_pointer[i];
380+
}
381+
368382
if (this->method == "cg")
369383
{
370384
// wrap the subspace_func into a lambda function
371-
auto ngk_pointer = psi.get_ngk_pointer();
372-
auto subspace_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
385+
auto subspace_func = [hm, ngk_vector_temp](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
373386
// psi_in should be a 2D tensor:
374387
// psi_in.shape() = [nbands, nbasis]
375388
const auto ndim = psi_in.shape().ndim();
@@ -379,12 +392,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
379392
1,
380393
psi_in.shape().dim_size(0),
381394
psi_in.shape().dim_size(1),
382-
ngk_pointer);
395+
ngk_vector_temp);
383396
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
384397
1,
385398
psi_out.shape().dim_size(0),
386399
psi_out.shape().dim_size(1),
387-
ngk_pointer);
400+
ngk_vector_temp);
388401
auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
389402
ct::DeviceType::CpuDevice,
390403
ct::TensorShape({psi_in.shape().dim_size(0)}));
@@ -403,7 +416,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
403416
using ct_Device = typename ct::PsiToContainer<Device>::type;
404417

405418
// wrap the hpsi_func and spsi_func into a lambda function
406-
auto hpsi_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
419+
auto hpsi_func = [hm, ngk_vector_temp](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
407420
ModuleBase::timer::tick("DiagoCG_New", "hpsi_func");
408421
// psi_in should be a 2D tensor:
409422
// psi_in.shape() = [nbands, nbasis]
@@ -414,7 +427,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
414427
1,
415428
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
416429
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
417-
ngk_pointer);
430+
ngk_vector_temp);
418431
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
419432
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
420433
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<T>());
@@ -473,13 +486,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
473486
{
474487
const int nband = psi.get_nbands();
475488
const int nbasis = psi.get_nbasis();
476-
auto ngk_pointer = psi.get_ngk_pointer();
477489
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
478-
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
490+
auto hpsi_func = [hm, ngk_vector_temp](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
479491
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
480492

481493
// Convert "pointer data stucture" to a psi::Psi object
482-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
494+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector_temp);
483495

484496
psi::Range bands_range(true, 0, 0, nvec - 1);
485497

@@ -495,13 +507,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
495507
}
496508
else if (this->method == "dav_subspace")
497509
{
498-
auto ngk_pointer = psi.get_ngk_pointer();
499510
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
500-
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
511+
auto hpsi_func = [hm, ngk_vector_temp](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
501512
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
502513

503514
// Convert "pointer data stucture" to a psi::Psi object
504-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
515+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector_temp);
505516

506517
psi::Range bands_range(true, 0, 0, nvec - 1);
507518

@@ -546,15 +557,13 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
546557
const int ld_psi = psi.get_nbasis(); /// leading dimension of psi
547558

548559
// Davidson matrix-blockvector functions
549-
550-
auto ngk_pointer = psi.get_ngk_pointer();
551560
/// wrap hpsi into lambda function, Matrix \times blockvector
552561
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
553-
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
562+
auto hpsi_func = [hm, ngk_vector_temp](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
554563
ModuleBase::timer::tick("David", "hpsi_func");
555564

556565
// Convert pointer of psi_in to a psi::Psi object
557-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
566+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector_temp);
558567

559568
psi::Range bands_range(true, 0, 0, nvec - 1);
560569

source/module_hsolver/hsolver_pw.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class HSolverPW
5656
void hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
5757
psi::Psi<T, Device>& psi,
5858
std::vector<Real>& pre_condition,
59-
Real* eigenvalue);
59+
Real* eigenvalue,
60+
const int& nk_nums);
6061

6162
// calculate the precondition array for diagonalization in PW base
6263
void update_precondition(std::vector<Real>& h_diag, const int ik, const int npw, const Real vl_of_0);

source/module_hsolver/hsolver_pw_sdft.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void HSolverPW_SDFT<T, Device>::solve(const UnitCell& ucell,
5454
this->update_precondition(precondition, ik, this->wfc_basis->npwk[ik], pes->pot->get_vl_of_0());
5555
/// solve eigenvector and eigenvalue for H(k)
5656
double* p_eigenvalues = &(pes->ekb(ik, 0));
57-
this->hamiltSolvePsiK(pHamilt, psi, precondition, p_eigenvalues);
57+
this->hamiltSolvePsiK(pHamilt, psi, precondition, p_eigenvalues, nks);
5858
}
5959

6060
#ifdef __MPI

source/module_psi/psi.cpp

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,57 @@ template <typename T, typename Device> Psi<T, Device>::Psi(const int nk_in, cons
6666
sizeof(T) * nk_in * nbd_in * nbs_in);
6767
}
6868

69-
template <typename T, typename Device> Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in)
69+
// template <typename T, typename Device> Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in)
70+
// {
71+
// this->k_first = k_first_in;
72+
// this->ngk = ngk_in;
73+
// this->current_b = 0;
74+
// this->current_k = 0;
75+
// this->npol = PARAM.globalv.npol;
76+
// this->device = base_device::get_device_type<Device>(this->ctx);
77+
// this->nk = nk_in;
78+
// this->nbands = nbd_in;
79+
// this->nbasis = nbs_in;
80+
// this->current_nbasis = nbs_in;
81+
// this->psi_current = this->psi = psi_pointer;
82+
// this->allocate_inside = false;
83+
// // Currently only GPU's implementation is supported for device recording!
84+
// base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
85+
// }
86+
87+
template <typename T, typename Device>
88+
Psi<T, Device>::Psi(T* psi_pointer,
89+
const int nk_in,
90+
const int nbd_in,
91+
const int nbs_in,
92+
const bool k_first_in)
7093
{
7194
this->k_first = k_first_in;
72-
this->ngk = ngk_in;
95+
this->ngk = nullptr;
96+
this->current_b = 0;
97+
this->current_k = 0;
98+
this->npol = PARAM.globalv.npol;
99+
this->device = base_device::get_device_type<Device>(this->ctx);
100+
this->nk = nk_in;
101+
this->nbands = nbd_in;
102+
this->nbasis = nbs_in;
103+
this->current_nbasis = nbs_in;
104+
this->psi_current = this->psi = psi_pointer;
105+
this->allocate_inside = false;
106+
// Currently only GPU's implementation is supported for device recording!
107+
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
108+
}
109+
110+
/// New Constructor
111+
template <typename T, typename Device> Psi<T, Device>::Psi(T* psi_pointer,
112+
const int nk_in,
113+
const int nbd_in,
114+
const int nbs_in,
115+
const std::vector<int> ngk_vector_in,
116+
const bool k_first_in)
117+
{
118+
this->k_first = k_first_in;
119+
this->ngk = ngk_vector_in.data();
73120
this->current_b = 0;
74121
this->current_k = 0;
75122
this->npol = PARAM.globalv.npol;

source/module_psi/psi.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,31 @@ class Psi
5252
// Constructor 7: initialize a new psi from the given psi_in with a different class template
5353
// in this case, psi_in may have a different device type.
5454
template <typename T_in, typename Device_in = Device> Psi(const Psi<T_in, Device_in>& psi_in);
55-
// Constructor 8: a pointer version of constructor 3
56-
Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true);
55+
5756
// Destructor for deleting the psi array manually
5857
~Psi();
5958

59+
60+
61+
// // Constructor 8: a pointer version of constructor 3
62+
// Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true);
63+
64+
65+
Psi(T* psi_pointer,
66+
const int nk_in,
67+
const int nbd_in,
68+
const int nbs_in,
69+
const bool k_first_in = true);
70+
71+
/// New Constructor
72+
Psi(T* psi_pointer,
73+
const int nk_in,
74+
const int nbd_in,
75+
const int nbs_in,
76+
const std::vector<int> ngk_vector_in,
77+
const bool k_first_in = true);
78+
79+
6080
// allocate psi for three dimensions
6181
void resize(const int nks_in, const int nbands_in, const int nbasis_in);
6282

@@ -140,6 +160,8 @@ class Psi
140160

141161
const int* ngk = nullptr;
142162

163+
std::vector<int> ngk_vector;
164+
143165
bool k_first = true;
144166

145167
bool allocate_inside = true; ///<whether allocate psi inside Psi class

0 commit comments

Comments
 (0)