Skip to content

Commit ea6dbc7

Browse files
authored
Refactor: change the readin directory of deepks_projdm.dat (#6673)
* Fix a bug. * Move deepks_projdm.dat out from OUT.autotest/. * Add a input shape check for DeePKS model.
1 parent 9ce0542 commit ea6dbc7

File tree

6 files changed

+14
-6
lines changed

6 files changed

+14
-6
lines changed

source/source_lcao/module_deepks/deepks_basic.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,15 @@ void DeePKS_domain::cal_edelta_gedm(const int nat,
197197
// input_dim:(natom, des_per_atom)
198198
inputs.push_back(torch::cat(descriptor, 0).reshape({1, nat, des_per_atom}));
199199
std::vector<torch::Tensor> ec;
200-
ec.push_back(model_deepks.forward(inputs).toTensor()); // Hartree
200+
try
201+
{
202+
ec.push_back(model_deepks.forward(inputs).toTensor()); // Hartree
203+
}
204+
catch (const c10::Error& e)
205+
{
206+
ModuleBase::WARNING_QUIT("DeePKS_domain::cal_edelta_gedm", "Please check whether the input shape required by model file matches the descriptor!");
207+
throw;
208+
}
201209
E_delta = ec[0].item<double>() * 2; // Ry; *2 is for Hartree to Ry
202210

203211
// cal gedm

source/source_lcao/module_deepks/deepks_pdm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ void DeePKS_domain::read_pdm(bool read_pdm_file,
3737
{
3838
if (read_pdm_file && !init_pdm) // for DeePKS NSCF calculation
3939
{
40-
const std::string file_projdm = PARAM.globalv.global_out_dir + "deepks_projdm.dat";
40+
const std::string file_projdm = PARAM.globalv.global_readin_dir + "deepks_projdm.dat";
4141
std::ifstream ifs(file_projdm.c_str());
4242

4343
if (!ifs)

source/source_lcao/spar_hsr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ void sparse_format::cal_HSR(const UnitCell& ucell,
9292
// cal_STN_R_sparse(current_spin, sparse_thr);
9393
if (nspin == 1 || nspin == 2)
9494
{
95-
hamilt::HamiltLCAO<std::complex<double>, double>* p_ham_lcao
96-
= dynamic_cast<hamilt::HamiltLCAO<std::complex<double>, double>*>(p_ham);
95+
hamilt::HamiltLCAO<TK, double>* p_ham_lcao
96+
= dynamic_cast<hamilt::HamiltLCAO<TK, double>*>(p_ham);
9797

9898
HS_Arrays.all_R_coor = get_R_range(*(p_ham_lcao->getHR()));
9999

source/source_lcao/spar_st.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ void sparse_format::cal_SR(
2828
// cal_STN_R_sparse(current_spin, sparse_thr);
2929
if (nspin == 1 || nspin == 2)
3030
{
31-
hamilt::HamiltLCAO<std::complex<double>, double>* p_ham_lcao
32-
= dynamic_cast<hamilt::HamiltLCAO<std::complex<double>, double>*>(p_ham);
31+
hamilt::HamiltLCAO<TK, double>* p_ham_lcao
32+
= dynamic_cast<hamilt::HamiltLCAO<TK, double>*>(p_ham);
3333
const int cspin = 0;
3434
sparse_format::cal_HContainer<double>(pv, sparse_thr, *(p_ham_lcao->getSR()), SR_sparse);
3535
}

0 commit comments

Comments
 (0)