Skip to content

Commit 288d78a

Browse files
committed
Add one and zero object for gemm
1 parent 6848680 commit 288d78a

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,24 +94,24 @@ void DiagoBPCG<T, Device>::orth_cholesky(
9494
ct::Tensor& hsub_out)
9595
{
9696
// hsub_out = psi_out * transc(psi_out)
97-
// ct::EinsumOption option(
98-
// /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
99-
// hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
97+
ct::EinsumOption option(
98+
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
99+
hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
100100
// using gemm instead einsum for different leading dimension and nbasis
101-
gemm_op<T, Device>()(this->ctx,
102-
'N',
103-
'C',
104-
this->n_band,
105-
this->n_band,
106-
this->n_dim,
107-
this->one,
108-
psi_out.data<T>(),
109-
this->n_basis,
110-
psi_out.data<T>(),
111-
this->n_basis,
112-
this->zero,
113-
hsub_out.data<T>(),
114-
this->n_band);
101+
// gemm_op()(this->ctx,
102+
// 'N',
103+
// 'C',
104+
// this->n_band,
105+
// this->n_band,
106+
// this->n_dim,
107+
// this->one,
108+
// psi_out.data<T>(),
109+
// this->n_basis,
110+
// psi_out.data<T>(),
111+
// this->n_basis,
112+
// this->zero,
113+
// hsub_out.data<T>(),
114+
// this->n_band);
115115

116116
// set hsub matrix to lower format;
117117
ct::kernels::set_matrix<T, ct_Device>()(

source/module_hsolver/diago_bpcg.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class DiagoBPCG
8080
/// cg convergence thr
8181
Real all_band_cg_thr = 1E-5;
8282

83+
// Pointer to objects of 1 and 0 for gemm
84+
const T *one = nullptr, *zero = nullptr, *neg_one = nullptr;
85+
const T one_ = static_cast<T>(1.0), zero_ = static_cast<T>(0.0), neg_one_ = static_cast<T>(-1.0);
86+
8387
ct::DataType r_type = ct::DataType::DT_INVALID;
8488
ct::DataType t_type = ct::DataType::DT_INVALID;
8589
ct::DeviceType device_type = ct::DeviceType::UnKnown;
@@ -334,6 +338,7 @@ class DiagoBPCG
334338

335339
using calc_grad_with_block_op = hsolver::calc_grad_with_block_op<T, Device>;
336340
using line_minimize_with_block_op = hsolver::line_minimize_with_block_op<T, Device>;
341+
using gemm_op = hsolver::gemm_op<T, Device>;
337342

338343
};
339344

0 commit comments

Comments
 (0)