Skip to content

Commit 42492ee

Browse files
authored
Add function to support reading cereal binary format HexxR files. (#5453)
* update the default setting for symmetry * modify the unit test for read_input_item_system.cpp * Add support for reading old version HexxR* files. * Code standardization for #include.
1 parent 67e6cbb commit 42492ee

File tree

3 files changed

+74
-9
lines changed

3 files changed

+74
-9
lines changed

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.hpp

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,40 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
104104
if (PARAM.inp.calculation == "nscf" && GlobalC::exx_info.info_global.cal_exx)
105105
{ // if nscf, read HexxR first and reallocate hR according to the read-in HexxR
106106
const std::string file_name_exx = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
107-
if (GlobalC::exx_info.info_ri.real_number)
107+
bool all_exist = true;
108+
for (int is=0;is<PARAM.inp.nspin;++is)
109+
{
110+
std::ifstream ifs(file_name_exx + "_" + std::to_string(is) + ".csr");
111+
if (!ifs) { all_exist = false; break; }
112+
}
113+
if (all_exist)
108114
{
109-
ModuleIO::read_Hexxs_csr(file_name_exx, GlobalC::ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxd);
110-
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxd, this->hR); }
115+
// Read HexxR in CSR format
116+
if (GlobalC::exx_info.info_ri.real_number)
117+
{
118+
ModuleIO::read_Hexxs_csr(file_name_exx, GlobalC::ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxd);
119+
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxd, this->hR); }
120+
}
121+
else
122+
{
123+
ModuleIO::read_Hexxs_csr(file_name_exx, GlobalC::ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxc);
124+
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxc, this->hR); }
125+
}
111126
}
112127
else
113128
{
114-
ModuleIO::read_Hexxs_csr(file_name_exx, GlobalC::ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxc);
115-
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxc, this->hR); }
129+
// Read HexxR in binary format (old version)
130+
const std::string file_name_exx_cereal = PARAM.globalv.global_readin_dir + "HexxR_" + std::to_string(GlobalV::MY_RANK);
131+
if (GlobalC::exx_info.info_ri.real_number)
132+
{
133+
ModuleIO::read_Hexxs_cereal(file_name_exx_cereal, *Hexxd);
134+
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxd, this->hR); }
135+
}
136+
else
137+
{
138+
ModuleIO::read_Hexxs_cereal(file_name_exx_cereal, *Hexxc);
139+
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxc, this->hR); }
140+
}
116141
}
117142
this->use_cell_nearest = false;
118143
}
@@ -181,11 +206,32 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
181206
{
182207
// read in Hexx(R)
183208
const std::string restart_HR_path = GlobalC::restart.folder + "HexxR" + std::to_string(GlobalV::MY_RANK);
184-
if (GlobalC::exx_info.info_ri.real_number) {
185-
ModuleIO::read_Hexxs_csr(restart_HR_path, GlobalC::ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxd);
209+
bool all_exist = true;
210+
for (int is = 0; is < PARAM.inp.nspin; ++is)
211+
{
212+
std::ifstream ifs(restart_HR_path + "_" + std::to_string(is) + ".csr");
213+
if (!ifs) { all_exist = false; break; }
214+
}
215+
if (all_exist)
216+
{
217+
// Read HexxR in CSR format
218+
if (GlobalC::exx_info.info_ri.real_number) {
219+
ModuleIO::read_Hexxs_csr(restart_HR_path, GlobalC::ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxd);
220+
}
221+
else {
222+
ModuleIO::read_Hexxs_csr(restart_HR_path, GlobalC::ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxc);
223+
}
186224
}
187-
else {
188-
ModuleIO::read_Hexxs_csr(restart_HR_path, GlobalC::ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxc);
225+
else
226+
{
227+
// Read HexxR in binary format (old version)
228+
const std::string restart_HR_path_cereal = GlobalC::restart.folder + "HexxR_" + std::to_string(GlobalV::MY_RANK);
229+
if (GlobalC::exx_info.info_ri.real_number) {
230+
ModuleIO::read_Hexxs_cereal(restart_HR_path_cereal, *Hexxd);
231+
}
232+
else {
233+
ModuleIO::read_Hexxs_cereal(restart_HR_path_cereal, *Hexxc);
234+
}
189235
}
190236
}
191237

source/module_io/restart_exx_csr.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22
#include "module_base/abfs-vector3_order.h"
33
#include "module_cell/unitcell.h"
4+
#include "module_ri/serialization_cereal.h"
45
#include <RI/global/Tensor.h>
56
#include <map>
67

@@ -15,6 +16,11 @@ namespace ModuleIO
1516
const int nspin, const int nbasis,
1617
std::vector<std::map<int, std::map<TAC, RI::Tensor<Tdata>>>>& Hexxs);
1718

19+
/// read Hexxs in cereal format
20+
template<typename Tdata>
21+
void read_Hexxs_cereal(const std::string& file_name,
22+
std::vector<std::map<int, std::map<TAC, RI::Tensor<Tdata>>>>& Hexxs);
23+
1824
/// write Hexxs in CSR format
1925
template<typename Tdata>
2026
void write_Hexxs_csr(const std::string& file_name, const UnitCell& ucell,

source/module_io/restart_exx_csr.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "module_cell/unitcell.h"
44
#include "module_io/csr_reader.h"
55
#include "module_io/write_HS_sparse.h"
6+
#include "module_ri/serialization_cereal.h"
67
#include <RI/global/Tensor.h>
78
#include <map>
89

@@ -49,6 +50,18 @@ namespace ModuleIO
4950
}
5051
}
5152

53+
template<typename Tdata>
54+
void read_Hexxs_cereal(const std::string& file_name,
55+
std::vector<std::map<int, std::map<TAC, RI::Tensor<Tdata>>>>& Hexxs)
56+
{
57+
ModuleBase::TITLE("Exx_LRI", "read_Hexxs_cereal");
58+
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");
59+
std::ifstream ifs(file_name, std::ios::binary);
60+
cereal::BinaryInputArchive iar(ifs);
61+
iar(Hexxs);
62+
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");
63+
}
64+
5265
template<typename Tdata>
5366
std::map<Abfs::Vector3_Order<int>, std::map<size_t, std::map<size_t, Tdata>>>
5467
calculate_RI_Tensor_sparse(const double& sparse_threshold,

0 commit comments

Comments
 (0)