Skip to content

Commit d7c66b1

Browse files
committed
Add utils for hsovler gemm_op
1 parent 16a26a3 commit d7c66b1

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ DiagoBPCG<T, Device>::DiagoBPCG(const Real* precondition_in)
2222
this->device_type = ct::DeviceTypeToEnum<Device>::value;
2323

2424
this->h_prec = std::move(ct::TensorMap((void *) precondition_in, r_type, device_type, {this->n_basis}));
25+
26+
this->one = &one_;
27+
this->zero = &zero_;
28+
this->neg_one = &neg_one_;
2529
}
2630

2731
template<typename T, typename Device>

source/module_hsolver/diago_bpcg.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,13 @@ class DiagoBPCG
110110
/// work for some calculations within this class, including rotate_wf call
111111
ct::Tensor work = {};
112112

113+
// These are for hsolver gemm_op use
114+
/// ctx is nothing but the devices used in gemm_op (Device * ctx = nullptr;),
115+
Device * ctx = {};
116+
// Pointer to objects of 1 and 0 for gemm
117+
const T *one = nullptr, *zero = nullptr, *neg_one = nullptr;
118+
const T one_ = static_cast<T>(1.0), zero_ = static_cast<T>(0.0), neg_one_ = static_cast<T>(-1.0);
119+
113120
/**
114121
* @brief Update the precondition array.
115122
*

0 commit comments

Comments
 (0)