Skip to content

Commit c60e93c

Browse files
committed
fix compile
1 parent 8725264 commit c60e93c

File tree

5 files changed

+22
-13
lines changed

5 files changed

+22
-13
lines changed

source/module_base/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ add_library(
6161
${LIBM_SRC}
6262
)
6363

64+
target_link_libraries(base PUBLIC container)
65+
6466
add_subdirectory(module_container)
6567

6668
if(ENABLE_COVERAGE)

source/module_base/module_container/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ if(USE_ROCM)
1919
set(ATen_ROCM_DEPENDENCY_LIBS container_rocm)
2020
endif()
2121

22-
add_library(container OBJECT ${ATen_CPU_SRCS} ${ATen_CUDA_SRCS})
22+
add_library(container STATIC ${ATen_CPU_SRCS} ${ATen_CUDA_SRCS})
2323

2424
target_link_libraries(container PUBLIC
2525
${ATen_CPU_DEPENDENCY_LIBS} ${ATen_CUDA_DEPENDENCY_LIBS} ${ATen_ROCM_DEPENDENCY_LIBS})

source/module_base/test/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ AddTest(
134134

135135
AddTest(
136136
TARGET base_math_chebyshev
137-
LIBS parameter ${math_libs}
138-
SOURCES math_chebyshev_test.cpp ../blas_connector.cpp ../math_chebyshev.cpp ../tool_quit.cpp ../global_variable.cpp ../timer.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp
137+
LIBS parameter ${math_libs} base device
138+
SOURCES math_chebyshev_test.cpp
139139
)
140140

141141
AddTest(

source/module_elecstate/elecstate_pw_sdft.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,7 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
4242

4343
// template class ElecStatePW_SDFT<std::complex<float>, base_device::DEVICE_CPU>;
4444
template class ElecStatePW_SDFT<std::complex<double>, base_device::DEVICE_CPU>;
45+
#if ((defined __CUDA) || (defined __ROCM))
4546
template class ElecStatePW_SDFT<std::complex<double>, base_device::DEVICE_GPU>;
47+
#endif
4648
} // namespace elecstate

source/module_hsolver/test/test_hsolver_sdft.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "module_base/global_variable.h"
1313
#include "module_hsolver/hsolver_pw.h"
1414
#include "module_hsolver/hsolver_pw_sdft.h"
15+
#include "module_elecstate/elecstate_pw.h"
1516
#undef private
1617
#undef protected
1718

@@ -20,16 +21,20 @@ template <typename REAL>
2021
Sto_Func<REAL>::Sto_Func()
2122
{
2223
}
23-
2424
template class Sto_Func<double>;
2525

26-
template <typename REAL>
27-
StoChe<REAL>::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto)
26+
template<>
27+
void elecstate::ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>::init_rho_data()
28+
{
29+
}
30+
31+
template <typename REAL, typename Device>
32+
StoChe<REAL, Device>::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto)
2833
{
2934
this->nche = nche;
3035
}
31-
template <typename REAL>
32-
StoChe<REAL>::~StoChe()
36+
template <typename REAL, typename Device>
37+
StoChe<REAL, Device>::~StoChe()
3338
{
3439
}
3540

@@ -51,7 +56,7 @@ template <typename T, typename Device>
5156
void Stochastic_Iter<T, Device>::init(K_Vectors* pkv_in,
5257
ModulePW::PW_Basis_K* wfc_basis,
5358
Stochastic_WF<T, Device>& stowf,
54-
StoChe<double>& stoche,
59+
StoChe<Real, Device>& stoche,
5560
hamilt::HamiltSdftPW<T, Device>* p_hamilt_sto)
5661
{
5762
this->nchip = stowf.nchip;
@@ -108,7 +113,7 @@ void Stochastic_Iter<T, Device>::calHsqrtchi(Stochastic_WF<T, Device>& stowf)
108113

109114
template <typename T, typename Device>
110115
void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
111-
elecstate::ElecStatePW* pes,
116+
elecstate::ElecStatePW<T, Device>* pes,
112117
hamilt::Hamilt<T, Device>* pHamilt,
113118
ModulePW::PW_Basis_K* wfc_basis)
114119
{
@@ -193,7 +198,7 @@ TEST_F(TestHSolverPW_SDFT, solve)
193198
int istep = 0;
194199
int iter = 0;
195200

196-
this->hs_d.solve(&hamilt_test_d, psi_test_cd, &elecstate_test, &pwbk, stowf, istep, iter, false);
201+
this->hs_d.solve(&hamilt_test_d, psi_test_cd, psi_test_cd, &elecstate_test, &pwbk, stowf, istep, iter, false);
197202
EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist<std::complex<double>>::avg_iter, 0.0);
198203
EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[0], 4.0);
199204
EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[1], 7.0);
@@ -237,7 +242,7 @@ TEST_F(TestHSolverPW_SDFT, solve_noband_skipcharge)
237242
int istep = 0;
238243
int iter = 0;
239244

240-
this->hs_d.solve(&hamilt_test_d, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, false);
245+
this->hs_d.solve(&hamilt_test_d, psi_test_no, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, false);
241246
EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist<std::complex<double>>::avg_iter, 0.0);
242247
EXPECT_EQ(stowf.nbands_diag, 2);
243248
EXPECT_EQ(stowf.nbands_total, 1);
@@ -251,7 +256,7 @@ TEST_F(TestHSolverPW_SDFT, solve_noband_skipcharge)
251256
std::cout<<__FILE__<<__LINE__<<" "<<elecstate_test.f_en.eband<<std::endl;*/
252257

253258
// test for skip charge
254-
this->hs_d.solve(&hamilt_test_d, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, true);
259+
this->hs_d.solve(&hamilt_test_d, psi_test_no, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, true);
255260
EXPECT_EQ(stowf.nbands_diag, 4);
256261
EXPECT_EQ(stowf.nbands_total, 1);
257262
EXPECT_EQ(stowf.nchi, 4);

0 commit comments

Comments
 (0)