Skip to content

Commit 51e2a78

Browse files
authored
Refactor: remove dependence of Tensor for cg interface (#7000)
* Refactor cg interface from Tensor to T * * Remove redundant code * Remove redundant code
1 parent 7035e85 commit 51e2a78

8 files changed

Lines changed: 223 additions & 228 deletions

File tree

source/source_hsolver/diago_cg.cpp

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ void DiagoCG<T, Device>::diag_once(const ct::Tensor& prec_in,
122122
{
123123
phi_m.sync(psi[m]);
124124
// copy psi_in into internal psi, m=0 has been done in Constructor
125-
this->spsi_func_(phi_m, sphi); // sphi = S|psi(m)>
125+
this->spsi_func_(phi_m.data<T>(), sphi.data<T>(), this->n_basis_, 1); // sphi = S|psi(m)>
126126
this->schmit_orth(m, psi, sphi, phi_m);
127-
this->spsi_func_(phi_m, sphi); // sphi = S|psi(m)>
128-
this->hpsi_func_(phi_m, hphi); // hphi = H|psi(m)>
127+
this->spsi_func_(phi_m.data<T>(), sphi.data<T>(), this->n_basis_, 1); // sphi = S|psi(m)>
128+
this->hpsi_func_(phi_m.data<T>(), hphi.data<T>(), this->n_basis_, 1); // hphi = H|psi(m)>
129129

130130
eigen_pack[m] = dot_real_op()(this->n_basis_, phi_m.data<T>(), hphi.data<T>());
131131

@@ -150,8 +150,8 @@ void DiagoCG<T, Device>::diag_once(const ct::Tensor& prec_in,
150150
g0,
151151
cg); // Tensor&
152152

153-
this->hpsi_func_(cg, pphi);
154-
this->spsi_func_(cg, scg);
153+
this->hpsi_func_(cg.data<T>(), pphi.data<T>(), this->n_basis_, 1);
154+
this->spsi_func_(cg.data<T>(), scg.data<T>(), this->n_basis_, 1);
155155

156156
converged = this->update_psi(pphi,
157157
cg,
@@ -264,7 +264,7 @@ void DiagoCG<T, Device>::orth_grad(const ct::Tensor& psi,
264264
ct::Tensor& scg,
265265
ct::Tensor& lagrange)
266266
{
267-
this->spsi_func_(grad, scg); // scg = S|grad>
267+
this->spsi_func_(grad.data<T>(), scg.data<T>(), this->n_basis_, 1); // scg = S|grad>
268268
ModuleBase::gemv_op<T, Device>()('C',
269269
this->n_basis_,
270270
m,
@@ -576,21 +576,47 @@ bool DiagoCG<T, Device>::test_exit_cond(const int& ntry, const int& notconv) con
576576
}
577577

578578
template <typename T, typename Device>
579-
double DiagoCG<T, Device>::diag(const Func& hpsi_func,
580-
const Func& spsi_func,
581-
ct::Tensor& psi,
582-
ct::Tensor& eigen,
583-
const std::vector<double>& ethr_band,
584-
const ct::Tensor& prec)
579+
double DiagoCG<T, Device>::diag(const HPsiFunc& hpsi_func,
580+
const SPsiFunc& spsi_func,
581+
const int ld_psi,
582+
const int nband,
583+
const int dim,
584+
T* psi_in,
585+
Real* eigenvalue_in,
586+
const std::vector<double>& ethr_band,
587+
const Real* prec)
585588
{
589+
REQUIRES_OK(ld_psi >= dim, "DiagoCG::diag: ld_psi must be >= dim");
590+
REQUIRES_OK(static_cast<int>(ethr_band.size()) >= nband,
591+
"DiagoCG::diag: ethr_band size must be >= nband");
592+
593+
auto psi = ct::TensorMap(psi_in,
594+
ct::DataTypeToEnum<T>::value,
595+
ct::DeviceTypeToEnum<ct_Device>::value,
596+
ct::TensorShape({nband, ld_psi}));
597+
auto eigen = ct::TensorMap(eigenvalue_in,
598+
ct::DataTypeToEnum<Real>::value,
599+
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
600+
ct::TensorShape({nband}));
601+
602+
ct::Tensor prec_tensor;
603+
if (prec != nullptr)
604+
{
605+
prec_tensor = ct::TensorMap(const_cast<Real*>(prec),
606+
ct::DataTypeToEnum<Real>::value,
607+
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
608+
ct::TensorShape({dim}))
609+
.template to_device<ct_Device>();
610+
}
611+
586612
/// record the times of trying iterative diagonalization
587613
int ntry = 0;
588614
this->notconv_ = 0;
589615
hpsi_func_ = hpsi_func;
590616
spsi_func_ = spsi_func;
591617

592618
// create a new slice of psi to do cg diagonalization
593-
ct::Tensor psi_temp = psi.slice({0, 0}, {int(psi.shape().dim_size(0)), int(prec.shape().dim_size(0))});
619+
ct::Tensor psi_temp = psi.slice({0, 0}, {nband, dim});
594620
do
595621
{
596622
// subspace diagonalization to get a better starting guess
@@ -601,21 +627,29 @@ double DiagoCG<T, Device>::diag(const Func& hpsi_func,
601627
{
602628
ct::TensorMap psi_map = ct::TensorMap(psi.data(), psi_temp);
603629
const bool assume_S_orthogonal = true;
604-
this->subspace_func_(psi_temp, psi_map, assume_S_orthogonal);
630+
this->subspace_func_(psi_temp.data<T>(),
631+
psi_map.data<T>(),
632+
dim,
633+
nband,
634+
assume_S_orthogonal);
605635
psi_temp.sync(psi_map);
606636
}
607637
else if (need_subspace_)
608638
{
609639
ct::TensorMap psi_map = ct::TensorMap(psi.data(), psi_temp);
610640
const bool assume_S_orthogonal = false;
611-
this->subspace_func_(psi_temp, psi_map, assume_S_orthogonal);
641+
this->subspace_func_(psi_temp.data<T>(),
642+
psi_map.data<T>(),
643+
dim,
644+
nband,
645+
assume_S_orthogonal);
612646
psi_temp.sync(psi_map);
613647
}
614648

615649

616650
++ntry;
617651
avg_iter_ += 1.0;
618-
this->diag_once(prec, psi_temp, eigen, ethr_band);
652+
this->diag_once(prec_tensor, psi_temp, eigen, ethr_band);
619653
} while (this->test_exit_cond(ntry, this->notconv_));
620654

621655
if (this->notconv_ > std::max(5, this->n_band_ / 4))

source/source_hsolver/diago_cg.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ class DiagoCG final
2222
using Real = typename GetTypeReal<T>::type;
2323
using ct_Device = typename ct::PsiToContainer<Device>::type;
2424
public:
25-
using Func = std::function<void(const ct::Tensor&, ct::Tensor&)>;
26-
using SubspaceFunc = std::function<void(const ct::Tensor&, ct::Tensor&, const bool)>;
25+
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
26+
using SPsiFunc = std::function<void(T*, T*, const int, const int)>;
27+
using SubspaceFunc = std::function<void(T*, T*, const int, const int, const bool)>;
2728
// Constructor need:
2829
// 1. temporary mock of Hamiltonian "Hamilt_PW"
2930
// 2. precondition pointer should point to place of precondition array.
@@ -43,12 +44,15 @@ class DiagoCG final
4344
// refactor hpsi_info
4445
// this is the diag() function for CG method
4546
// returns avg_iter
46-
double diag(const Func& hpsi_func,
47-
const Func& spsi_func,
48-
ct::Tensor& psi,
49-
ct::Tensor& eigen,
50-
const std::vector<double>& ethr_band,
51-
const ct::Tensor& prec = {});
47+
double diag(const HPsiFunc& hpsi_func,
48+
const SPsiFunc& spsi_func,
49+
const int ld_psi,
50+
const int nband,
51+
const int dim,
52+
T* psi_in,
53+
Real* eigenvalue_in,
54+
const std::vector<double>& ethr_band,
55+
const Real* prec = nullptr);
5256

5357
private:
5458
Device * ctx_ = {};
@@ -77,9 +81,9 @@ class DiagoCG final
7781

7882
bool need_subspace_ = false;
7983
/// A function object that performs the hPsi calculation.
80-
Func hpsi_func_ = nullptr;
84+
HPsiFunc hpsi_func_ = nullptr;
8185
/// A function object that performs the sPsi calculation.
82-
Func spsi_func_ = nullptr;
86+
SPsiFunc spsi_func_ = nullptr;
8387
/// A function object that performs the subspace calculation.
8488
SubspaceFunc subspace_func_ = nullptr;
8589

source/source_hsolver/hsolver_pw.cpp

Lines changed: 28 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -254,27 +254,15 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
254254
// wrap the subspace_func into a lambda function
255255
// if S_orth is true, then assume psi is S-orthogonal, solve standard eigenproblem
256256
// otherwise, solve generalized eigenproblem
257-
auto subspace_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool S_orth) {
258-
// psi_in should be a 2D tensor:
259-
// psi_in.shape() = [nbands, nbasis]
260-
const auto ndim = psi_in.shape().ndim();
261-
REQUIRES_OK(ndim == 2, "dims of psi_in should be less than or equal to 2");
262-
// Convert a Tensor object to a psi::Psi object
263-
auto psi_in_wrapper = psi::Psi<T, Device>(psi_in.data<T>(),
264-
1,
265-
psi_in.shape().dim_size(0),
266-
psi_in.shape().dim_size(1),
267-
cur_nbasis);
268-
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
269-
1,
270-
psi_out.shape().dim_size(0),
271-
psi_out.shape().dim_size(1),
272-
cur_nbasis);
273-
auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
274-
ct::DeviceType::CpuDevice,
275-
ct::TensorShape({psi_in.shape().dim_size(0)}));
276-
277-
DiagoIterAssist<T, Device>::diag_subspace(hm, psi_in_wrapper, psi_out_wrapper, eigen.data<Real>());
257+
auto subspace_func = [hm, cur_nbasis](T* psi_in,
258+
T* psi_out,
259+
const int ld_psi,
260+
const int nband,
261+
const bool S_orth) {
262+
auto psi_in_wrapper = psi::Psi<T, Device>(psi_in, 1, nband, ld_psi, cur_nbasis);
263+
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out, 1, nband, ld_psi, cur_nbasis);
264+
std::vector<Real> eigen(nband, 0.0);
265+
DiagoIterAssist<T, Device>::diag_subspace(hm, psi_in_wrapper, psi_out_wrapper, eigen.data());
278266
};
279267
DiagoCG<T, Device> cg(this->basis_type,
280268
this->calculation_type,
@@ -284,70 +272,38 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
284272
this->diag_iter_max,
285273
this->nproc_in_pool);
286274

287-
// wrap the hpsi_func and spsi_func into a lambda function
288-
using ct_Device = typename ct::PsiToContainer<Device>::type;
289-
290-
// wrap the hpsi_func and spsi_func into a lambda function
291-
auto hpsi_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
292-
// psi_in should be a 2D tensor:
293-
// psi_in.shape() = [nbands, nbasis]
294-
const auto ndim = psi_in.shape().ndim();
295-
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
296-
// Convert a Tensor object to a psi::Psi object
297-
auto psi_wrapper = psi::Psi<T, Device>(psi_in.data<T>(),
298-
1,
299-
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
300-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
301-
cur_nbasis);
302-
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
275+
// wrap the hpsi_func and spsi_func into lambda functions
276+
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
277+
auto psi_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);
278+
psi::Range all_bands_range(true, 0, 0, nvec - 1);
303279
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
304-
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<T>());
280+
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out);
305281
hm->ops->hPsi(info);
306282
};
307-
auto spsi_func = [this, hm](const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
308-
// psi_in should be a 2D tensor:
309-
// psi_in.shape() = [nbands, nbasis]
310-
const auto ndim = psi_in.shape().ndim();
311-
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
312-
283+
auto spsi_func = [this, hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) {
313284
if (this->use_uspp)
314285
{
315-
// Convert a Tensor object to a psi::Psi object
316-
hm->sPsi(psi_in.data<T>(),
317-
spsi_out.data<T>(),
318-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
319-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
320-
ndim == 1 ? 1 : psi_in.shape().dim_size(0));
286+
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
321287
}
322288
else
323289
{
324290
base_device::memory::synchronize_memory_op<T, Device, Device>()(
325-
spsi_out.data<T>(),
326-
psi_in.data<T>(),
327-
static_cast<size_t>((ndim == 1 ? 1 : psi_in.shape().dim_size(0))
328-
* (ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1))));
291+
spsi_out,
292+
psi_in,
293+
static_cast<size_t>(nvec) * static_cast<size_t>(ld_psi));
329294
}
330295
};
331296

