Skip to content

Commit ab08f46

Browse files
committed
Revert last commit, substitute gemm for einsum
This reverts commit a3350b7.
1 parent a3350b7 commit ab08f46

File tree

1 file changed

+60
-60
lines changed

1 file changed

+60
-60
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 60 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -97,23 +97,23 @@ void DiagoBPCG<T, Device>::orth_cholesky(
9797
// hsub_out = psi_out * transc(psi_out)
9898
ct::EinsumOption option(
9999
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
100-
hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
100+
// hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
101101

102102
// gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band)
103-
// gemm_op()(this->ctx,
104-
// 'C',
105-
// 'N',
106-
// this->n_band, //m
107-
// this->n_band, //n
108-
// this->n_dim, //k
109-
// this->one, //1.0
110-
// psi_out.data<T>(),
111-
// this->n_basis, //lda
112-
// psi_out.data<T>(),
113-
// this->n_basis, //ldb
114-
// this->zero, //0.0
115-
// hsub_out.data<T>(),
116-
// this->n_band); //ldc
103+
gemm_op()(this->ctx,
104+
'C',
105+
'N',
106+
this->n_band, //m
107+
this->n_band, //n
108+
this->n_dim, //k
109+
this->one, //1.0
110+
psi_out.data<T>(),
111+
this->n_basis, //lda
112+
psi_out.data<T>(),
113+
this->n_basis, //ldb
114+
this->zero, //0.0
115+
hsub_out.data<T>(),
116+
this->n_band); //ldc
117117

118118
// set hsub matrix to lower format;
119119
ct::kernels::set_matrix<T, ct_Device>()(
@@ -165,45 +165,45 @@ void DiagoBPCG<T, Device>::orth_projection(
165165
{
166166
ct::EinsumOption option(
167167
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in);
168-
hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
168+
// hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
169169

170170
// this->orth_projection(this->psi, this->hsub, this->grad);
171171
// gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band)
172-
// gemm_op()(this->ctx,
173-
// 'C',
174-
// 'N',
175-
// this->n_band, //m
176-
// this->n_band, //n
177-
// this->n_dim, //k
178-
// this->one, //1.0
179-
// psi_in.data<T>(),
180-
// this->n_basis, //lda
181-
// grad_out.data<T>(),
182-
// this->n_basis, //ldb
183-
// this->zero, //0.0
184-
// hsub_in.data<T>(),
185-
// this->n_band); //ldc
172+
gemm_op()(this->ctx,
173+
'C',
174+
'N',
175+
this->n_band, //m
176+
this->n_band, //n
177+
this->n_dim, //k
178+
this->one, //1.0
179+
psi_in.data<T>(),
180+
this->n_basis, //lda
181+
grad_out.data<T>(),
182+
this->n_basis, //ldb
183+
this->zero, //0.0
184+
hsub_in.data<T>(),
185+
this->n_band); //ldc
186186

187187
// set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
188188
option = ct::EinsumOption(
189189
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out);
190-
grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
190+
// grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
191191

192192
// grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
193-
// gemm_op()(this->ctx,
194-
// 'N',
195-
// 'N',
196-
// this->n_dim, //m
197-
// this->n_band, //n
198-
// this->n_band, //k
199-
// this->neg_one, //-1.0
200-
// psi_in.data<T>(),
201-
// this->n_basis, //lda
202-
// hsub_in.data<T>(),
203-
// this->n_band, //ldb
204-
// this->one, //1.0
205-
// grad_out.data<T>(),
206-
// this->n_basis); //ldc
193+
gemm_op()(this->ctx,
194+
'N',
195+
'N',
196+
this->n_dim, //m
197+
this->n_band, //n
198+
this->n_band, //k
199+
this->neg_one, //-1.0
200+
psi_in.data<T>(),
201+
this->n_basis, //lda
202+
hsub_in.data<T>(),
203+
this->n_band, //ldb
204+
this->one, //1.0
205+
grad_out.data<T>(),
206+
this->n_basis); //ldc
207207

208208
return;
209209
}
@@ -263,23 +263,23 @@ void DiagoBPCG<T, Device>::diag_hsub(
263263
// it controls the ops to use the corresponding device to calculate results
264264
ct::EinsumOption option(
265265
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
266-
hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
266+
// hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
267267

268268
// gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band)
269-
// gemm_op()(this->ctx,
270-
// 'C',
271-
// 'N',
272-
// this->n_band, //m
273-
// this->n_band, //n
274-
// this->n_dim, //k
275-
// this->one, //1.0
276-
// hpsi_in.data<T>(),
277-
// this->n_basis, //lda
278-
// psi_in.data<T>(),
279-
// this->n_basis, //ldb
280-
// this->zero, //0.0
281-
// hsub_out.data<T>(),
282-
// this->n_band); //ldc
269+
gemm_op()(this->ctx,
270+
'C',
271+
'N',
272+
this->n_band, //m
273+
this->n_band, //n
274+
this->n_dim, //k
275+
this->one, //1.0
276+
hpsi_in.data<T>(),
277+
this->n_basis, //lda
278+
psi_in.data<T>(),
279+
this->n_basis, //ldb
280+
this->zero, //0.0
281+
hsub_out.data<T>(),
282+
this->n_band); //ldc
283283

284284
ct::kernels::lapack_dnevd<T, ct_Device>()('V', 'U', hsub_out.data<T>(), this->n_band, eigenvalue_out.data<Real>());
285285

0 commit comments

Comments
 (0)