Skip to content

Commit 487b6b2

Browse files
committed
using gemm instead of einsum in orth_projection
1 parent 31d78ff commit 487b6b2

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,44 @@ void DiagoBPCG<T, Device>::orth_projection(
167167
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in);
168168
hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
169169

170+
// this->orth_projection(this->psi, this->hsub, this->grad);
171+
// gemm: hsub_in(n_band x n_band) = grad_out^T(n_band x n_basis) * psi_in(n_basis x n_band)
172+
// gemm_op()(this->ctx,
173+
// 'C',
174+
// 'N',
175+
// this->n_band, //m
176+
// this->n_band, //n
177+
// this->n_dim, //k
178+
// this->one,
179+
// grad_out.data<T>(),
180+
// this->n_basis, //lda
181+
// psi_in.data<T>(),
182+
// this->n_basis, //ldb
183+
// this->zero,
184+
// hsub_in.data<T>(),
185+
// this->n_band); //ldc
186+
170187
// set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
171188
option = ct::EinsumOption(
172189
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out);
173190
grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
174191

192+
// grad_out(n_basis x n_band) = psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
193+
// gemm_op()(this->ctx,
194+
// 'N',
195+
// 'N',
196+
// this->n_basis, //m
197+
// this->n_band, //n
198+
// this->n_band, //k
199+
// this->one,
200+
// psi_in.data<T>(),
201+
// this->n_basis, //lda
202+
// hsub_in.data<T>(),
203+
// this->n_band, //ldb
204+
// this->zero,
205+
// grad_out.data<T>(),
206+
// this->n_basis); //ldc
207+
175208
return;
176209
}
177210

0 commit comments

Comments
 (0)