Skip to content

Commit f256cc8

Browse files
committed
Fix wrong function position.
1 parent 8fccc4e commit f256cc8

File tree

4 files changed

+94
-92
lines changed

4 files changed

+94
-92
lines changed

source/module_hamilt_lcao/module_deepks/deepks_orbital.cpp

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -69,73 +69,6 @@ void DeePKS_domain::cal_o_delta(const std::vector<TH>& dm_hl,
6969
return;
7070
}
7171

72-
template <typename TK, typename TH>
73-
void DeePKS_domain::collect_h_mat(const Parallel_Orbitals& pv,
74-
const std::vector<std::vector<TK>>& h_in,
75-
std::vector<TH>& h_out,
76-
const int nlocal,
77-
const int nks)
78-
{
79-
ModuleBase::TITLE("DeePKS_domain", "collect_h_tot");
80-
81-
// construct the total H matrix
82-
for (int k = 0; k < nks; k++)
83-
{
84-
#ifdef __MPI
85-
int ir = 0;
86-
int ic = 0;
87-
for (int i = 0; i < nlocal; i++)
88-
{
89-
std::vector<TK> lineH(nlocal - i, TK(0.0));
90-
91-
ir = pv.global2local_row(i);
92-
if (ir >= 0)
93-
{
94-
// data collection
95-
for (int j = i; j < nlocal; j++)
96-
{
97-
ic = pv.global2local_col(j);
98-
if (ic >= 0)
99-
{
100-
int iic = 0;
101-
if (ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER(PARAM.inp.ks_solver))
102-
{
103-
iic = ir + ic * pv.nrow;
104-
}
105-
else
106-
{
107-
iic = ir * pv.ncol + ic;
108-
}
109-
lineH[j - i] = h_in[k][iic];
110-
}
111-
}
112-
}
113-
else
114-
{
115-
// do nothing
116-
}
117-
118-
Parallel_Reduce::reduce_all(lineH.data(), nlocal - i);
119-
120-
for (int j = i; j < nlocal; j++)
121-
{
122-
h_out[k](i, j) = lineH[j - i];
123-
h_out[k](j, i) = h_out[k](i, j); // H is a symmetric matrix
124-
}
125-
}
126-
#else
127-
for (int i = 0; i < nlocal; i++)
128-
{
129-
for (int j = i; j < nlocal; j++)
130-
{
131-
h_out[k](i, j) = h_in[k][i * nlocal + j];
132-
h_out[k](j, i) = h_out[k](i, j); // H is a symmetric matrix
133-
}
134-
}
135-
#endif
136-
}
137-
}
138-
13972
template void DeePKS_domain::cal_o_delta<double, ModuleBase::matrix>(const std::vector<ModuleBase::matrix>& dm_hl,
14073
const std::vector<std::vector<double>>& h_delta,
14174
// std::vector<double>& o_delta,
@@ -151,18 +84,4 @@ template void DeePKS_domain::cal_o_delta<std::complex<double>, ModuleBase::Compl
15184
const Parallel_Orbitals& pv,
15285
const int nks);
15386

154-
template void DeePKS_domain::collect_h_mat<double, ModuleBase::matrix>(
155-
const Parallel_Orbitals& pv,
156-
const std::vector<std::vector<double>>& h_in,
157-
std::vector<ModuleBase::matrix>& h_out,
158-
const int nlocal,
159-
const int nks);
160-
161-
template void DeePKS_domain::collect_h_mat<std::complex<double>, ModuleBase::ComplexMatrix>(
162-
const Parallel_Orbitals& pv,
163-
const std::vector<std::vector<std::complex<double>>>& h_in,
164-
std::vector<ModuleBase::ComplexMatrix>& h_out,
165-
const int nlocal,
166-
const int nks);
167-
16887
#endif

source/module_hamilt_lcao/module_deepks/deepks_orbital.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ namespace DeePKS_domain
2020
// which is defind as sum_mu,nu rho^{hl}_mu,nu <chi_mu|alpha>V(D)<alpha|chi_nu>
2121
// where rho^{hl}_mu,nu = C_{L\mu}C_{L\nu} - C_{H\mu}C_{H\nu}, L for LUMO, H for HOMO
2222

23-
// There are 2 subroutines in this file:
23+
// There are 1 subroutines in this file:
2424
// 1. cal_o_delta, which is used for O_delta calculation
25-
// 2. collect_h_mat, which collect H(k) data from different processes
2625

2726
template <typename TK, typename TH>
2827
void cal_o_delta(const std::vector<TH>& dm_hl,
@@ -31,14 +30,6 @@ void cal_o_delta(const std::vector<TH>& dm_hl,
3130
ModuleBase::matrix& o_delta,
3231
const Parallel_Orbitals& pv,
3332
const int nks);
34-
35-
// Collect data in h_in to matrix h_out. Note that left lower trianger in h_out is filled
36-
template <typename TK, typename TH>
37-
void collect_h_mat(const Parallel_Orbitals& pv,
38-
const std::vector<std::vector<TK>>& h_in,
39-
std::vector<TH>& h_out,
40-
const int nlocal,
41-
const int nks);
4233
} // namespace DeePKS_domain
4334

