Skip to content

Commit 0bac03f

Browse files
maki49mohanchen
andauthored
Feature: LR-TDDFT for open-shell systems (#5312)
* do not pass matrix in OperatorLRDiag * spin-up/down nocc, nvirt, npairs * spin up/down PotHxc * move the psi wrapper into LR_Util * refactor hsolver_lr * change X from Psi to pointer in dm_trans and AX * refactor HSolverLR and remove the inheritance of HamiltLR * read/write value tool funcs * store X with ct::Tensor instead of Psi * use const ref instead of pointer for Parallel_2D * key: ULR Hamilt & tear down HSolverLR * pass spin-type to PotHxc * rebase develop and update DM * traverse states outside of act() * fix a fatal bug of AX index * remove nspin_solve * openshell support for spectrum * add parameter lr_unrestricted * fix the parallel copy of eigenvectors * enable complex spin2 and fix a parameter bug * restrict DM size within orb_rcut * remove band-traverse and recover RI-benchmark * update LR cases: cover spin2 and dav_subspace * add a test case for open-shell solver * fix compile error * minor fixes --------- Co-authored-by: Mohan Chen <[email protected]>
1 parent 3203ec6 commit 0bac03f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1465
-1254
lines changed

docs/advanced/input_files/input-main.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3957,6 +3957,13 @@ Currently supported: `RPA`, `LDA`, `PBE`, `HSE`, `HF`.
39573957
- **Description**: The number of 2-particle states to be solved
39583958
- **Default**: 0
39593959

3960+
### lr_unrestricted
3961+
- **Type**: Boolean
3962+
- **Description**: Whether to use unrestricted construction for LR-TDDFT (the matrix size will be doubled).
3963+
- True: Always use unrestricted LR-TDDFT.
3964+
- False: Use unrestricted LR-TDDFT only when the system is open-shell.
3965+
- **Default**: False
3966+
39603967
### abs_wavelen_range
39613968

39623969
- **Type**: Real Real

source/Makefile.Objects

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,6 @@ OBJS_TENSOR=tensor.o\
728728
operator_lr_exx.o\
729729
kernel_xc.o\
730730
pot_hxc_lrtd.o\
731-
hsolver_lrtd.o\
732731
lr_spectrum.o\
733732
hamilt_casida.o\
734733
esolver_lrtd_lcao.o\

source/module_basis/module_ao/parallel_2d.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Parallel_2D
1515
~Parallel_2D() = default;
1616

1717
Parallel_2D& operator=(Parallel_2D&& rhs) = default;
18+
Parallel_2D(Parallel_2D&& rhs) = default;
1819

1920
/// number of local rows
2021
int get_row_size() const

source/module_esolver/esolver.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,8 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
187187
else if (esolver_type == "lr_lcao")
188188
{
189189
// use constructor rather than Init function to initialize reference (instead of pointers) to ucell
190-
if (PARAM.globalv.gamma_only_local){
191-
return new LR::ESolver_LR<double, double>(inp, ucell);
192-
} else if (PARAM.inp.nspin < 2) {
193-
return new LR::ESolver_LR<std::complex<double>, double>(inp, ucell);
194-
} else {
195-
throw std::runtime_error("LR-TDDFT is not implemented for spin polarized case");
196-
}
190+
if (PARAM.globalv.gamma_only_local) { return new LR::ESolver_LR<double, double>(inp, ucell); }
191+
else { return new LR::ESolver_LR<std::complex<double>, double>(inp, ucell); }
197192
}
198193
else if (esolver_type == "ksdft_lr_lcao")
199194
{

source/module_io/read_input_item_tddft.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,12 @@ void ReadInput::item_lr_tddft()
327327
read_sync_bool(input.out_wfc_lr);
328328
this->add_item(item);
329329
}
330+
{
331+
Input_Item item("lr_unrestricted");
332+
item.annotation = "Whether to use unrestricted construction for LR-TDDFT";
333+
read_sync_bool(input.lr_unrestricted);
334+
this->add_item(item);
335+
}
330336
{
331337
Input_Item item("abs_wavelen_range");
332338
item.annotation = "the range of wavelength(nm) to output the absorption spectrum ";
@@ -337,10 +343,6 @@ void ReadInput::item_lr_tddft()
337343
para.input.abs_wavelen_range.push_back(std::stod(item.str_values[i]));
338344
}
339345
};
340-
item.check_value = [](const Input_Item& item, const Parameter& para) {
341-
auto& awr = para.input.abs_wavelen_range;
342-
if (awr.size() < 2) { ModuleBase::WARNING_QUIT("ReadInput", "abs_wavelen_range must have two values"); }
343-
};
344346
sync_doublevec(input.abs_wavelen_range, 2, 0.0);
345347
this->add_item(item);
346348
}

