Skip to content

Commit d5634b3

Browse files
committed
update get_cur_effective_basis
1 parent 0b0604c commit d5634b3

File tree

9 files changed

+18
-25
lines changed

9 files changed

+18
-25
lines changed

source/module_hsolver/diago_iter_assist.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
4949
setmem_complex_op()(ctx, scc, 0, nstart * nstart);
5050
setmem_complex_op()(ctx, vcc, 0, nstart * nstart);
5151

52-
const int dmin = psi.get_current_nbas();
52+
const int dmin = psi.get_cur_effective_basis();
5353
const int dmax = psi.get_nbasis();
5454

5555
T* temp = nullptr;
@@ -167,7 +167,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
167167
const int nstart = psi_nr;
168168
const int n_band = evc.get_nbands();
169169
const int dmax = evc.get_nbasis();
170-
const int dmin = evc.get_current_nbas();
170+
const int dmin = evc.get_cur_effective_basis();
171171

172172
// skip the diagonalization if the operators are not allocated
173173
if (pHamilt->ops == nullptr)
@@ -264,7 +264,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
264264

265265
T* spsi = temp;
266266
// do sPsi for all bands
267-
pHamilt->sPsi(ppsi, spsi, psi_temp.get_nbasis(), psi_temp.get_current_nbas(), psi_temp.get_nbands());
267+
pHamilt->sPsi(ppsi, spsi, psi_temp.get_nbasis(), psi_temp.get_cur_effective_basis(), psi_temp.get_nbands());
268268

269269
gemm_op<T, Device>()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, spsi, dmax, &zero, scc, nstart);
270270
delmem_complex_op()(ctx, temp);
@@ -423,7 +423,7 @@ void DiagoIterAssist<T, Device>::cal_hs_subspace(const hamilt::Hamilt<T, Device>
423423
setmem_complex_op()(ctx, hcc, 0, nstart * nstart);
424424
setmem_complex_op()(ctx, scc, 0, nstart * nstart);
425425

426-
const int dmin = psi.get_current_nbas();
426+
const int dmin = psi.get_cur_effective_basis();
427427
const int dmax = psi.get_nbasis();
428428

429429
T* temp = nullptr;
@@ -549,7 +549,7 @@ void DiagoIterAssist<T, Device>::diag_subspace_psi(const T* hcc,
549549
DiagoIterAssist::diagH_LAPACK(nstart, nstart, hcc, scc, nstart, en, vcc);
550550

551551
{ // code block to calculate tar_mat
552-
const int dmin = evc.get_current_nbas();
552+
const int dmin = evc.get_cur_effective_basis();
553553
const int dmax = evc.get_nbasis();
554554
T* temp = nullptr;
555555
resmem_complex_op()(ctx, temp, nstart * dmax, "DiagSub::temp");

source/module_hsolver/hsolver_pw.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
480480
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
481481
ct::TensorShape({static_cast<int>(pre_condition.size())}))
482482
.to_device<ct_Device>()
483-
.slice({0}, {psi.get_current_nbas()});
483+
.slice({0}, {psi.get_cur_effective_basis()});
484484

485485
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor);
486486
// TODO: Double check tensormap's potential problem
@@ -530,7 +530,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
530530

531531
Diago_DavSubspace<T, Device> dav_subspace(pre_condition,
532532
psi.get_nbands(),
533-
psi.get_k_first() ? psi.get_current_nbas()
533+
psi.get_k_first() ? psi.get_cur_effective_basis()
534534
: psi.get_nk() * psi.get_nbasis(),
535535
PARAM.inp.pw_diag_ndim,
536536
this->diag_thr,
@@ -556,7 +556,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
556556
const int david_maxiter = this->diag_iter_max;
557557

558558
// dimensions of matrix to be solved
559-
const int dim = psi.get_current_nbas(); /// dimension of matrix
559+
const int dim = psi.get_cur_effective_basis(); /// dimension of matrix
560560
const int nband = psi.get_nbands(); /// number of eigenpairs sought
561561
const int ld_psi = psi.get_nbasis(); /// leading dimension of psi
562562

source/module_hsolver/test/diago_cg_float_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class DiagoCGPrepare
182182
psi_local.get_pointer(),
183183
ct::DataType::DT_COMPLEX,
184184
ct::DeviceType::CpuDevice,
185-
ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_nbas()});
185+
ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_cur_effective_basis()});
186186
auto eigen_tensor = ct::TensorMap(
187187
en,
188188
ct::DataType::DT_FLOAT,
@@ -192,7 +192,7 @@ class DiagoCGPrepare
192192
precondition_local,
193193
ct::DataType::DT_FLOAT,
194194
ct::DeviceType::CpuDevice,
195-
ct::TensorShape({static_cast<int>(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()});
195+
ct::TensorShape({static_cast<int>(psi_local.get_cur_effective_basis())})).slice({0}, {psi_local.get_cur_effective_basis()});
196196

197197
std::vector<double> ethr_band(nband, 1e-5);
198198
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);

