Skip to content

Commit 9b6651c

Browse files
committed
Update subspace_func where needed
1 parent e6a55ed commit 9b6651c

File tree

4 files changed

+27
-19
lines changed

4 files changed

+27
-19
lines changed

source/source_hsolver/diago_cg.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ DiagoCG<T, Device>::~DiagoCG()
5454
}
5555

5656
template <typename T, typename Device>
57-
void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in,
57+
void DiagoCG<T, Device>::diag_once(const ct::Tensor& prec_in,
5858
ct::Tensor& psi,
5959
ct::Tensor& eigen,
6060
const std::vector<double>& ethr_band)
@@ -592,16 +592,29 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
592592
ct::Tensor psi_temp = psi.slice({0, 0}, {int(psi.shape().dim_size(0)), int(prec.shape().dim_size(0))});
593593
do
594594
{
595-
if (need_subspace_ || ntry > 0)
595+
// subspace diagonalization to get a better starting guess
596+
// for cg diagonalization, restart from current psi approximation
597+
// Note: if not the first try, then psi is already S-orthogonalized by CG iterations!
598+
// Otherwise, if the first try, then psi is not assumed to be S-orthogonalized
599+
if (ntry > 0)
596600
{
597601
ct::TensorMap psi_map = ct::TensorMap(psi.data(), psi_temp);
598-
this->subspace_func_(psi_temp, psi_map);
602+
const bool assume_S_orthogonal = true;
603+
this->subspace_func_(psi_temp, psi_map, assume_S_orthogonal);
599604
psi_temp.sync(psi_map);
600605
}
606+
else if (need_subspace_)
607+
{
608+
ct::TensorMap psi_map = ct::TensorMap(psi.data(), psi_temp);
609+
const bool assume_S_orthogonal = false;
610+
this->subspace_func_(psi_temp, psi_map, assume_S_orthogonal);
611+
psi_temp.sync(psi_map);
612+
}
613+
601614

602615
++ntry;
603616
avg_iter_ += 1.0;
604-
this->diag_mock(prec, psi_temp, eigen, ethr_band);
617+
this->diag_once(prec, psi_temp, eigen, ethr_band);
605618
} while (this->test_exit_cond(ntry, this->notconv_));
606619

607620
if (this->notconv_ > std::max(5, this->n_band_ / 4))

source/source_hsolver/diago_cg.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class DiagoCG final
2222
using ct_Device = typename ct::PsiToContainer<Device>::type;
2323
public:
2424
using Func = std::function<void(const ct::Tensor&, ct::Tensor&)>;
25+
using SubspaceFunc = std::function<void(const ct::Tensor&, ct::Tensor&, const bool)>;
2526
// Constructor need:
2627
// 1. temporary mock of Hamiltonian "Hamilt_PW"
2728
// 2. precondition pointer should point to place of precondition array.
@@ -30,7 +31,7 @@ class DiagoCG final
3031
const std::string& basis_type,
3132
const std::string& calculation,
3233
const bool& need_subspace,
33-
const Func& subspace_func,
34+
const SubspaceFunc& subspace_func,
3435
const Real& pw_diag_thr,
3536
const int& pw_diag_nmax,
3637
const int& nproc_in_pool);
@@ -72,11 +73,11 @@ class DiagoCG final
7273

7374
bool need_subspace_ = false;
7475
/// A function object that performs the hPsi calculation.
75-
std::function<void(const ct::Tensor&, ct::Tensor&)> hpsi_func_ = nullptr;
76+
Func hpsi_func_ = nullptr;
7677
/// A function object that performs the sPsi calculation.
77-
std::function<void(const ct::Tensor&, ct::Tensor&)> spsi_func_ = nullptr;
78+
Func spsi_func_ = nullptr;
7879
/// A function object that performs the subspace calculation.
79-
std::function<void(const ct::Tensor&, ct::Tensor&)> subspace_func_ = nullptr;
80+
SubspaceFunc subspace_func_ = nullptr;
8081