source/module_io/test/read_input_ptest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ TEST_F(InputParaTest, ParaRead)
419419
EXPECT_EQ(param.inp.xc_kernel, "LDA");
420420
EXPECT_EQ(param.inp.lr_solver, "dav");
421421
EXPECT_DOUBLE_EQ(param.inp.lr_thr, 1e-2);
422+
EXPECT_FALSE(param.inp.lr_unrestricted);
422423
EXPECT_FALSE(param.inp.out_wfc_lr);
423424
EXPECT_EQ(param.inp.abs_wavelen_range.size(), 2);
424425
EXPECT_DOUBLE_EQ(param.inp.abs_wavelen_range[0], 0.0);

source/module_lr/AX/AX.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,25 @@ namespace LR
1313
const psi::Psi<double>& c,
1414
const int& nocc,
1515
const int& nvirt,
16-
psi::Psi<double>& AX_istate);
16+
double* AX_istate);
1717
void cal_AX_blas(
1818
const std::vector<container::Tensor>& V_istate,
1919
const psi::Psi<double>& c,
2020
const int& nocc,
2121
const int& nvirt,
22-
psi::Psi<double>& AX_istate,
22+
double* AX_istate,
2323
const bool add_on = true);
2424
#ifdef __MPI
2525
void cal_AX_pblas(
2626
const std::vector<container::Tensor>& V_istate,
2727
const Parallel_2D& pmat,
2828
const psi::Psi<double>& c,
2929
const Parallel_2D& pc,
30-
int naos,
31-
int nocc,
32-
int nvirt,
33-
Parallel_2D& pX,
34-
psi::Psi<double>& AX_istate,
30+
const int& naos,
31+
const int& nocc,
32+
const int& nvirt,
33+
const Parallel_2D& pX,
34+
double* AX_istate,
3535
const bool add_on=true);
3636
#endif
3737
// complex
@@ -40,13 +40,13 @@ namespace LR
4040
const psi::Psi<std::complex<double>>& c,
4141
const int& nocc,
4242
const int& nvirt,
43-
psi::Psi<std::complex<double>>& AX_istate);
43+
std::complex<double>* AX_istate);
4444
void cal_AX_blas(
4545
const std::vector<container::Tensor>& V_istate,
4646
const psi::Psi<std::complex<double>>& c,
4747
const int& nocc,
4848
const int& nvirt,
49-
psi::Psi<std::complex<double>>& AX_istate,
49+
std::complex<double>* AX_istate,
5050
const bool add_on = true);
5151

5252
#ifdef __MPI
@@ -55,11 +55,11 @@ namespace LR
5555
const Parallel_2D& pmat,
5656
const psi::Psi<std::complex<double>>& c,
5757
const Parallel_2D& pc,
58-
int naos,
59-
int nocc,
60-
int nvirt,
61-
Parallel_2D& pX,
62-
psi::Psi<std::complex<double>>& AX_istate,
58+
const int& naos,
59+
const int& nocc,
60+
const int& nvirt,
61+
const Parallel_2D& pX,
62+
std::complex<double>* AX_istate,
6363
const bool add_on = true);
6464
#endif
6565
}

