Skip to content

Commit b67a2d3

Browse files
committed
Add n_dim var in bpcg class to support different leading dimension vs matrix dim
1 parent cc3ed9b commit b67a2d3

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ void DiagoBPCG<T, Device>::init_iter(const psi::Psi<T, Device> &psi_in) {
3535
// Specify the problem size n_basis, n_band, while lda is n_basis
3636
this->n_band = psi_in.get_nbands();
3737
this->n_basis = psi_in.get_nbasis();
38+
this->n_dim = psi_in.get_current_nbas();
3839

3940
// All column major tensors
4041

@@ -93,24 +94,24 @@ void DiagoBPCG<T, Device>::orth_cholesky(
9394
ct::Tensor& hsub_out)
9495
{
9596
// hsub_out = psi_out * transc(psi_out)
96-
ct::EinsumOption option(
97-
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
98-
hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
97+
// ct::EinsumOption option(
98+
// /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
99+
// hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
99100
// using gemm instead einsum for different leading dimension and nbasis
100-
// gemm_op<T, Device>()(this->ctx,
101-
// 'N',
102-
// 'N',
103-
// this->dim,
104-
// notconv,
105-
// nbase,
106-
// this->one,
107-
// hphi,
108-
// this->dim,
109-
// vcc,
110-
// this->nbase_x,
111-
// this->zero,
112-
// psi_iter + (nbase) * this->dim,
113-
// this->dim);
101+
gemm_op<T, Device>()(this->ctx,
102+
'N',
103+
'C',
104+
this->n_band,
105+
this->n_band,
106+
this->n_dim,
107+
this->one,
108+
psi_out.data<T>(),
109+
this->n_basis,
110+
psi_out.data<T>(),
111+
this->n_basis,
112+
this->zero,
113+
hsub_out.data<T>(),
114+
this->n_band);
114115

115116
// set hsub matrix to lower format;
116117
ct::kernels::set_matrix<T, ct_Device>()(

source/module_hsolver/diago_bpcg.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ class DiagoBPCG
7272
Device * ctx = {};
7373
/// the number of rows of the input psi
7474
int n_band = 0;
75-
/// the number of cols of the input psi
75+
/// the number of cols of the input psi, leading dimension
7676
int n_basis = 0;
77+
/// the real-time column size of the input psi
78+
int n_dim = 0;
7779
/// max iter steps for all-band cg loop
7880
int nline = 4;
7981
/// cg convergence thr

0 commit comments

Comments
 (0)