Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
16a26a3
Add dimension parameter for BPCG method
Cstandardlib Jan 6, 2025
d7c66b1
Add utils for hsovler gemm_op
Cstandardlib Jan 6, 2025
bcf492f
Change code to fit new bpcg init interface
Cstandardlib Jan 6, 2025
31d78ff
using gemm instead of einsum in orth_cholesky
Cstandardlib Jan 6, 2025
487b6b2
using gemm instead of einsum in orth_projection
Cstandardlib Jan 6, 2025
d901d00
replace einsum by gemm in orth_projection
Cstandardlib Jan 6, 2025
e2dc4a1
replace einsum by gemm in rotate_wf
Cstandardlib Jan 6, 2025
3a70e1b
replace einsum by gemm in diag_hsub
Cstandardlib Jan 6, 2025
14c3693
Merge branch 'develop' into fix/bpcg-gemm-insteadof-einsum
Cstandardlib Jan 6, 2025
fe19e07
fix wrong dimension of gemm_op in rotate_wf
Cstandardlib Jan 6, 2025
87f0d36
Revert change of einsum in rotate_wf:
Cstandardlib Jan 7, 2025
830188a
Merge branch 'develop' into fix/bpcg-gemm-insteadof-einsum
Cstandardlib Jan 7, 2025
a3350b7
Revert gemm substitute for einsum
Cstandardlib Jan 7, 2025
ab08f46
Revert last commit, substitute gemm for einsum
Cstandardlib Jan 7, 2025
1813a56
Update 102_PW_BPCG totalstressref reference value
Cstandardlib Jan 7, 2025
570131e
Merge branch 'develop' into fix/bpcg-gemm-insteadof-einsum
Cstandardlib Jan 8, 2025
d8045b8
Merge branch 'develop' into fix/bpcg-gemm-insteadof-einsum
Cstandardlib Jan 9, 2025
15d28ed
Merge branch 'develop' into fix/bpcg-gemm-insteadof-einsum
Cstandardlib Jan 9, 2025
0d87baa
Update 102_PW_BPCG totalstressref reference value
Cstandardlib Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 93 additions & 6 deletions source/module_hsolver/diago_bpcg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ DiagoBPCG<T, Device>::DiagoBPCG(const Real* precondition_in)
this->device_type = ct::DeviceTypeToEnum<Device>::value;

this->h_prec = std::move(ct::TensorMap((void *) precondition_in, r_type, device_type, {this->n_basis}));

this->one = &one_;
this->zero = &zero_;
this->neg_one = &neg_one_;
}

template<typename T, typename Device>
Expand All @@ -30,11 +34,11 @@ DiagoBPCG<T, Device>::~DiagoBPCG() {
}