source/module_lr/AX/AX_parallel.cpp

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,25 @@ namespace LR
1414
const Parallel_2D& pmat,
1515
const psi::Psi<double>& c,
1616
const Parallel_2D& pc,
17-
int naos,
18-
int nocc,
19-
int nvirt,
20-
Parallel_2D& pX,
21-
psi::Psi<double>& AX_istate,
17+
const int& naos,
18+
const int& nocc,
19+
const int& nvirt,
20+
const Parallel_2D& pX,
21+
double* AX_istate,
2222
const bool add_on)
2323
{
2424
ModuleBase::TITLE("hamilt_lrtd", "cal_AX_pblas");
25-
assert(pmat.comm() == pc.comm());
26-
assert(pmat.blacs_ctxt == pc.blacs_ctxt);
27-
28-
if (pX.comm() != pmat.comm() || pX.blacs_ctxt != pmat.blacs_ctxt)
29-
LR_Util::setup_2d_division(pX, pmat.get_block_size(), nvirt, nocc, pmat.blacs_ctxt);
30-
else assert(pX.get_local_size() > 0 && AX_istate.get_nbasis() == pX.get_local_size());
25+
assert(pmat.comm() == pc.comm() && pmat.comm() == pX.comm());
26+
assert(pmat.blacs_ctxt == pc.blacs_ctxt && pmat.blacs_ctxt == pX.blacs_ctxt);
27+
assert(pX.get_local_size() > 0);
3128

3229
const int nks = V_istate.size();
3330

3431
Parallel_2D pVc; // for intermediate Vc
3532
LR_Util::setup_2d_division(pVc, pmat.get_block_size(), naos, nocc, pmat.blacs_ctxt);
3633
for (int isk = 0;isk < nks;++isk)
3734
{
38-
AX_istate.fix_k(isk);
35+
const int ax_start = isk * pX.get_local_size();
3936
c.fix_k(isk);
4037

4138
//Vc
@@ -60,7 +57,7 @@ namespace LR
6057
pdgemm_(&transa, &transb, &nvirt, &nocc, &naos,
6158
&alpha, c.get_pointer(), &i1, &ivirt, pc.desc,
6259
Vc.data<double>(), &i1, &i1, pVc.desc,
63-
&beta, AX_istate.get_pointer(), &i1, &i1, pX.desc);
60+
&beta, AX_istate + ax_start, &i1, &i1, pX.desc);
6461

6562
}
6663
}
@@ -70,28 +67,25 @@ namespace LR
7067
const Parallel_2D& pmat,
7168
const psi::Psi<std::complex<double>>& c,
7269
const Parallel_2D& pc,
73-
int naos,
74-
int nocc,
75-
int nvirt,
76-
Parallel_2D& pX,
77-
psi::Psi<std::complex<double>>& AX_istate,
70+
const int& naos,
71+
const int& nocc,
72+
const int& nvirt,
73+
const Parallel_2D& pX,
74+
std::complex<double>* AX_istate,
7875
const bool add_on)
7976
{
8077
ModuleBase::TITLE("hamilt_lrtd", "cal_AX_plas");
81-
assert(pmat.comm() == pc.comm());
82-
assert(pmat.blacs_ctxt == pc.blacs_ctxt);
83-
84-
if (pX.comm() != pmat.comm() || pX.blacs_ctxt != pmat.blacs_ctxt)
85-
LR_Util::setup_2d_division(pX, pmat.get_block_size(), nvirt, nocc, pmat.blacs_ctxt);
86-
else assert(pX.get_local_size() > 0 && AX_istate.get_nbasis() == pX.get_local_size());
78+
assert(pmat.comm() == pc.comm() && pmat.comm() == pX.comm());
79+
assert(pmat.blacs_ctxt == pc.blacs_ctxt && pmat.blacs_ctxt == pX.blacs_ctxt);
80+
assert(pX.get_local_size() > 0);
8781

8882
const int nks = V_istate.size();
8983

9084
Parallel_2D pVc; // for intermediate Vc
9185
LR_Util::setup_2d_division(pVc, pmat.get_block_size(), naos, nocc, pmat.blacs_ctxt);
9286
for (int isk = 0;isk < nks;++isk)
9387
{
94-
AX_istate.fix_k(isk);
88+
const int ax_start = isk * pX.get_local_size();
9589
c.fix_k(isk);
9690

9791
//Vc
@@ -116,7 +110,7 @@ namespace LR
116110
pzgemm_(&transa, &transb, &nvirt, &nocc, &naos,
117111
&alpha, c.get_pointer(), &i1, &ivirt, pc.desc,
118112
Vc.data<std::complex<double>>(), &i1, &i1, pVc.desc,
119-
&beta, AX_istate.get_pointer(), &i1, &i1, pX.desc);
113+
&beta, AX_istate + ax_start, &i1, &i1, pX.desc);
120114
}
121115
}
122116
}

