Skip to content

Commit f9b1a7e

Browse files
author
dyzheng
committed
Fix: init_chg hr support nspin=2
1 parent f4bc269 commit f9b1a7e

File tree

6 files changed

+40
-10
lines changed

6 files changed

+40
-10
lines changed

docs/advanced/scf/initialization.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ In LCAO basis, wavefunction can be read to calculate initial charge density. The
1111
- `file` : initial charge density from files produced by previous calculations with [`out_chg 1`](../elec_properties/charge.md).
1212
- `auto`: Abacus first attempts to read the density from a file; if not found, it defaults to using atomic density.
1313
- `dm` (LCAO only): initial charge density from density matrix files in CSR format. For `nspin=1`, reads `dmrs1_nao.csr`. For `nspin=2` (spin-polarized), reads both `dmrs1_nao.csr` (spin-up) and `dmrs2_nao.csr` (spin-down). These files are generated by previous calculations with [`out_dmr 1`](../elec_properties/density_matrix.md). This method is particularly useful for restarting spin-polarized calculations.
14+
- `hr` (LCAO only): initial charge density from Hamiltonian matrix files in CSR format. The Hamiltonian is read from file, then diagonalized to obtain wavefunctions and charge density. For `nspin=1`, reads `hrs1_nao.csr`. For `nspin=2` (spin-polarized), reads both `hrs1_nao.csr` (spin-up) and `hrs2_nao.csr` (spin-down). These files are generated by previous calculations with [`out_mat_hs2 1`](../input_files/input-main.md).
1415

1516
## Wave function
1617
`init_wfc` is used for choosing the method of wavefunction coefficient initialization.

source/source_esolver/esolver_ks_lcao.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,26 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
182182
if(PARAM.inp.init_chg == "hr")
183183
{
184184
//! 13.1.2) init HR from file
185-
std::string hrfile = PARAM.globalv.global_readin_dir + "hrs1_nao.csr";
186-
LCAO_domain::init_hr_from_file<TR>(
187-
hrfile,
188-
dynamic_cast<hamilt::HamiltLCAO<TK, TR>*>(this->p_hamilt)->getHR(),
189-
ucell,
190-
&(this->pv)
191-
);
185+
if (PARAM.inp.nspin == 2)
186+
{
187+
// nspin=2: load spin-up into first half of hRS2, spin-down into second half
188+
const std::string hrfile_up = PARAM.globalv.global_readin_dir + "/hrs1_nao.csr";
189+
LCAO_domain::init_hr_from_file<TR>(hrfile_up, hamilt_lcao->getHR(), ucell, &(this->pv));
190+
191+
// switch hR data pointer to spin-down half, then read hrs2
192+
auto& hRS2 = hamilt_lcao->getHRS2();
193+
hamilt_lcao->getHR()->allocate(hRS2.data() + hRS2.size() / 2, 0);
194+
const std::string hrfile_down = PARAM.globalv.global_readin_dir + "/hrs2_nao.csr";
195+
LCAO_domain::init_hr_from_file<TR>(hrfile_down, hamilt_lcao->getHR(), ucell, &(this->pv));
196+
197+
// restore hR to spin-up half (refresh(false) will also do this, but be explicit)
198+
hamilt_lcao->getHR()->allocate(hRS2.data(), 0);
199+
}
200+
else
201+
{
202+
const std::string hrfile = PARAM.globalv.global_readin_dir + "/hrs1_nao.csr";
203+
LCAO_domain::init_hr_from_file<TR>(hrfile, hamilt_lcao->getHR(), ucell, &(this->pv));
204+
}
192205
this->p_hamilt->refresh(false);
193206
hsolver::HSolverLCAO<TK> hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver);
194207
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, *this->dmat.dm,

source/source_lcao/hamilt_lcao.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,10 @@ void HamiltLCAO<TK, TR>::refresh(bool yes)
537537
this->refresh_times = 0;
538538
if (PARAM.inp.nspin == 2)
539539
{
540-
ModuleBase::WARNING_QUIT("HamiltLCAO::refresh",
541-
"When turning off the refresh flag, the nspin==2 case is not supported yet.");
540+
// HR has been loaded from file into both halves of hRS2.
541+
// Reset to spin-up; updateHk will switch pointers as needed.
542+
this->current_spin = 0;
543+
this->hR->allocate(this->hRS2.data(), 0);
542544
}
543545
}
544546
}

source/source_lcao/hamilt_lcao.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class HamiltLCAO : public Hamilt<TK>
123123
}
124124
#endif
125125

126+
/// get hRS2 buffer for NSPIN=2 case (spin-up in first half, spin-down in second half)
127+
std::vector<TR>& getHRS2() { return this->hRS2; }
128+
126129
/// refresh the status of HR
127130
void refresh(bool yes) override;
128131

source/source_lcao/test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ remove_definitions(-D__ROCM)
55
if(ENABLE_LCAO)
66
AddTest(
77
TARGET init_dm_from_file_test
8-
LIBS parameter ${math_libs} base device cell_info
8+
LIBS parameter ${math_libs} base device
99
SOURCES test_init_dm_from_file.cpp tmp_mocks.cpp
1010
${ABACUS_SOURCE_DIR}/source_estate/module_dm/density_matrix.cpp
1111
${ABACUS_SOURCE_DIR}/source_estate/module_dm/density_matrix_io.cpp
@@ -15,6 +15,7 @@ AddTest(
1515
${ABACUS_SOURCE_DIR}/source_lcao/module_hcontainer/read_hcontainer.cpp
1616
${ABACUS_SOURCE_DIR}/source_lcao/module_hcontainer/func_transfer.cpp
1717
${ABACUS_SOURCE_DIR}/source_lcao/module_hcontainer/func_folding.cpp
18+
${ABACUS_SOURCE_DIR}/source_lcao/module_hcontainer/transfer.cpp
1819
${ABACUS_SOURCE_DIR}/source_basis/module_ao/parallel_orbitals.cpp
1920
${ABACUS_SOURCE_DIR}/source_io/module_output/sparse_matrix.cpp
2021
${ABACUS_SOURCE_DIR}/source_io/module_output/csr_reader.cpp

source/source_lcao/test/tmp_mocks.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
#include "source_cell/unitcell.h"
3+
#include "source_cell/module_neighbor/sltk_grid_driver.h"
34

45
// constructor of Atom
56
Atom::Atom() {}
@@ -46,3 +47,12 @@ void UnitCell::set_iat2iwt(const int& npol_in)
4647
}
4748
return;
4849
}
50+
51+
// stub for Grid_Driver::Find_atom (used by density_matrix_io.cpp but not exercised in test)
52+
void Grid_Driver::Find_atom(const UnitCell& ucell,
53+
const ModuleBase::Vector3<double>& tau,
54+
const int& T,
55+
const int& I,
56+
AdjacentAtomInfo* adjs) const
57+
{
58+
}

0 commit comments

Comments
 (0)