template<typename T, typename Device>
void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis) {
void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis, const int ndim) {
// Specify the problem size n_basis, n_band, while lda is n_basis
this->n_band = nband;
this->n_basis = nbasis;

this->n_dim = ndim;

// All column major tensors

Expand Down Expand Up @@ -93,7 +97,23 @@ void DiagoBPCG<T, Device>::orth_cholesky(
// hsub_out = psi_out * transc(psi_out)
ct::EinsumOption option(
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
// hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);

// gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band)
gemm_op()(this->ctx,
'C',
'N',
this->n_band, //m
this->n_band, //n
this->n_dim, //k
this->one, //1.0
psi_out.data<T>(),
this->n_basis, //lda
psi_out.data<T>(),
this->n_basis, //ldb
this->zero, //0.0
hsub_out.data<T>(),
this->n_band); //ldc

// set hsub matrix to lower format;
ct::kernels::set_matrix<T, ct_Device>()(
Expand Down Expand Up @@ -145,12 +165,45 @@ void DiagoBPCG<T, Device>::orth_projection(
{
ct::EinsumOption option(
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in);
hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
// hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);

// this->orth_projection(this->psi, this->hsub, this->grad);
// gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band)
gemm_op()(this->ctx,
'C',
'N',
this->n_band, //m
this->n_band, //n
this->n_dim, //k
this->one, //1.0
psi_in.data<T>(),
this->n_basis, //lda
grad_out.data<T>(),
this->n_basis, //ldb
this->zero, //0.0
hsub_in.data<T>(),
this->n_band); //ldc

// set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
option = ct::EinsumOption(
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out);
grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
// grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);

// grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
gemm_op()(this->ctx,
'N',
'N',
this->n_dim, //m
this->n_band, //n
this->n_band, //k
this->neg_one, //-1.0
psi_in.data<T>(),
this->n_basis, //lda
hsub_in.data<T>(),
this->n_band, //ldb
this->one, //1.0
grad_out.data<T>(),
this->n_basis); //ldc

return;
}
Expand All @@ -165,6 +218,24 @@ void DiagoBPCG<T, Device>::rotate_wf(
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&workspace_in);
workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option);

// this->rotate_wf(hsub_out, psi_out, workspace_in);
// this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub);
// gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band)
// gemm_op()(this->ctx,
// 'N',
// 'N',
// this->n_basis, //m
// this->n_band, //n
// this->n_band, //k
// this->one, //1.0
// psi_out.data<T>(),
// this->n_basis, //lda
// hsub_in.data<T>(),
// this->n_band, //ldb
// this->zero, //0.0
// workspace_in.data<T>(),
// this->n_basis); //ldc

syncmem_complex_op()(psi_out.template data<T>(), workspace_in.template data<T>(), this->n_band * this->n_basis);

return;
Expand Down Expand Up @@ -192,7 +263,23 @@ void DiagoBPCG<T, Device>::diag_hsub(
// it controls the ops to use the corresponding device to calculate results
ct::EinsumOption option(
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
// hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);

// gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band)
gemm_op()(this->ctx,
'C',
'N',
this->n_band, //m
this->n_band, //n
this->n_dim, //k
this->one, //1.0
hpsi_in.data<T>(),
this->n_basis, //lda
psi_in.data<T>(),
this->n_basis, //ldb
this->zero, //0.0
hsub_out.data<T>(),
this->n_band); //ldc

ct::kernels::lapack_dnevd<T, ct_Device>()('V', 'U', hsub_out.data<T>(), this->n_band, eigenvalue_out.data<Real>());

Expand Down
13 changes: 12 additions & 1 deletion source/module_hsolver/diago_bpcg.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ class DiagoBPCG
*
* @param nband The number of bands.
* @param nbasis The number of basis functions. Leading dimension of psi.
* @param ndim The number of valid dimension of psi.
*/
void init_iter(const int nband, const int nbasis);
void init_iter(const int nband, const int nbasis, const int ndim);

using HPsiFunc = std::function<void(T*, T*, const int, const int)>;

Expand All @@ -77,6 +78,8 @@ class DiagoBPCG
int n_band = 0;
/// the number of cols of the input psi
int n_basis = 0;
/// valid dimension of psi
int n_dim = 0;
/// max iter steps for all-band cg loop
int nline = 4;

Expand Down Expand Up @@ -107,6 +110,13 @@ class DiagoBPCG
/// work for some calculations within this class, including rotate_wf call
ct::Tensor work = {};

// These are for hsolver gemm_op use
/// ctx is nothing but the devices used in gemm_op (Device * ctx = nullptr;),
Device * ctx = {};
// Pointer to objects of 1 and 0 for gemm
const T *one = nullptr, *zero = nullptr, *neg_one = nullptr;
const T one_ = static_cast<T>(1.0), zero_ = static_cast<T>(0.0), neg_one_ = static_cast<T>(-1.0);

/**
* @brief Update the precondition array.
*
Expand Down Expand Up @@ -332,6 +342,7 @@ class DiagoBPCG

using calc_grad_with_block_op = hsolver::calc_grad_with_block_op<T, Device>;
using line_minimize_with_block_op = hsolver::line_minimize_with_block_op<T, Device>;
using gemm_op = hsolver::gemm_op<T, Device>;

};

Expand Down
3 changes: 2 additions & 1 deletion source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
{
const int nband = psi.get_nbands();
const int nbasis = psi.get_nbasis();
const int ndim = psi.get_current_ngk();
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
Expand All @@ -499,7 +500,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
};
DiagoBPCG<T, Device> bpcg(pre_condition.data());
bpcg.init_iter(nband, nbasis);
bpcg.init_iter(nband, nbasis, ndim);
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band);
}
else if (this->method == "dav_subspace")
Expand Down
3 changes: 2 additions & 1 deletion source/module_hsolver/test/diago_bpcg_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ class DiagoBPCGPrepare
zero_,
hpsi_out, ld_psi);
};
bpcg.init_iter(nband, npw);
const int ndim = psi_local.get_current_ngk();
bpcg.init_iter(nband, npw, ndim);
std::vector<double> ethr_band(nband, 1e-5);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
Expand Down
2 changes: 1 addition & 1 deletion tests/integrate/102_PW_BPCG/result.ref
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
etotref -4869.74705201
etotperatomref -2434.87352600
totalforceref 5.19483000
totalstressref 37241.44843500
totalstressref 37241.45334600
pointgroupref C_1
spacegroupref C_1
nksibzref 8
Expand Down
Loading