Skip to content

Commit 8b85311

Browse files
committed
fix compile
1 parent 767e4d4 commit 8b85311

File tree

2 files changed

+23
-22
lines changed

2 files changed

+23
-22
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
#include <limits>
2-
31
#include "module_hsolver/diago_bpcg.h"
42

5-
#include <ATen/kernels/blas.h>
6-
#include <ATen/kernels/lapack.h>
7-
8-
#include <ATen/ops/einsum_op.h>
9-
103
#include "diago_iter_assist.h"
114
#include "module_base/blas_connector.h"
125
#include "module_base/global_function.h"
136
#include "module_base/kernels/math_kernel_op.h"
7+
#include "module_hsolver/kernels/bpcg_kernel_op.h"
148
#include "para_linear_transform.h"
159

10+
#include <ATen/kernels/blas.h>
11+
#include <ATen/kernels/lapack.h>
12+
#include <ATen/ops/einsum_op.h>
13+
#include <limits>
14+
1615
namespace hsolver {
1716

1817
template<typename T, typename Device>
@@ -100,7 +99,13 @@ void DiagoBPCG<T, Device>::line_minimize(
10099
ct::Tensor& psi_out,
101100
ct::Tensor& hpsi_out)
102101
{
103-
line_minimize_with_block_op()(grad_in.data<T>(), hgrad_in.data<T>(), psi_out.data<T>(), hpsi_out.data<T>(), this->n_dim, this->n_basis, this->n_band_l);
102+
line_minimize_with_block_op<T, Device>()(grad_in.data<T>(),
103+
hgrad_in.data<T>(),
104+
psi_out.data<T>(),
105+
hpsi_out.data<T>(),
106+
this->n_dim,
107+
this->n_basis,
108+
this->n_band_l);
104109
}
105110

106111

@@ -138,17 +143,16 @@ void DiagoBPCG<T, Device>::calc_grad_with_block(
138143
ct::Tensor& grad_out,
139144
ct::Tensor& grad_old_out)
140145
{
141-
calc_grad_with_block_op()(
142-
prec_in.data<Real>(),
143-
err_out.data<Real>(),
144-
beta_out.data<Real>(),
145-
psi_in.data<T>(),
146-
hpsi_in.data<T>(),
147-
grad_out.data<T>(),
148-
grad_old_out.data<T>(),
149-
this->n_dim,
150-
this->n_basis,
151-
this->n_band_l);
146+
calc_grad_with_block_op<T, Device>()(prec_in.data<Real>(),
147+
err_out.data<Real>(),
148+
beta_out.data<Real>(),
149+
psi_in.data<T>(),
150+
hpsi_in.data<T>(),
151+
grad_out.data<T>(),
152+
grad_old_out.data<T>(),
153+
this->n_dim,
154+
this->n_basis,
155+
this->n_band_l);
152156
}
153157

154158
template<typename T, typename Device>

source/module_hsolver/diago_bpcg.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "module_base/para_gemm.h"
88
#include "module_hamilt_general/hamilt.h"
99
#include "module_hamilt_pw/hamilt_pwdft/structure_factor.h"
10-
#include "module_hsolver/kernels/bpcg_kernel_op.h"
1110
#include "module_hsolver/kernels/dngvd_op.h"
1211
#include "module_hsolver/para_linear_transform.h"
1312

@@ -350,8 +349,6 @@ class DiagoBPCG
350349
// note: these operators use template parameter base_device::Device_*
351350
// defined in module_base/module_device/types.h
352351
// different from ct_Device!
353-
using calc_grad_with_block_op = calc_grad_with_block_op<T, Device>;
354-
using line_minimize_with_block_op = line_minimize_with_block_op<T, Device>;
355352
using gemm_op = ModuleBase::gemm_op<T, Device>;
356353

357354
};

0 commit comments

Comments
 (0)