source/module_hsolver/test/diago_cg_real_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class DiagoCGPrepare
185185
psi_local.get_pointer(),
186186
ct::DataType::DT_DOUBLE,
187187
ct::DeviceType::CpuDevice,
188-
ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_nbas()});
188+
ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_cur_effective_basis()});
189189
auto eigen_tensor = ct::TensorMap(
190190
en,
191191
ct::DataType::DT_DOUBLE,
@@ -195,7 +195,7 @@ class DiagoCGPrepare
195195
precondition_local,
196196
ct::DataType::DT_DOUBLE,
197197
ct::DeviceType::CpuDevice,
198-
ct::TensorShape({static_cast<int>(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()});
198+
ct::TensorShape({static_cast<int>(psi_local.get_cur_effective_basis())})).slice({0}, {psi_local.get_cur_effective_basis()});
199199

200200
std::vector<double> ethr_band(nband, 1e-5);
201201
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);

source/module_hsolver/test/diago_cg_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class DiagoCGPrepare
176176
psi_local.get_pointer(),
177177
ct::DataType::DT_COMPLEX_DOUBLE,
178178
ct::DeviceType::CpuDevice,
179-
ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_nbas()});
179+
ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_cur_effective_basis()});
180180
auto eigen_tensor = ct::TensorMap(
181181
en,
182182
ct::DataType::DT_DOUBLE,
@@ -186,7 +186,7 @@ class DiagoCGPrepare
186186
precondition_local,
187187
ct::DataType::DT_DOUBLE,
188188
ct::DeviceType::CpuDevice,
189-
ct::TensorShape({static_cast<int>(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()});
189+
ct::TensorShape({static_cast<int>(psi_local.get_cur_effective_basis())})).slice({0}, {psi_local.get_cur_effective_basis()});
190190

191191
std::vector<double> ethr_band(nband, 1e-5);
192192
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);

source/module_hsolver/test/diago_david_float_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class DiagoDavPrepare
9090
const hsolver::diag_comm_info comm_info = {mypnum, nprocs};
9191
#endif
9292

93-
const int dim = phi.get_current_nbas() ;
93+
const int dim = phi.get_cur_effective_basis() ;
9494
const int nband = phi.get_nbands();
9595
const int ld_psi =phi.get_nbasis();
9696
hsolver::DiagoDavid<std::complex<float>> dav(precondition, nband, dim, order, false, comm_info);

source/module_hsolver/test/diago_david_real_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class DiagoDavPrepare
8989
const hsolver::diag_comm_info comm_info = {mypnum, nprocs};
9090
#endif
9191

92-
const int dim = phi.get_current_nbas();
92+
const int dim = phi.get_cur_effective_basis();
9393
const int nband = phi.get_nbands();
9494
const int ld_psi = phi.get_nbasis();
9595
hsolver::DiagoDavid<double> dav(precondition, nband, dim, order, false, comm_info);

source/module_hsolver/test/diago_david_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class DiagoDavPrepare
9292
const hsolver::diag_comm_info comm_info = {mypnum, nprocs};
9393
#endif
9494

95-
const int dim = phi.get_current_nbas();
95+
const int dim = phi.get_cur_effective_basis();
9696
const int nband = phi.get_nbands();
9797
const int ld_psi = phi.get_nbasis();
9898
hsolver::DiagoDavid<std::complex<double>> dav(precondition, nband, dim, order, false, comm_info);

source/module_psi/psi.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,7 @@ const int& Psi<T, Device>::get_cur_effective_basis() const
300300
{
301301
if (this->npol == 1)
302302
{
303-
// if (this->ngk != nullptr)
304-
// {
305-
// return this->ngk[this->current_k];
306-
// }
307-
// else
308-
{
309-
return this->current_nbasis;
310-
}
303+
return this->current_nbasis;
311304
}
312305
else
313306
{

0 commit comments

Comments
 (0)