From 16a26a330669111e0fe82bec6687ea667f8b1783 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Mon, 6 Jan 2025 16:36:16 +0800 Subject: [PATCH 01/14] Add dimension parameter for BPCG method --- source/module_hsolver/diago_bpcg.cpp | 4 ++-- source/module_hsolver/diago_bpcg.h | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 635e3a7943..88cfb0076b 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -30,11 +30,11 @@ DiagoBPCG::~DiagoBPCG() { } template -void DiagoBPCG::init_iter(const int nband, const int nbasis) { +void DiagoBPCG::init_iter(const int nband, const int nbasis, const int ndim) { // Specify the problem size n_basis, n_band, while lda is n_basis this->n_band = nband; this->n_basis = nbasis; - + this->n_dim = ndim; // All column major tensors diff --git a/source/module_hsolver/diago_bpcg.h b/source/module_hsolver/diago_bpcg.h index c57ed5e5ee..dfc51ceba4 100644 --- a/source/module_hsolver/diago_bpcg.h +++ b/source/module_hsolver/diago_bpcg.h @@ -52,8 +52,9 @@ class DiagoBPCG * * @param nband The number of bands. * @param nbasis The number of basis functions. Leading dimension of psi. + * @param ndim The number of valid dimension of psi. */ - void init_iter(const int nband, const int nbasis); + void init_iter(const int nband, const int nbasis, const int ndim); using HPsiFunc = std::function; @@ -77,6 +78,8 @@ class DiagoBPCG int n_band = 0; /// the number of cols of the input psi int n_basis = 0; + /// valid dimension of psi + int n_dim = 0; /// max iter steps for all-band cg loop int nline = 4; @@ -332,6 +335,7 @@ class DiagoBPCG using calc_grad_with_block_op = hsolver::calc_grad_with_block_op; using line_minimize_with_block_op = hsolver::line_minimize_with_block_op; + using gemm_op = hsolver::gemm_op; }; From d7c66b15273ce59780daa527dda3f435e355ace2 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Mon, 6 Jan 2025 16:37:42 +0800 Subject: [PATCH 02/14] Add utils for hsovler gemm_op --- source/module_hsolver/diago_bpcg.cpp | 4 ++++ source/module_hsolver/diago_bpcg.h | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 88cfb0076b..dea6925bbf 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -22,6 +22,10 @@ DiagoBPCG::DiagoBPCG(const Real* precondition_in) this->device_type = ct::DeviceTypeToEnum::value; this->h_prec = std::move(ct::TensorMap((void *) precondition_in, r_type, device_type, {this->n_basis})); + + this->one = &one_; + this->zero = &zero_; + this->neg_one = &neg_one_; } template diff --git a/source/module_hsolver/diago_bpcg.h b/source/module_hsolver/diago_bpcg.h index dfc51ceba4..44ddd9736f 100644 --- a/source/module_hsolver/diago_bpcg.h +++ b/source/module_hsolver/diago_bpcg.h @@ -110,6 +110,13 @@ class DiagoBPCG /// work for some calculations within this class, including rotate_wf call ct::Tensor work = {}; + // These are for hsolver gemm_op use + /// ctx is nothing but the devices used in gemm_op (Device * ctx = nullptr;), + Device * ctx = {}; + // Pointer to objects of 1 and 0 for gemm + const T *one = nullptr, *zero = nullptr, *neg_one = nullptr; + const T one_ = static_cast(1.0), zero_ = static_cast(0.0), neg_one_ = static_cast(-1.0); + /** * @brief Update the precondition array. * From bcf492f678482d89ed14452a93ea8f5d181106e3 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Mon, 6 Jan 2025 16:50:44 +0800 Subject: [PATCH 03/14] Change code to fit new bpcg init interface --- source/module_hsolver/hsolver_pw.cpp | 3 ++- source/module_hsolver/test/diago_bpcg_test.cpp | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index dbfca81061..0eb57e4763 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -483,6 +483,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, { const int nband = psi.get_nbands(); const int nbasis = psi.get_nbasis(); + const int ndim = psi.get_current_ngk(); // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); @@ -499,7 +500,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ModuleBase::timer::tick("DavSubspace", "hpsi_func"); }; DiagoBPCG bpcg(pre_condition.data()); - bpcg.init_iter(nband, nbasis); + bpcg.init_iter(nband, nbasis, ndim); bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band); } else if (this->method == "dav_subspace") diff --git a/source/module_hsolver/test/diago_bpcg_test.cpp b/source/module_hsolver/test/diago_bpcg_test.cpp index 0ca9ff2444..1448588394 100644 --- a/source/module_hsolver/test/diago_bpcg_test.cpp +++ b/source/module_hsolver/test/diago_bpcg_test.cpp @@ -151,7 +151,8 @@ class DiagoBPCGPrepare zero_, hpsi_out, ld_psi); }; - bpcg.init_iter(nband, npw); + const int ndim = psi_local.get_current_ngk(); + bpcg.init_iter(nband, npw, ndim); std::vector ethr_band(nband, 1e-5); bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); From 31d78ff3aee198b94d6f36f305cea48533990db0 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:11:28 +0800 Subject: [PATCH 04/14] using gemm instead of einsum in orth_cholesky --- source/module_hsolver/diago_bpcg.cpp | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index dea6925bbf..6ff8079681 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -97,7 +97,23 @@ void DiagoBPCG::orth_cholesky( // hsub_out = psi_out * transc(psi_out) ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option); + // hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option); + + // gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band) + gemm_op()(this->ctx, + 'C', + 'N', + this->n_band, //m + this->n_band, //n + this->n_dim, //k + this->one, + psi_out.data(), + this->n_basis, //lda + psi_out.data(), + this->n_basis, //ldb + this->zero, + hsub_out.data(), + this->n_band); //ldc // set hsub matrix to lower format; ct::kernels::set_matrix()( From 487b6b25604f76645d3dde782b8b0ff9b2248b5a Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:21:23 +0800 Subject: [PATCH 05/14] using gemm instead of einsum in orth_projection --- source/module_hsolver/diago_bpcg.cpp | 33 ++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 6ff8079681..1599ac0d65 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -167,11 +167,44 @@ void DiagoBPCG::orth_projection( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in); hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option); + // this->orth_projection(this->psi, this->hsub, this->grad); + // gemm: hsub_in(n_band x n_band) = grad_out^T(n_band x n_basis) * psi_in(n_basis x n_band) + // gemm_op()(this->ctx, + // 'C', + // 'N', + // this->n_band, //m + // this->n_band, //n + // this->n_dim, //k + // this->one, + // grad_out.data(), + // this->n_basis, //lda + // psi_in.data(), + // this->n_basis, //ldb + // this->zero, + // hsub_in.data(), + // this->n_band); //ldc + // set_matrix_op()('L', hsub_in->data(), this->n_band); option = ct::EinsumOption( /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out); grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); + // grad_out(n_basis x n_band) = psi_in(n_basis x n_band) * hsub_in(n_band x n_band) + // gemm_op()(this->ctx, + // 'N', + // 'N', + // this->n_basis, //m + // this->n_band, //n + // this->n_band, //k + // this->one, + // psi_in.data(), + // this->n_basis, //lda + // hsub_in.data(), + // this->n_band, //ldb + // this->zero, + // grad_out.data(), + // this->n_basis); //ldc + return; } From d901d00f0ad24f798312c9cc41e4a95ffdb7b56f Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Mon, 6 Jan 2025 20:16:02 +0800 Subject: [PATCH 06/14] replace einsum by gemm in orth_projection --- source/module_hsolver/diago_bpcg.cpp | 62 +++++++++---------- .../module_hsolver/test/diago_bpcg_test.cpp | 10 ++- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 1599ac0d65..c67af37049 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -165,45 +165,45 @@ void DiagoBPCG::orth_projection( { ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in); - hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option); + // hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option); // this->orth_projection(this->psi, this->hsub, this->grad); - // gemm: hsub_in(n_band x n_band) = grad_out^T(n_band x n_basis) * psi_in(n_basis x n_band) - // gemm_op()(this->ctx, - // 'C', - // 'N', - // this->n_band, //m - // this->n_band, //n - // this->n_dim, //k - // this->one, - // grad_out.data(), - // this->n_basis, //lda - // psi_in.data(), - // this->n_basis, //ldb - // this->zero, - // hsub_in.data(), - // this->n_band); //ldc + // gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band) + gemm_op()(this->ctx, + 'C', + 'N', + this->n_band, //m + this->n_band, //n + this->n_dim, //k + this->one, + psi_in.data(), + this->n_basis, //lda + grad_out.data(), + this->n_basis, //ldb + this->zero, + hsub_in.data(), + this->n_band); //ldc // set_matrix_op()('L', hsub_in->data(), this->n_band); option = ct::EinsumOption( /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out); - grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); + // grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); // grad_out(n_basis x n_band) = psi_in(n_basis x n_band) * hsub_in(n_band x n_band) - // gemm_op()(this->ctx, - // 'N', - // 'N', - // this->n_basis, //m - // this->n_band, //n - // this->n_band, //k - // this->one, - // psi_in.data(), - // this->n_basis, //lda - // hsub_in.data(), - // this->n_band, //ldb - // this->zero, - // grad_out.data(), - // this->n_basis); //ldc + gemm_op()(this->ctx, + 'N', + 'N', + this->n_dim, //m + this->n_band, //n + this->n_band, //k + this->neg_one, + psi_in.data(), + this->n_basis, //lda + hsub_in.data(), + this->n_band, //ldb + this->one, + grad_out.data(), + this->n_basis); //ldc return; } diff --git a/source/module_hsolver/test/diago_bpcg_test.cpp b/source/module_hsolver/test/diago_bpcg_test.cpp index 1448588394..0ebe40e08f 100644 --- a/source/module_hsolver/test/diago_bpcg_test.cpp +++ b/source/module_hsolver/test/diago_bpcg_test.cpp @@ -249,7 +249,15 @@ TEST(DiagoBPCGTest, readH) // read Hamilt matrix from file data-H std::vector> hm; std::ifstream ifs; - ifs.open("H-KPoints-Si64.dat"); + std::string filename = "H-KPoints-Si64.dat"; //"H-small-6x6.dat"; + std::cout << "Reading file " << filename << std::endl; + ifs.open(filename); + // open file and check status + if (!ifs.is_open()) + { + std::cout << "Error opening file " << filename << std::endl; + exit(1); + } DIAGOTEST::readh(ifs, hm); ifs.close(); int dim = DIAGOTEST::npw; From e2dc4a17ddf1c68dc760ab129349d0c7d572a380 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Mon, 6 Jan 2025 20:26:55 +0800 Subject: [PATCH 07/14] replace einsum by gemm in rotate_wf --- source/module_hsolver/diago_bpcg.cpp | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index c67af37049..67f7fd5f6e 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -216,7 +216,25 @@ void DiagoBPCG::rotate_wf( { ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&workspace_in); - workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option); + // workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option); + + // this->rotate_wf(hsub_out, psi_out, workspace_in); + // this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub); + // gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band) + gemm_op()(this->ctx, + 'N', + 'N', + this->n_dim, //m + this->n_band, //n + this->n_band, //k + this->one, + psi_out.data(), + this->n_basis, //lda + hsub_in.data(), + this->n_band, //ldb + this->zero, + workspace_in.data(), + this->n_basis); //ldc syncmem_complex_op()(psi_out.template data(), workspace_in.template data(), this->n_band * this->n_basis); From 3a70e1b28cd3cb5a006ef24c75ba9b47321081a4 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Mon, 6 Jan 2025 20:58:12 +0800 Subject: [PATCH 08/14] replace einsum by gemm in diag_hsub --- source/module_hsolver/diago_bpcg.cpp | 48 ++++++++++++++++++---------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 67f7fd5f6e..9d1c50493e 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -106,12 +106,12 @@ void DiagoBPCG::orth_cholesky( this->n_band, //m this->n_band, //n this->n_dim, //k - this->one, + this->one, //1.0 psi_out.data(), this->n_basis, //lda psi_out.data(), this->n_basis, //ldb - this->zero, + this->zero, //0.0 hsub_out.data(), this->n_band); //ldc @@ -175,12 +175,12 @@ void DiagoBPCG::orth_projection( this->n_band, //m this->n_band, //n this->n_dim, //k - this->one, + this->one, //1.0 psi_in.data(), this->n_basis, //lda grad_out.data(), this->n_basis, //ldb - this->zero, + this->zero, //0.0 hsub_in.data(), this->n_band); //ldc @@ -189,21 +189,21 @@ void DiagoBPCG::orth_projection( /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out); // grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); - // grad_out(n_basis x n_band) = psi_in(n_basis x n_band) * hsub_in(n_band x n_band) + // grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band) gemm_op()(this->ctx, 'N', 'N', - this->n_dim, //m + this->n_dim, //m this->n_band, //n this->n_band, //k - this->neg_one, + this->neg_one, //-1.0 psi_in.data(), this->n_basis, //lda hsub_in.data(), this->n_band, //ldb - this->one, + this->one, //1.0 grad_out.data(), - this->n_basis); //ldc + this->n_basis); //ldc return; } @@ -224,15 +224,15 @@ void DiagoBPCG::rotate_wf( gemm_op()(this->ctx, 'N', 'N', - this->n_dim, //m - this->n_band, //n + this->n_dim, //m + this->n_band, //n this->n_band, //k - this->one, + this->one, //1.0 psi_out.data(), - this->n_basis, //lda + this->n_basis, //lda hsub_in.data(), - this->n_band, //ldb - this->zero, + this->n_band, //ldb + this->zero, //0.0 workspace_in.data(), this->n_basis); //ldc @@ -263,7 +263,23 @@ void DiagoBPCG::diag_hsub( // it controls the ops to use the corresponding device to calculate results ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option); + // hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option); + + // gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band) + gemm_op()(this->ctx, + 'C', + 'N', + this->n_band, //m + this->n_band, //n + this->n_dim, //k + this->one, //1.0 + hpsi_in.data(), + this->n_basis, //lda + psi_in.data(), + this->n_basis, //ldb + this->zero, //0.0 + hsub_out.data(), + this->n_band); //ldc ct::kernels::lapack_dnevd()('V', 'U', hsub_out.data(), this->n_band, eigenvalue_out.data()); From fe19e073e138a6b5555c8d876e184986ca1f1b55 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Mon, 6 Jan 2025 22:07:29 +0800 Subject: [PATCH 09/14] fix wrong dimension of gemm_op in rotate_wf --- source/module_hsolver/diago_bpcg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 9d1c50493e..d389ab85c9 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -224,7 +224,7 @@ void DiagoBPCG::rotate_wf( gemm_op()(this->ctx, 'N', 'N', - this->n_dim, //m + this->n_basis, //m this->n_band, //n this->n_band, //k this->one, //1.0 From 87f0d3612cbbaab7c852bc709175a7bd4a4b67ca Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:21:03 +0800 Subject: [PATCH 10/14] Revert change of einsum in rotate_wf: --- source/module_hsolver/diago_bpcg.cpp | 30 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index d389ab85c9..7830af2f67 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -216,25 +216,25 @@ void DiagoBPCG::rotate_wf( { ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&workspace_in); - // workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option); + workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option); // this->rotate_wf(hsub_out, psi_out, workspace_in); // this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub); // gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band) - gemm_op()(this->ctx, - 'N', - 'N', - this->n_basis, //m - this->n_band, //n - this->n_band, //k - this->one, //1.0 - psi_out.data(), - this->n_basis, //lda - hsub_in.data(), - this->n_band, //ldb - this->zero, //0.0 - workspace_in.data(), - this->n_basis); //ldc + // gemm_op()(this->ctx, + // 'N', + // 'N', + // this->n_basis, //m + // this->n_band, //n + // this->n_band, //k + // this->one, //1.0 + // psi_out.data(), + // this->n_basis, //lda + // hsub_in.data(), + // this->n_band, //ldb + // this->zero, //0.0 + // workspace_in.data(), + // this->n_basis); //ldc syncmem_complex_op()(psi_out.template data(), workspace_in.template data(), this->n_band * this->n_basis); From a3350b782c91b6a169c60a5237c3a3124bbba95d Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Tue, 7 Jan 2025 12:22:04 +0800 Subject: [PATCH 11/14] Revert gemm substitute for einsum --- source/module_hsolver/diago_bpcg.cpp | 120 +++++++++++++-------------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 7830af2f67..8838f93391 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -97,23 +97,23 @@ void DiagoBPCG::orth_cholesky( // hsub_out = psi_out * transc(psi_out) ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - // hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option); + hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option); // gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band) - gemm_op()(this->ctx, - 'C', - 'N', - this->n_band, //m - this->n_band, //n - this->n_dim, //k - this->one, //1.0 - psi_out.data(), - this->n_basis, //lda - psi_out.data(), - this->n_basis, //ldb - this->zero, //0.0 - hsub_out.data(), - this->n_band); //ldc + // gemm_op()(this->ctx, + // 'C', + // 'N', + // this->n_band, //m + // this->n_band, //n + // this->n_dim, //k + // this->one, //1.0 + // psi_out.data(), + // this->n_basis, //lda + // psi_out.data(), + // this->n_basis, //ldb + // this->zero, //0.0 + // hsub_out.data(), + // this->n_band); //ldc // set hsub matrix to lower format; ct::kernels::set_matrix()( @@ -165,45 +165,45 @@ void DiagoBPCG::orth_projection( { ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in); - // hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option); + hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option); // this->orth_projection(this->psi, this->hsub, this->grad); // gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band) - gemm_op()(this->ctx, - 'C', - 'N', - this->n_band, //m - this->n_band, //n - this->n_dim, //k - this->one, //1.0 - psi_in.data(), - this->n_basis, //lda - grad_out.data(), - this->n_basis, //ldb - this->zero, //0.0 - hsub_in.data(), - this->n_band); //ldc + // gemm_op()(this->ctx, + // 'C', + // 'N', + // this->n_band, //m + // this->n_band, //n + // this->n_dim, //k + // this->one, //1.0 + // psi_in.data(), + // this->n_basis, //lda + // grad_out.data(), + // this->n_basis, //ldb + // this->zero, //0.0 + // hsub_in.data(), + // this->n_band); //ldc // set_matrix_op()('L', hsub_in->data(), this->n_band); option = ct::EinsumOption( /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out); - // grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); + grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); // grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band) - gemm_op()(this->ctx, - 'N', - 'N', - this->n_dim, //m - this->n_band, //n - this->n_band, //k - this->neg_one, //-1.0 - psi_in.data(), - this->n_basis, //lda - hsub_in.data(), - this->n_band, //ldb - this->one, //1.0 - grad_out.data(), - this->n_basis); //ldc + // gemm_op()(this->ctx, + // 'N', + // 'N', + // this->n_dim, //m + // this->n_band, //n + // this->n_band, //k + // this->neg_one, //-1.0 + // psi_in.data(), + // this->n_basis, //lda + // hsub_in.data(), + // this->n_band, //ldb + // this->one, //1.0 + // grad_out.data(), + // this->n_basis); //ldc return; } @@ -263,23 +263,23 @@ void DiagoBPCG::diag_hsub( // it controls the ops to use the corresponding device to calculate results ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - // hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option); + hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option); // gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band) - gemm_op()(this->ctx, - 'C', - 'N', - this->n_band, //m - this->n_band, //n - this->n_dim, //k - this->one, //1.0 - hpsi_in.data(), - this->n_basis, //lda - psi_in.data(), - this->n_basis, //ldb - this->zero, //0.0 - hsub_out.data(), - this->n_band); //ldc + // gemm_op()(this->ctx, + // 'C', + // 'N', + // this->n_band, //m + // this->n_band, //n + // this->n_dim, //k + // this->one, //1.0 + // hpsi_in.data(), + // this->n_basis, //lda + // psi_in.data(), + // this->n_basis, //ldb + // this->zero, //0.0 + // hsub_out.data(), + // this->n_band); //ldc ct::kernels::lapack_dnevd()('V', 'U', hsub_out.data(), this->n_band, eigenvalue_out.data()); From ab08f460e807cca4cbc4715d6b9209382a2d71df Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Tue, 7 Jan 2025 21:26:39 +0800 Subject: [PATCH 12/14] Revert last commit, substitute gemm for einsum This reverts commit a3350b782c91b6a169c60a5237c3a3124bbba95d. --- source/module_hsolver/diago_bpcg.cpp | 120 +++++++++++++-------------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 8838f93391..7830af2f67 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -97,23 +97,23 @@ void DiagoBPCG::orth_cholesky( // hsub_out = psi_out * transc(psi_out) ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option); + // hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option); // gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band) - // gemm_op()(this->ctx, - // 'C', - // 'N', - // this->n_band, //m - // this->n_band, //n - // this->n_dim, //k - // this->one, //1.0 - // psi_out.data(), - // this->n_basis, //lda - // psi_out.data(), - // this->n_basis, //ldb - // this->zero, //0.0 - // hsub_out.data(), - // this->n_band); //ldc + gemm_op()(this->ctx, + 'C', + 'N', + this->n_band, //m + this->n_band, //n + this->n_dim, //k + this->one, //1.0 + psi_out.data(), + this->n_basis, //lda + psi_out.data(), + this->n_basis, //ldb + this->zero, //0.0 + hsub_out.data(), + this->n_band); //ldc // set hsub matrix to lower format; ct::kernels::set_matrix()( @@ -165,45 +165,45 @@ void DiagoBPCG::orth_projection( { ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in); - hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option); + // hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option); // this->orth_projection(this->psi, this->hsub, this->grad); // gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band) - // gemm_op()(this->ctx, - // 'C', - // 'N', - // this->n_band, //m - // this->n_band, //n - // this->n_dim, //k - // this->one, //1.0 - // psi_in.data(), - // this->n_basis, //lda - // grad_out.data(), - // this->n_basis, //ldb - // this->zero, //0.0 - // hsub_in.data(), - // this->n_band); //ldc + gemm_op()(this->ctx, + 'C', + 'N', + this->n_band, //m + this->n_band, //n + this->n_dim, //k + this->one, //1.0 + psi_in.data(), + this->n_basis, //lda + grad_out.data(), + this->n_basis, //ldb + this->zero, //0.0 + hsub_in.data(), + this->n_band); //ldc // set_matrix_op()('L', hsub_in->data(), this->n_band); option = ct::EinsumOption( /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out); - grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); + // grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); // grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band) - // gemm_op()(this->ctx, - // 'N', - // 'N', - // this->n_dim, //m - // this->n_band, //n - // this->n_band, //k - // this->neg_one, //-1.0 - // psi_in.data(), - // this->n_basis, //lda - // hsub_in.data(), - // this->n_band, //ldb - // this->one, //1.0 - // grad_out.data(), - // this->n_basis); //ldc + gemm_op()(this->ctx, + 'N', + 'N', + this->n_dim, //m + this->n_band, //n + this->n_band, //k + this->neg_one, //-1.0 + psi_in.data(), + this->n_basis, //lda + hsub_in.data(), + this->n_band, //ldb + this->one, //1.0 + grad_out.data(), + this->n_basis); //ldc return; } @@ -263,23 +263,23 @@ void DiagoBPCG::diag_hsub( // it controls the ops to use the corresponding device to calculate results ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option); + // hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option); // gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band) - // gemm_op()(this->ctx, - // 'C', - // 'N', - // this->n_band, //m - // this->n_band, //n - // this->n_dim, //k - // this->one, //1.0 - // hpsi_in.data(), - // this->n_basis, //lda - // psi_in.data(), - // this->n_basis, //ldb - // this->zero, //0.0 - // hsub_out.data(), - // this->n_band); //ldc + gemm_op()(this->ctx, + 'C', + 'N', + this->n_band, //m + this->n_band, //n + this->n_dim, //k + this->one, //1.0 + hpsi_in.data(), + this->n_basis, //lda + psi_in.data(), + this->n_basis, //ldb + this->zero, //0.0 + hsub_out.data(), + this->n_band); //ldc ct::kernels::lapack_dnevd()('V', 'U', hsub_out.data(), this->n_band, eigenvalue_out.data()); From 1813a56e362cee87f672e99def6fb8f6f024d504 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Wed, 8 Jan 2025 01:20:08 +0800 Subject: [PATCH 13/14] Update 102_PW_BPCG totalstressref reference value --- tests/integrate/102_PW_BPCG/result.ref | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integrate/102_PW_BPCG/result.ref b/tests/integrate/102_PW_BPCG/result.ref index b94b93bfc2..470d983117 100644 --- a/tests/integrate/102_PW_BPCG/result.ref +++ b/tests/integrate/102_PW_BPCG/result.ref @@ -1,7 +1,7 @@ etotref -4869.7470519163989593 etotperatomref -2434.8735259582 totalforceref 5.198676 -totalstressref 37241.091710 +totalstressref 37241.09727600 pointgroupref C_1 spacegroupref C_1 nksibzref 8 From 0d87baa7d803cbb45994f838c67051c363a98aa3 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Thu, 9 Jan 2025 23:02:45 +0800 Subject: [PATCH 14/14] Update 102_PW_BPCG totalstressref reference value --- tests/integrate/102_PW_BPCG/result.ref | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integrate/102_PW_BPCG/result.ref b/tests/integrate/102_PW_BPCG/result.ref index e702dfbb6b..2972395a15 100644 --- a/tests/integrate/102_PW_BPCG/result.ref +++ b/tests/integrate/102_PW_BPCG/result.ref @@ -1,7 +1,7 @@ etotref -4869.74705201 etotperatomref -2434.87352600 totalforceref 5.19483000 -totalstressref 37241.44843500 +totalstressref 37241.45334600 pointgroupref C_1 spacegroupref C_1 nksibzref 8