|
1 | | -#include <limits> |
2 | | - |
3 | 1 | #include "module_hsolver/diago_bpcg.h" |
4 | 2 |
|
5 | | -#include <ATen/kernels/blas.h> |
6 | | -#include <ATen/kernels/lapack.h> |
7 | | - |
8 | | -#include <ATen/ops/einsum_op.h> |
9 | | - |
10 | 3 | #include "diago_iter_assist.h" |
11 | 4 | #include "module_base/blas_connector.h" |
12 | 5 | #include "module_base/global_function.h" |
13 | 6 | #include "module_base/kernels/math_kernel_op.h" |
| 7 | +#include "module_hsolver/kernels/bpcg_kernel_op.h" |
14 | 8 | #include "para_linear_transform.h" |
15 | 9 |
|
| 10 | +#include <ATen/kernels/blas.h> |
| 11 | +#include <ATen/kernels/lapack.h> |
| 12 | +#include <ATen/ops/einsum_op.h> |
| 13 | +#include <limits> |
| 14 | + |
16 | 15 | namespace hsolver { |
17 | 16 |
|
18 | 17 | template<typename T, typename Device> |
@@ -100,7 +99,13 @@ void DiagoBPCG<T, Device>::line_minimize( |
100 | 99 | ct::Tensor& psi_out, |
101 | 100 | ct::Tensor& hpsi_out) |
102 | 101 | { |
103 | | - line_minimize_with_block_op()(grad_in.data<T>(), hgrad_in.data<T>(), psi_out.data<T>(), hpsi_out.data<T>(), this->n_dim, this->n_basis, this->n_band_l); |
| 102 | + line_minimize_with_block_op<T, Device>()(grad_in.data<T>(), |
| 103 | + hgrad_in.data<T>(), |
| 104 | + psi_out.data<T>(), |
| 105 | + hpsi_out.data<T>(), |
| 106 | + this->n_dim, |
| 107 | + this->n_basis, |
| 108 | + this->n_band_l); |
104 | 109 | } |
105 | 110 |
|
106 | 111 |
|
@@ -138,17 +143,16 @@ void DiagoBPCG<T, Device>::calc_grad_with_block( |
138 | 143 | ct::Tensor& grad_out, |
139 | 144 | ct::Tensor& grad_old_out) |
140 | 145 | { |
141 | | - calc_grad_with_block_op()( |
142 | | - prec_in.data<Real>(), |
143 | | - err_out.data<Real>(), |
144 | | - beta_out.data<Real>(), |
145 | | - psi_in.data<T>(), |
146 | | - hpsi_in.data<T>(), |
147 | | - grad_out.data<T>(), |
148 | | - grad_old_out.data<T>(), |
149 | | - this->n_dim, |
150 | | - this->n_basis, |
151 | | - this->n_band_l); |
| 146 | + calc_grad_with_block_op<T, Device>()(prec_in.data<Real>(), |
| 147 | + err_out.data<Real>(), |
| 148 | + beta_out.data<Real>(), |
| 149 | + psi_in.data<T>(), |
| 150 | + hpsi_in.data<T>(), |
| 151 | + grad_out.data<T>(), |
| 152 | + grad_old_out.data<T>(), |
| 153 | + this->n_dim, |
| 154 | + this->n_basis, |
| 155 | + this->n_band_l); |
152 | 156 | } |
153 | 157 |
|
154 | 158 | template<typename T, typename Device> |
|
0 commit comments