Skip to content

Commit 7938206

Browse files
authored
Refactor: refactor solve func in hsolver-lcao class (#5257)
* refactor hsolver-lcao code * refactor hsolver-lcao
1 parent 3aec2be commit 7938206

File tree

1 file changed

+22
-160
lines changed

1 file changed

+22
-160
lines changed

source/module_hsolver/hsolver_lcao.cpp

Lines changed: 22 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#ifdef __MPI
44
#include "diago_scalapack.h"
5+
#include "module_base/scalapack_connector.h"
56
#else
67
#include "diago_lapack.h"
78
#endif
@@ -24,22 +25,12 @@
2425
#include "module_elecstate/elecstate_lcao.h"
2526
#endif
2627

27-
#include "diago_cg.h"
2828
#include "module_base/global_variable.h"
2929
#include "module_base/memory.h"
30-
#include "module_base/scalapack_connector.h"
3130
#include "module_base/timer.h"
32-
#include "module_hsolver/diago_iter_assist.h"
33-
#include "module_hsolver/kernels/math_kernel_op.h"
3431
#include "module_hsolver/parallel_k2d.h"
35-
#include "module_io/write_HS.h"
3632
#include "module_parameter/parameter.h"
3733

38-
#include <ATen/core/tensor.h>
39-
#include <ATen/core/tensor_map.h>
40-
#include <ATen/core/tensor_types.h>
41-
#include <unistd.h>
42-
4334
namespace hsolver
4435
{
4536

@@ -74,51 +65,39 @@ void HSolverLCAO<T, Device>::solve(hamilt::Hamilt<T>* pHamilt,
7465
}
7566
#endif
7667

77-
if (this->method == "cg_in_lcao")
78-
{
79-
this->precondition_lcao.resize(psi.get_nbasis());
80-
81-
using Real = typename GetTypeReal<T>::type;
82-
// set precondition
83-
for (size_t i = 0; i < precondition_lcao.size(); i++)
84-
{
85-
precondition_lcao[i] = 1.0;
86-
}
87-
}
88-
89-
#ifdef __MPI
9068
if (GlobalV::KPAR_LCAO > 1
9169
&& (this->method == "genelpa" || this->method == "elpa" || this->method == "scalapack_gvx"))
9270
{
71+
#ifdef __MPI
9372
this->parakSolve(pHamilt, psi, pes, GlobalV::KPAR_LCAO);
94-
}
95-
else
9673
#endif
74+
}
75+
else if (GlobalV::KPAR_LCAO == 1)
9776
{
98-
/// Loop over k points for solve Hamiltonian to charge density
77+
/// Loop over k points for solve Hamiltonian to eigenpairs(eigenvalues and eigenvectors).
9978
for (int ik = 0; ik < psi.get_nk(); ++ik)
10079
{
10180
/// update H(k) for each k point
10281
pHamilt->updateHk(ik);
10382

83+
/// find psi pointer for each k point
10484
psi.fix_k(ik);
10585

106-
// solve eigenvector and eigenvalue for H(k)
86+
/// solve eigenvector and eigenvalue for H(k)
10787
this->hamiltSolvePsiK(pHamilt, psi, &(pes->ekb(ik, 0)));
10888
}
10989
}
11090

111-
if (skip_charge) // used in nscf calculation
91+
if (!skip_charge) // used in scf calculation
11292
{
113-
ModuleBase::timer::tick("HSolverLCAO", "solve");
93+
// calculate charge by eigenpairs(eigenvalues and eigenvectors)
94+
pes->psiToRho(psi);
11495
}
115-
else // used in scf calculation
96+
else // used in nscf calculation
11697
{
117-
// calculate charge by psi
118-
pes->psiToRho(psi);
119-
ModuleBase::timer::tick("HSolverLCAO", "solve");
12098
}
12199

100+
ModuleBase::timer::tick("HSolverLCAO", "solve");
122101
return;
123102
}
124103

