diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 63d0956464..3a8e7b02a8 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -35,6 +35,7 @@ void DiagoBPCG::init_iter(const psi::Psi &psi_in) { // Specify the problem size n_basis, n_band, while lda is n_basis this->n_band = psi_in.get_nbands(); this->n_basis = psi_in.get_nbasis(); + this->n_dim = psi_in.get_current_nbas(); // All column major tensors @@ -96,6 +97,21 @@ void DiagoBPCG::orth_cholesky( ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option); + // using gemm instead einsum for different leading dimension and nbasis + // gemm_op()(this->ctx, + // 'N', + // 'C', + // this->n_band, + // this->n_band, + // this->n_dim, + // this->one, + // psi_out.data(), + // this->n_basis, + // psi_out.data(), + // this->n_basis, + // this->zero, + // hsub_out.data(), + // this->n_band); // set hsub matrix to lower format; ct::kernels::set_matrix()( diff --git a/source/module_hsolver/diago_bpcg.h b/source/module_hsolver/diago_bpcg.h index c14472de2e..531e66e61e 100644 --- a/source/module_hsolver/diago_bpcg.h +++ b/source/module_hsolver/diago_bpcg.h @@ -67,15 +67,23 @@ class DiagoBPCG private: + /// ctx is nothing but the devices used in gemm_op (Device * ctx = nullptr;), + Device * ctx = {}; /// the number of rows of the input psi int n_band = 0; - /// the number of cols of the input psi + /// the number of cols of the input psi, leading dimension int n_basis = 0; + /// the real-time column size of the input psi + int n_dim = 0; /// max iter steps for all-band cg loop int nline = 4; /// cg convergence thr Real all_band_cg_thr = 1E-5; + // Pointer to objects of 1 and 0 for gemm + const T *one = nullptr, *zero = nullptr, *neg_one = nullptr; + const T one_ = static_cast(1.0), zero_ = static_cast(0.0), neg_one_ = static_cast(-1.0); + ct::DataType r_type = ct::DataType::DT_INVALID; ct::DataType t_type = ct::DataType::DT_INVALID; ct::DeviceType device_type = ct::DeviceType::UnKnown; @@ -330,6 +338,7 @@ class DiagoBPCG using calc_grad_with_block_op = hsolver::calc_grad_with_block_op; using line_minimize_with_block_op = hsolver::line_minimize_with_block_op; + using gemm_op = hsolver::gemm_op; };