Skip to content

Commit d901d00

Browse files
committed
replace einsum by gemm in orth_projection
1 parent 487b6b2 commit d901d00

File tree

2 files changed

+40
-32
lines changed

2 files changed

+40
-32
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -165,45 +165,45 @@ void DiagoBPCG<T, Device>::orth_projection(
165165
{
166166
ct::EinsumOption option(
167167
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in);
168-
hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
168+
// hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
169169

170170
// this->orth_projection(this->psi, this->hsub, this->grad);
171-
// gemm: hsub_in(n_band x n_band) = grad_out^T(n_band x n_basis) * psi_in(n_basis x n_band)
172-
// gemm_op()(this->ctx,
173-
// 'C',
174-
// 'N',
175-
// this->n_band, //m
176-
// this->n_band, //n
177-
// this->n_dim, //k
178-
// this->one,
179-
// grad_out.data<T>(),
180-
// this->n_basis, //lda
181-
// psi_in.data<T>(),
182-
// this->n_basis, //ldb
183-
// this->zero,
184-
// hsub_in.data<T>(),
185-
// this->n_band); //ldc
171+
// gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band)
172+
gemm_op()(this->ctx,
173+
'C',
174+
'N',
175+
this->n_band, //m
176+
this->n_band, //n
177+
this->n_dim, //k
178+
this->one,
179+
psi_in.data<T>(),
180+
this->n_basis, //lda
181+
grad_out.data<T>(),
182+
this->n_basis, //ldb
183+
this->zero,
184+
hsub_in.data<T>(),
185+
this->n_band); //ldc
186186

187187
// set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
188188
option = ct::EinsumOption(
189189
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out);
190-
grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
190+
// grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
191191

192192
// grad_out(n_basis x n_band) = psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
193-
// gemm_op()(this->ctx,
194-
// 'N',
195-
// 'N',
196-
// this->n_basis, //m
197-
// this->n_band, //n
198-
// this->n_band, //k
199-
// this->one,
200-
// psi_in.data<T>(),
201-
// this->n_basis, //lda
202-
// hsub_in.data<T>(),
203-
// this->n_band, //ldb
204-
// this->zero,
205-
// grad_out.data<T>(),
206-
// this->n_basis); //ldc
193+
gemm_op()(this->ctx,
194+
'N',
195+
'N',
196+
this->n_dim, //m
197+
this->n_band, //n
198+
this->n_band, //k
199+
this->neg_one,
200+
psi_in.data<T>(),
201+
this->n_basis, //lda
202+
hsub_in.data<T>(),
203+
this->n_band, //ldb
204+
this->one,
205+
grad_out.data<T>(),
206+
this->n_basis); //ldc
207207

208208
return;
209209
}

source/module_hsolver/test/diago_bpcg_test.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,15 @@ TEST(DiagoBPCGTest, readH)
249249
// read Hamilt matrix from file data-H
250250
std::vector<std::complex<double>> hm;
251251
std::ifstream ifs;
252-
ifs.open("H-KPoints-Si64.dat");
252+
std::string filename = "H-KPoints-Si64.dat"; //"H-small-6x6.dat";
253+
std::cout << "Reading file " << filename << std::endl;
254+
ifs.open(filename);
255+
// open file and check status
256+
if (!ifs.is_open())
257+
{
258+
std::cout << "Error opening file " << filename << std::endl;
259+
exit(1);
260+
}
253261
DIAGOTEST::readh(ifs, hm);
254262
ifs.close();
255263
int dim = DIAGOTEST::npw;

0 commit comments

Comments
 (0)