4435
#endif

source/module_hamilt_lcao/module_deepks/deepks_vdelta.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,73 @@ void DeePKS_domain::cal_e_delta_band(const std::vector<std::vector<TK>>& dm,
7777
return;
7878
}
7979

80+
template <typename TK, typename TH>
81+
void DeePKS_domain::collect_h_mat(const Parallel_Orbitals& pv,
82+
const std::vector<std::vector<TK>>& h_in,
83+
std::vector<TH>& h_out,
84+
const int nlocal,
85+
const int nks)
86+
{
87+
ModuleBase::TITLE("DeePKS_domain", "collect_h_tot");
88+
89+
// construct the total H matrix
90+
for (int k = 0; k < nks; k++)
91+
{
92+
#ifdef __MPI
93+
int ir = 0;
94+
int ic = 0;
95+
for (int i = 0; i < nlocal; i++)
96+
{
97+
std::vector<TK> lineH(nlocal - i, TK(0.0));
98+
99+
ir = pv.global2local_row(i);
100+
if (ir >= 0)
101+
{
102+
// data collection
103+
for (int j = i; j < nlocal; j++)
104+
{
105+
ic = pv.global2local_col(j);
106+
if (ic >= 0)
107+
{
108+
int iic = 0;
109+
if (ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER(PARAM.inp.ks_solver))
110+
{
111+
iic = ir + ic * pv.nrow;
112+
}
113+
else
114+
{
115+
iic = ir * pv.ncol + ic;
116+
}
117+
lineH[j - i] = h_in[k][iic];
118+
}
119+
}
120+
}
121+
else
122+
{
123+
// do nothing
124+
}
125+
126+
Parallel_Reduce::reduce_all(lineH.data(), nlocal - i);
127+
128+
for (int j = i; j < nlocal; j++)
129+
{
130+
h_out[k](i, j) = lineH[j - i];
131+
h_out[k](j, i) = h_out[k](i, j); // H is a symmetric matrix
132+
}
133+
}
134+
#else
135+
for (int i = 0; i < nlocal; i++)
136+
{
137+
for (int j = i; j < nlocal; j++)
138+
{
139+
h_out[k](i, j) = h_in[k][i * nlocal + j];
140+
h_out[k](j, i) = h_out[k](i, j); // H is a symmetric matrix
141+
}
142+
}
143+
#endif
144+
}
145+
}
146+
80147
template void DeePKS_domain::cal_e_delta_band<double>(const std::vector<std::vector<double>>& dm,
81148
const std::vector<std::vector<double>>& V_delta,
82149
const int nks,
@@ -89,4 +156,18 @@ template void DeePKS_domain::cal_e_delta_band<std::complex<double>>(
89156
const Parallel_Orbitals* pv,
90157
double& e_delta_band);
91158

159+
template void DeePKS_domain::collect_h_mat<double, ModuleBase::matrix>(
160+
const Parallel_Orbitals& pv,
161+
const std::vector<std::vector<double>>& h_in,
162+
std::vector<ModuleBase::matrix>& h_out,
163+
const int nlocal,
164+
const int nks);
165+
166+
template void DeePKS_domain::collect_h_mat<std::complex<double>, ModuleBase::ComplexMatrix>(
167+
const Parallel_Orbitals& pv,
168+
const std::vector<std::vector<std::complex<double>>>& h_in,
169+
std::vector<ModuleBase::ComplexMatrix>& h_out,
170+
const int nlocal,
171+
const int nks);
172+
92173
#endif

source/module_hamilt_lcao/module_deepks/deepks_vdelta.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define DEEPKS_VDELTA_H
33

44
#ifdef __DEEPKS
5+
#include "module_base/complexmatrix.h"
6+
#include "module_base/matrix.h"
57
#include "module_basis/module_ao/parallel_orbitals.h"
68

79
namespace DeePKS_domain
@@ -10,8 +12,9 @@ namespace DeePKS_domain
1012
// deepks_vdelta.cpp
1113
//------------------------
1214

13-
// This file contains 1 subroutine for calculating e_delta_bands
15+
// This file contains 2 subroutine for calculating e_delta_bands
1416
// 1. cal_e_delta_band : calculates e_delta_bands
17+
// 2. collect_h_mat, which collect H(k) data from different processes
1518

1619
/// calculate tr(\rho V_delta)
1720
template <typename TK>
@@ -20,6 +23,14 @@ void cal_e_delta_band(const std::vector<std::vector<TK>>& dm,
2023
const int nks,
2124
const Parallel_Orbitals* pv,
2225
double& e_delta_band);
26+
27+
// Collect data in h_in to matrix h_out. Note that left lower trianger in h_out is filled
28+
template <typename TK, typename TH>
29+
void collect_h_mat(const Parallel_Orbitals& pv,
30+
const std::vector<std::vector<TK>>& h_in,
31+
std::vector<TH>& h_out,
32+
const int nlocal,
33+
const int nks);
2334
} // namespace DeePKS_domain
2435
#endif
2536
#endif

0 commit comments

Comments
 (0)