Skip to content

Commit cc3ed9b

Browse files
committed
Add ctx in bpcg to support gemm_op
1 parent 146a509 commit cc3ed9b

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,21 @@ void DiagoBPCG<T, Device>::orth_cholesky(
9696
ct::EinsumOption option(
9797
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
9898
hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
99+
// using gemm instead einsum for different leading dimension and nbasis
100+
// gemm_op<T, Device>()(this->ctx,
101+
// 'N',
102+
// 'N',
103+
// this->dim,
104+
// notconv,
105+
// nbase,
106+
// this->one,
107+
// hphi,
108+
// this->dim,
109+
// vcc,
110+
// this->nbase_x,
111+
// this->zero,
112+
// psi_iter + (nbase) * this->dim,
113+
// this->dim);
99114

100115
// set hsub matrix to lower format;
101116
ct::kernels::set_matrix<T, ct_Device>()(

source/module_hsolver/diago_bpcg.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ class DiagoBPCG
6868

6969

7070
private:
71+
/// ctx is nothing but the devices used in gemm_op (Device * ctx = nullptr;),
72+
Device * ctx = {};
7173
/// the number of rows of the input psi
7274
int n_band = 0;
7375
/// the number of cols of the input psi

0 commit comments

Comments
 (0)