8182
void calc_grad(
8283
const ct::Tensor& prec,
@@ -119,7 +120,7 @@ class DiagoCG final
119120
void schmit_orth(const int& m, const ct::Tensor& psi, const ct::Tensor& sphi, ct::Tensor& phi_m);
120121

121122
// used in diag() for template replace Hamilt with Hamilt_PW
122-
void diag_mock(const ct::Tensor& prec,
123+
void diag_once(const ct::Tensor& prec,
123124
ct::Tensor& psi,
124125
ct::Tensor& eigen,
125126
const std::vector<double>& ethr_band);

source/source_hsolver/diago_dav_subspace.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "diago_dav_subspace.h"
22

33
#include "diago_iter_assist.h"
4-
#include "source_base/memory.h"
4+
55
#include "source_base/module_device/device.h"
66
#include "source_base/timer.h"
77
#include "source_hsolver/kernels/hegvd_op.h"
@@ -26,12 +26,11 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
2626
const int& david_ndim_in,
2727
const double& diag_thr_in,
2828
const int& diag_nmax_in,
29-
const bool& need_subspace_in,
3029
const diag_comm_info& diag_comm_in,
3130
const int diag_subspace_in,
3231
const int diago_subspace_bs_in)
3332
: precondition(precondition_in), n_band(nband_in), dim(nbasis_in), nbase_x(nband_in * david_ndim_in),
34-
diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), is_subspace(need_subspace_in), diag_comm(diag_comm_in),
33+
diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), diag_comm(diag_comm_in),
3534
diag_subspace(diag_subspace_in), diago_subspace_bs(diago_subspace_bs_in)
3635
{
3736
this->device = base_device::get_device_type<Device>(this->ctx);
@@ -76,7 +75,7 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
7675
{
7776
resmem_real_op()(this->d_precondition, nbasis_in);
7877
// syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition.data(), nbasis_in);
79-
base_device::memory::resize_memory_op<T, Device>()(this->d_scc, this->nbase_x * this->nbase_x);
78+
resmem_complex_op()(this->d_scc, this->nbase_x * this->nbase_x);
8079
resmem_real_op()(this->d_eigenvalue, this->nbase_x);
8180
}
8281
#endif
@@ -370,7 +369,6 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
370369
{
371370
Real* psi_norm = nullptr;
372371
resmem_real_op()(psi_norm, notconv);
373-
using setmem_real_op = base_device::memory::set_memory_op<Real, Device>;
374372
setmem_real_op()(psi_norm, 0.0, notconv);
375373

376374
normalize_op<T, Device>()(this->dim,
@@ -541,7 +539,7 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,
541539
#if defined(__CUDA) || defined(__ROCM)
542540
if (this->diag_comm.rank == 0)
543541
{
544-
base_device::memory::synchronize_memory_op<T, Device, Device>()(this->d_scc, scc, nbase * this->nbase_x);
542+
syncmem_complex_op()(this->d_scc, scc, nbase * this->nbase_x);
545543
hegvd_op<T, Device>()(this->ctx, nbase, this->nbase_x, this->hcc, this->d_scc, this->d_eigenvalue, this->vcc);
546544
syncmem_var_d2h_op()((*eigenvalue_iter).data(), this->d_eigenvalue, this->nbase_x);
547545
}

source/source_hsolver/diago_dav_subspace.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class Diago_DavSubspace
3030
const int& david_ndim_in,
3131
const double& diag_thr_in,
3232
const int& diag_nmax_in,
33-
const bool& need_subspace_in,
3433
const diag_comm_info& diag_comm_in,
3534
const int diago_dav_method_in,
3635
const int block_size_in);
@@ -58,9 +57,6 @@ class Diago_DavSubspace
5857
/// maximal iteration number
5958
const int iter_nmax;
6059

61-
/// is diagH_subspace needed?
62-
const bool is_subspace;
63-
6460
/// the first dimension of the matrix to be diagonalized
6561
const int n_band = 0;
6662

0 commit comments

Comments
 (0)