source/module_lr/AX/AX_serial.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,17 @@ namespace LR
99
const psi::Psi<double>& c,
1010
const int& nocc,
1111
const int& nvirt,
12-
psi::Psi<double>& AX_istate)
12+
double* AX_istate)
1313
{
1414
ModuleBase::TITLE("hamilt_lrtd", "cal_AX_forloop");
1515
const int nks = V_istate.size();
1616
int naos = c.get_nbasis();
17-
AX_istate.fix_k(0);
18-
ModuleBase::GlobalFunc::ZEROS(AX_istate.get_pointer(), nks * nocc * nvirt);
17+
ModuleBase::GlobalFunc::ZEROS(AX_istate, nks * nocc * nvirt);
1918

2019
for (int isk = 0;isk < nks;++isk)
2120
{
2221
c.fix_k(isk);
23-
AX_istate.fix_k(isk);
22+
const int ax_start = isk * nocc * nvirt;
2423
for (int i = 0;i < nocc;++i)
2524
{
2625
for (int a = 0;a < nvirt;++a)
@@ -29,7 +28,7 @@ namespace LR
2928
{
3029
for (int mu = 0;mu < naos;++mu)
3130
{
32-
AX_istate(i * nvirt + a) += c(nocc + a, mu) * V_istate[isk].data<double>()[nu * naos + mu] * c(i, nu);
31+
AX_istate[ax_start + i * nvirt + a] += c(nocc + a, mu) * V_istate[isk].data<double>()[nu * naos + mu] * c(i, nu);
3332
}
3433
}
3534
}
@@ -41,18 +40,17 @@ namespace LR
4140
const psi::Psi<std::complex<double>>& c,
4241
const int& nocc,
4342
const int& nvirt,
44-
psi::Psi<std::complex<double>>& AX_istate)
43+
std::complex<double>* AX_istate)
4544
{
4645
ModuleBase::TITLE("hamilt_lrtd", "cal_AX_forloop");
4746
const int nks = V_istate.size();
4847
int naos = c.get_nbasis();
49-
AX_istate.fix_k(0);
50-
ModuleBase::GlobalFunc::ZEROS(AX_istate.get_pointer(), nks * nocc * nvirt);
48+
ModuleBase::GlobalFunc::ZEROS(AX_istate, nks * nocc * nvirt);
5149

5250
for (int isk = 0;isk < nks;++isk)
5351
{
5452
c.fix_k(isk);
55-
AX_istate.fix_k(isk);
53+
const int ax_start = isk * nocc * nvirt;
5654
for (int i = 0;i < nocc;++i)
5755
{
5856
for (int a = 0;a < nvirt;++a)
@@ -61,7 +59,7 @@ namespace LR
6159
{
6260
for (int mu = 0;mu < naos;++mu)
6361
{
64-
AX_istate(i * nvirt + a) += std::conj(c(nocc + a, mu)) * V_istate[isk].data<std::complex<double>>()[nu * naos + mu] * c(i, nu);
62+
AX_istate[ax_start + i * nvirt + a] += std::conj(c(nocc + a, mu)) * V_istate[isk].data<std::complex<double>>()[nu * naos + mu] * c(i, nu);
6563
}
6664
}
6765
}
@@ -74,7 +72,7 @@ namespace LR
7472
const psi::Psi<double>& c,
7573
const int& nocc,
7674
const int& nvirt,
77-
psi::Psi<double>& AX_istate,
75+
double* AX_istate,
7876
const bool add_on)
7977
{
8078
ModuleBase::TITLE("hamilt_lrtd", "cal_AX_blas");
@@ -84,7 +82,7 @@ namespace LR
8482
for (int isk = 0;isk < nks;++isk)
8583
{
8684
c.fix_k(isk);
87-
AX_istate.fix_k(isk);
85+
const int ax_start = isk * nocc * nvirt;
8886

8987
// Vc[naos*nocc]
9088
container::Tensor Vc(DAT::DT_DOUBLE, DEV::CpuDevice, { nocc, naos });// (Vc)^T
@@ -101,15 +99,15 @@ namespace LR
10199
//AX_istate=c^TVc (nvirt major)
102100
dgemm_(&transa, &transb, &nvirt, &nocc, &naos, &alpha,
103101
c.get_pointer(nocc), &naos, Vc.data<double>(), &naos, &beta,
104-
AX_istate.get_pointer(), &nvirt);
102+
AX_istate + ax_start, &nvirt);
105103
}
106104
}
107105
void cal_AX_blas(
108106
const std::vector<container::Tensor>& V_istate,
109107
const psi::Psi<std::complex<double>>& c,
110108
const int& nocc,
111109
const int& nvirt,
112-
psi::Psi<std::complex<double>>& AX_istate,
110+
std::complex<double>* AX_istate,
113111
const bool add_on)
114112
{
115113
ModuleBase::TITLE("hamilt_lrtd", "cal_AX_blas");
@@ -119,7 +117,7 @@ namespace LR
119117
for (int isk = 0;isk < nks;++isk)
120118
{
121119
c.fix_k(isk);
122-
AX_istate.fix_k(isk);
120+
const int ax_start = isk * nocc * nvirt;
123121

124122
// Vc[naos*nocc] (V is hermitian)
125123
container::Tensor Vc(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { nocc, naos });// (Vc)^T
@@ -136,7 +134,7 @@ namespace LR
136134
//AX_istate=c^\dagger Vc (nvirt major)
137135
zgemm_(&transa, &transb, &nvirt, &nocc, &naos, &alpha,
138136
c.get_pointer(nocc), &naos, Vc.data<std::complex<double>>(), &naos, &beta,
139-
AX_istate.get_pointer(), &nvirt);
137+
AX_istate + ax_start, &nvirt);
140138
}
141139
}
142140
}

0 commit comments

Comments
 (0)