332-
auto psi_tensor = ct::TensorMap(psi.get_pointer(),
333-
ct::DataTypeToEnum<T>::value,
334-
ct::DeviceTypeToEnum<ct_Device>::value,
335-
ct::TensorShape({psi.get_nbands(), psi.get_nbasis()}));
336-
337-
auto eigen_tensor = ct::TensorMap(eigenvalue,
338-
ct::DataTypeToEnum<Real>::value,
339-
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
340-
ct::TensorShape({psi.get_nbands()}));
341-
342-
auto prec_tensor = ct::TensorMap(pre_condition.data(),
343-
ct::DataTypeToEnum<Real>::value,
344-
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
345-
ct::TensorShape({static_cast<int>(pre_condition.size())}))
346-
.to_device<ct_Device>()
347-
.slice({0}, {psi.get_current_ngk()});
348-
349297
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
350-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor)
298+
cg.diag(hpsi_func,
299+
spsi_func,
300+
psi.get_nbasis(),
301+
psi.get_nbands(),
302+
psi.get_current_ngk(),
303+
psi.get_pointer(),
304+
eigenvalue,
305+
this->ethr_band,
306+
pre_condition.data())
351307
);
352308
// TODO: Double check tensormap's potential problem
353309
// ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor);

source/source_hsolver/test/diago_cg_float_test.cpp

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,19 @@ class DiagoCGPrepare
142142
// New interface of cg method
143143
/**************************************************************/
144144
// warp the subspace_func into a lambda function
145-
auto subspace_func = [ha](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool S_orth) { /*do nothing*/ };
145+
auto subspace_func = [ha](std::complex<float>* psi_in,
146+
std::complex<float>* psi_out,
147+
const int ld_psi,
148+
const int nband,
149+
const bool S_orth) {
150+
auto psi_in_wrapper = psi::Psi<std::complex<float>>(psi_in, 1, nband, ld_psi, true);
151+
auto psi_out_wrapper = psi::Psi<std::complex<float>>(psi_out, 1, nband, ld_psi, true);
152+
std::vector<float> eigen(nband, 0.0f);
153+
hsolver::DiagoIterAssist<std::complex<float>>::diag_subspace(ha,
154+
psi_in_wrapper,
155+
psi_out_wrapper,
156+
eigen.data());
157+
};
146158
hsolver::DiagoCG<std::complex<float>> cg(
147159
PARAM.input.basis_type,
148160
PARAM.input.calculation,
@@ -156,46 +168,33 @@ class DiagoCGPrepare
156168
float start, end;
157169
start = MPI_Wtime();
158170

159-
auto hpsi_func = [ha](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
160-
const auto ndim = psi_in.shape().ndim();
161-
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
162-
auto psi_wrapper = psi::Psi<std::complex<float>>(
163-
psi_in.data<std::complex<float>>(), 1,
164-
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
165-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true);
166-
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
171+
auto hpsi_func = [ha](std::complex<float>* psi_in,
172+
std::complex<float>* hpsi_out,
173+
const int ld_psi,
174+
const int nvec) {
175+
auto psi_wrapper = psi::Psi<std::complex<float>>(psi_in, 1, nvec, ld_psi, true);
176+
psi::Range all_bands_range(true, 0, 0, nvec - 1);
167177
using hpsi_info = typename hamilt::Operator<std::complex<float>>::hpsi_info;
168-
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<std::complex<float>>());
178+
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out);
169179
ha->ops->hPsi(info);
170180
};
171-
auto spsi_func = [ha](const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
172-
const auto ndim = psi_in.shape().ndim();
173-
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
174-
ha->sPsi(psi_in.data<std::complex<float>>(), spsi_out.data<std::complex<float>>(),
175-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
176-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
177-
ndim == 1 ? 1 : psi_in.shape().dim_size(0));
181+
auto spsi_func = [ha](std::complex<float>* psi_in,
182+
std::complex<float>* spsi_out,
183+
const int ld_psi,
184+
const int nvec) {
185+
ha->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
178186
};
179-
auto psi_tensor = ct::TensorMap(
180-
psi_local.get_pointer(),
181-
ct::DataType::DT_COMPLEX,
182-
ct::DeviceType::CpuDevice,
183-
ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_ngk()});
184-
auto eigen_tensor = ct::TensorMap(
185-
en,
186-
ct::DataType::DT_FLOAT,
187-
ct::DeviceType::CpuDevice,
188-
ct::TensorShape({psi_local.get_nbands()}));
189-
auto prec_tensor = ct::TensorMap(
190-
precondition_local,
191-
ct::DataType::DT_FLOAT,
192-
ct::DeviceType::CpuDevice,
193-
ct::TensorShape({static_cast<int>(psi_local.get_current_ngk())})).slice({0}, {psi_local.get_current_ngk()});
194187

195188
std::vector<double> ethr_band(nband, 1e-5);
196-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);
197-
// TODO: Double check tensormap's potential problem
198-
ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor);
189+
cg.diag(hpsi_func,
190+
spsi_func,
191+
psi_local.get_nbasis(),
192+
psi_local.get_nbands(),
193+
psi_local.get_current_ngk(),
194+
psi_local.get_pointer(),
195+
en,
196+
ethr_band,
197+
precondition_local);
199198
/**************************************************************/
200199

201200
end = MPI_Wtime();

0 commit comments

Comments
 (0)