@@ -135,6 +114,7 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
135114
sa.diag(hm, psi, eigenvalue);
136115
#endif
137116
}
117+
138118
#ifdef __ELPA
139119
else if (this->method == "genelpa")
140120
{
@@ -147,151 +127,33 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
147127
el.diag(hm, psi, eigenvalue);
148128
}
149129
#endif
130+
150131
#ifdef __CUDA
151132
else if (this->method == "cusolver")
152133
{
153134
DiagoCusolver<T> cs(this->ParaV);
154135
cs.diag(hm, psi, eigenvalue);
155136
}
137+
#ifdef __CUSOLVERMP
156138
else if (this->method == "cusolvermp")
157139
{
158-
#ifdef __CUSOLVERMP
159140
DiagoCusolverMP<T> cm;
160141
cm.diag(hm, psi, eigenvalue);
161-
#else
162-
ModuleBase::WARNING_QUIT("HSolverLCAO", "CUSOLVERMP did not compiled!");
163-
#endif
164142
}
165143
#endif
166-
else if (this->method == "lapack")
167-
{
144+
#endif
145+
168146
#ifndef __MPI
147+
else if (this->method == "lapack") // only for single core
148+
{
169149
DiagoLapack<T> la;
170150
la.diag(hm, psi, eigenvalue);
171-
#else
172-
ModuleBase::WARNING_QUIT("HSolverLCAO::solve", "This type of eigensolver is not supported!");
173-
#endif
174151
}
152+
#endif
153+
175154
else
176155
{
177-
178-
using ct_Device = typename ct::PsiToContainer<base_device::DEVICE_CPU>::type;
179-
180-
auto subspace_func = [](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
181-
// psi_in should be a 2D tensor:
182-
// psi_in.shape() = [nbands, nbasis]
183-
const auto ndim = psi_in.shape().ndim();
184-
REQUIRES_OK(ndim == 2, "dims of psi_in should be less than or equal to 2");
185-
};
186-
187-
DiagoCG<T, Device> cg(PARAM.inp.basis_type,
188-
PARAM.inp.calculation,
189-
DiagoIterAssist<T, Device>::need_subspace,
190-
subspace_func,
191-
DiagoIterAssist<T, Device>::PW_DIAG_THR,
192-
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
193-
GlobalV::NPROC_IN_POOL);
194-
195-
hamilt::MatrixBlock<T> h_mat, s_mat;
196-
hm->matrix(h_mat, s_mat);
197-
198-
// set h_mat & s_mat
199-
for (int i = 0; i < h_mat.row; i++)
200-
{
201-
for (int j = i; j < h_mat.col; j++)
202-
{
203-
h_mat.p[h_mat.row * j + i] = hsolver::get_conj(h_mat.p[h_mat.row * i + j]);
204-
s_mat.p[s_mat.row * j + i] = hsolver::get_conj(s_mat.p[s_mat.row * i + j]);
205-
}
206-
}
207-
208-
const T *one_ = nullptr, *zero_ = nullptr;
209-
one_ = new T(static_cast<T>(1.0));
210-
zero_ = new T(static_cast<T>(0.0));
211-
212-
auto hpsi_func = [h_mat, one_, zero_](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
213-
ModuleBase::timer::tick("DiagoCG_New", "hpsi_func");
214-
// psi_in should be a 2D tensor:
215-
// psi_in.shape() = [nbands, nbasis]
216-
const auto ndim = psi_in.shape().ndim();
217-
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
218-
219-
Device* ctx = {};
220-
221-
gemv_op<T, Device>()(ctx,
222-
'N',
223-
h_mat.row,
224-
h_mat.col,
225-
one_,
226-
h_mat.p,
227-
h_mat.row,
228-
psi_in.data<T>(),
229-
1,
230-
zero_,
231-
hpsi_out.data<T>(),
232-
1);
233-
234-
ModuleBase::timer::tick("DiagoCG_New", "hpsi_func");
235-
};
236-
237-
auto spsi_func = [s_mat, one_, zero_](const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
238-
ModuleBase::timer::tick("DiagoCG_New", "spsi_func");
239-
// psi_in should be a 2D tensor:
240-
// psi_in.shape() = [nbands, nbasis]
241-
const auto ndim = psi_in.shape().ndim();
242-
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
243-
244-
Device* ctx = {};
245-
246-
gemv_op<T, Device>()(ctx,
247-
'N',
248-
s_mat.row,
249-
s_mat.col,
250-
one_,
251-
s_mat.p,
252-
s_mat.row,
253-
psi_in.data<T>(),
254-
1,
255-
zero_,
256-
spsi_out.data<T>(),
257-
1);
258-
259-
ModuleBase::timer::tick("DiagoCG_New", "spsi_func");
260-
};
261-
262-
// if (this->is_first_scf)
263-
// {
264-
for (size_t i = 0; i < psi.get_nbands(); i++)
265-
{
266-
for (size_t j = 0; j < psi.get_nbasis(); j++)
267-
{
268-
psi(i, j) = *zero_;
269-
}
270-
psi(i, i) = *one_;
271-
}
272-
// }
273-
274-
auto psi_tensor = ct::TensorMap(psi.get_pointer(),
275-
ct::DataTypeToEnum<T>::value,
276-
ct::DeviceTypeToEnum<ct_Device>::value,
277-
ct::TensorShape({psi.get_nbands(), psi.get_nbasis()}))
278-
.slice({0, 0}, {psi.get_nbands(), psi.get_current_nbas()});
279-
280-
auto eigen_tensor = ct::TensorMap(eigenvalue,
281-
ct::DataTypeToEnum<Real>::value,
282-
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
283-
ct::TensorShape({psi.get_nbands()}));
284-
285-
auto prec_tensor = ct::TensorMap(this->precondition_lcao.data(),
286-
ct::DataTypeToEnum<Real>::value,
287-
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
288-
ct::TensorShape({static_cast<int>(this->precondition_lcao.size())}))
289-
.slice({0}, {psi.get_current_nbas()});
290-
291-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
292-
293-
// TODO: Double check tensormap's potential problem
294-
ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor);
156+
ModuleBase::WARNING_QUIT("HSolverLCAO::solve", "This method is not supported for lcao basis in ABACUS!");
295157
}
296158

297159
ModuleBase::timer::tick("HSolverLCAO", "hamiltSolvePsiK");

0 commit comments

Comments
 (0)