Skip to content

Commit f1cd921

Browse files
committed
Update namespace ModuleGint
1 parent 491b324 commit f1cd921

File tree

5 files changed

+124
-93
lines changed

5 files changed

+124
-93
lines changed

source/module_hamilt_lcao/module_gint/temp_gint/gint_common.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,11 @@ void transfer_hr_gint_to_hR(std::shared_ptr<const HContainer<T>> hr_gint, HConta
124124

125125
// gint_info should not have been a parameter, but it was added to initialize dm_gint_full
126126
// In the future, we might try to remove the gint_info parameter
127+
template<typename T>
127128
void transfer_dm_2d_to_gint(
128129
std::shared_ptr<const GintInfo> gint_info,
129-
std::vector<HContainer<double>*> dm,
130-
std::vector<std::shared_ptr<HContainer<double>>> dm_gint)
130+
std::vector<HContainer<T>*> dm,
131+
std::vector<std::shared_ptr<HContainer<T>>> dm_gint)
131132
{
132133
// To check whether input parameter dm_2d has been initialized
133134
#ifdef __DEBUG
@@ -150,12 +151,12 @@ void transfer_dm_2d_to_gint(
150151
{
151152
#ifdef __MPI
152153
const int npol = 2;
153-
std::shared_ptr<HContainer<double>> dm_full = gint_info->get_hr<double>(npol);
154+
std::shared_ptr<HContainer<T>> dm_full = gint_info->get_hr<T>(npol);
154155
hamilt::transferParallels2Serials(*dm[0], dm_full.get());
155156
#else
156-
HContainer<double>* dm_full = dm[0];
157+
HContainer<T>* dm_full = dm[0];
157158
#endif
158-
std::vector<double*> tmp_pointer(4, nullptr);
159+
std::vector<T*> tmp_pointer(4, nullptr);
159160
for (int iap = 0; iap < dm_full->size_atom_pairs(); iap++)
160161
{
161162
auto& ap = dm_full->get_atom_pair(iap);
@@ -169,7 +170,7 @@ void transfer_dm_2d_to_gint(
169170
tmp_pointer[is] =
170171
dm_gint[is]->find_matrix(iat1, iat2, r_index)->get_pointer();
171172
}
172-
double* data_full = ap.get_pointer(ir);
173+
T* data_full = ap.get_pointer(ir);
173174
for (int irow = 0; irow < ap.get_row_size(); irow += 2)
174175
{
175176
for (int icol = 0; icol < ap.get_col_size(); icol += 2)
@@ -191,6 +192,18 @@ void transfer_dm_2d_to_gint(
191192
}
192193

193194

194-
template void transfer_hr_gint_to_hR(std::shared_ptr<const HContainer<double>> hr_gint, HContainer<double>* hR);
195-
template void transfer_hr_gint_to_hR(std::shared_ptr<const HContainer<std::complex<double>>> hr_gint, HContainer<std::complex<double>>* hR);
195+
template void transfer_hr_gint_to_hR(
196+
std::shared_ptr<const HContainer<double>> hr_gint,
197+
HContainer<double>* hR);
198+
template void transfer_hr_gint_to_hR(
199+
std::shared_ptr<const HContainer<std::complex<double>>> hr_gint,
200+
HContainer<std::complex<double>>* hR);
201+
template void transfer_dm_2d_to_gint(
202+
std::shared_ptr<const GintInfo> gint_info,
203+
std::vector<HContainer<double>*> dm,
204+
std::vector<std::shared_ptr<HContainer<double>>> dm_gint);
205+
template void transfer_dm_2d_to_gint(
206+
std::shared_ptr<const GintInfo> gint_info,
207+
std::vector<HContainer<std::complex<double>>*> dm,
208+
std::vector<std::shared_ptr<HContainer<std::complex<double>>>> dm_gint);
196209
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_common.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ namespace ModuleGint
1313
template <typename T>
1414
void transfer_hr_gint_to_hR(std::shared_ptr<const HContainer<T>> hr_gint, HContainer<T>* hR);
1515

16+
template<typename T>
1617
void transfer_dm_2d_to_gint(
1718
std::shared_ptr<const GintInfo> gint_info,
18-
std::vector<HContainer<double>*> dm,
19-
std::vector<std::shared_ptr<HContainer<double>>> dm_gint);
19+
std::vector<HContainer<T>*> dm,
20+
std::vector<std::shared_ptr<HContainer<T>>> dm_gint);
2021

2122
}

source/module_hamilt_lcao/module_gint/temp_gint/phi_operator.cpp

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -66,75 +66,6 @@ void PhiOperator::set_ddphi(
6666
}
6767
}
6868

69-
void PhiOperator::phi_mul_dm(
70-
const double* phi,
71-
const HContainer<double>& dm,
72-
const bool is_symm, double* phi_dm) const
73-
{
74-
ModuleBase::GlobalFunc::ZEROS(phi_dm, rows_ * cols_);
75-
// parameters for lapack subroutines
76-
constexpr char side = 'L';
77-
constexpr char uplo = 'U';
78-
const char trans = 'N';
79-
const double alpha = 1.0;
80-
const double beta = 1.0;
81-
const double alpha1 = is_symm ? 2.0 : 1.0;
82-
83-
for(int i = 0; i < biggrid_->get_atoms_num(); ++i)
84-
{
85-
const auto atom_i = biggrid_->get_atom(i);
86-
const auto r_i = atom_i->get_R();
87-
88-
if(is_symm)
89-
{
90-
const auto dm_mat = dm.find_matrix(atom_i->get_iat(), atom_i->get_iat(), 0, 0, 0);
91-
dsymm_(&side, &uplo, &atoms_phi_len_[i], &rows_, &alpha, dm_mat->get_pointer(), &atoms_phi_len_[i],
92-
&phi[0 * cols_ + atoms_startidx_[i]], &cols_, &beta, &phi_dm[0 * cols_ + atoms_startidx_[i]], &cols_);
93-
}
94-
95-
const int start = is_symm ? i + 1 : 0;
96-
97-
for(int j = start; j < biggrid_->get_atoms_num(); ++j)
98-
{
99-
const auto atom_j = biggrid_->get_atom(j);
100-
const auto r_j = atom_j->get_R();
101-
// FIXME may be r = r_j - r_i
102-
const auto dm_mat = dm.find_matrix(atom_i->get_iat(), atom_j->get_iat(), r_i-r_j);
103-
104-
// if dm_mat is nullptr, it means this atom pair does not affect any meshgrid in the unitcell
105-
if(dm_mat == nullptr)
106-
{
107-
continue;
108-
}
109-
110-
int start_idx = get_atom_pair_start_end_idx_(i, j).first;
111-
int end_idx = get_atom_pair_start_end_idx_(i, j).second;
112-
const int len = end_idx - start_idx + 1;
113-
114-
// if len<=0, it means this atom pair does not affect any meshgrid in this biggrid
115-
if(len <= 0)
116-
{
117-
continue;
118-
}
119-
120-
dgemm_(&trans, &trans, &atoms_phi_len_[j], &len, &atoms_phi_len_[i], &alpha1, dm_mat->get_pointer(), &atoms_phi_len_[j],
121-
&phi[start_idx * cols_ + atoms_startidx_[i]], &cols_, &beta, &phi_dm[start_idx * cols_ + atoms_startidx_[j]], &cols_);
122-
}
123-
}
124-
}
125-
126-
void PhiOperator::phi_dot_phi(
127-
const double*const phi_1, // phi_1(igrid,iwt)
128-
const double*const phi_2, // phi_2(igrid,iwt)
129-
double*const rho) const // rho(igrid)
130-
{
131-
const int inc = 1;
132-
for(int i = 0; i < biggrid_->get_mgrids_num(); ++i)
133-
{
134-
rho[meshgrids_local_idx_[i]] += ddot_(&cols_, phi_1+i*cols_, &inc, phi_2+i*cols_, &inc);
135-
}
136-
}
137-
13869
void PhiOperator::phi_dot_dphi(
13970
const double* phi,
14071
const double* dphi_x,

source/module_hamilt_lcao/module_gint/temp_gint/phi_operator.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,30 +47,37 @@ class PhiOperator
4747
double* ddphi_xx, double* ddphi_xy, double* ddphi_xz,
4848
double* ddphi_yy, double* ddphi_yz, double* ddphi_zz) const;
4949

50+
// phi_dm(ir,iwt_2) = \sum_{iwt_1} phi(ir,iwt_1) * dm(iwt_1,iwt_2)
51+
template<typename T>
5052
void phi_mul_dm(
51-
const double* phi,
52-
const HContainer<double>& dm,
53-
const bool is_symm, double* phi_dm) const;
53+
const T*const phi,
54+
const HContainer<T>& dm,
55+
const bool is_symm,
56+
T*const phi_dm) const;
5457

58+
// result(ir) = phi(ir) * vl(ir)
5559
template<typename T>
5660
void phi_mul_vldr3(
5761
const T*const vl,
5862
const T dr3,
5963
const T*const phi,
6064
T*const result) const;
6165

66+
// hr(iwt_i,iwt_j) = \sum_{ir} phi_i(ir,iwt_i) * phi_i(ir,iwt_j)
6267
// this is a thread-safe function
6368
template<typename T>
6469
void phi_mul_phi(
65-
const T*const phi_i, // phi_i(igrid,iwt_i)
66-
const T*const phi_j, // phi_j(igrid,iwt_j)
70+
const T*const phi_i, // phi_i(ir,iwt_i)
71+
const T*const phi_j, // phi_j(ir,iwt_j)
6772
HContainer<T>& hr, // hr(iwt_i,iwt_j)
6873
const Triangular_Matrix triangular_matrix) const;
6974

75+
// rho(ir) = \sum_{iwt} \phi_i(ir,iwt) * \phi_j(ir,iwt)
76+
template<typename T>
7077
void phi_dot_phi(
71-
const double*const phi_1, // phi_1(igrid,iwt)
72-
const double*const phi_2, // phi_2(igrid,iwt)
73-
double*const rho) const; // rho(igrid)
78+
const T*const phi_i, // phi_i(ir,iwt)
79+
const T*const phi_j, // phi_j(ir,iwt)
80+
T*const rho) const; // rho(ir)
7481

7582
void phi_dot_dphi(
7683
const double* phi,

source/module_hamilt_lcao/module_gint/temp_gint/phi_operator.hpp

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,72 @@ void PhiOperator::set_phi(T* phi) const
1818
}
1919
}
2020

21+
// phi_dm(ir,iwt_2) = \sum_{iwt_1} phi(ir,iwt_1) * dm(iwt_1,iwt_2)
22+
template<typename T>
23+
void PhiOperator::phi_mul_dm(
24+
const T*const phi,
25+
const HContainer<T>& dm,
26+
const bool is_symm,
27+
T*const phi_dm) const
28+
{
29+
ModuleBase::GlobalFunc::ZEROS(phi_dm, rows_ * cols_);
30+
31+
for(int i = 0; i < biggrid_->get_atoms_num(); ++i)
32+
{
33+
const auto atom_i = biggrid_->get_atom(i);
34+
const auto r_i = atom_i->get_R();
35+
36+
if(is_symm)
37+
{
38+
const auto dm_mat = dm.find_matrix(atom_i->get_iat(), atom_i->get_iat(), 0, 0, 0);
39+
constexpr T alpha = 1.0;
40+
constexpr T beta = 1.0;
41+
BlasConnector::symm_cm(
42+
'L', 'U',
43+
atoms_phi_len_[i], rows_,
44+
alpha, dm_mat->get_pointer(), atoms_phi_len_[i],
45+
&phi[0 * cols_ + atoms_startidx_[i]], cols_,
46+
beta, &phi_dm[0 * cols_ + atoms_startidx_[i]], cols_);
47+
}
48+
49+
const int start = is_symm ? i + 1 : 0;
50+
51+
for(int j = start; j < biggrid_->get_atoms_num(); ++j)
52+
{
53+
const auto atom_j = biggrid_->get_atom(j);
54+
const auto r_j = atom_j->get_R();
55+
// FIXME may be r = r_j - r_i
56+
const auto dm_mat = dm.find_matrix(atom_i->get_iat(), atom_j->get_iat(), r_i-r_j);
57+
58+
// if dm_mat is nullptr, it means this atom pair does not affect any meshgrid in the unitcell
59+
if(dm_mat == nullptr)
60+
{
61+
continue;
62+
}
63+
64+
const int start_idx = get_atom_pair_start_end_idx_(i, j).first;
65+
const int end_idx = get_atom_pair_start_end_idx_(i, j).second;
66+
const int len = end_idx - start_idx + 1;
67+
68+
// if len<=0, it means this atom pair does not affect any meshgrid in this biggrid
69+
if(len <= 0)
70+
{
71+
continue;
72+
}
73+
74+
const T alpha = is_symm ? 2.0 : 1.0;
75+
constexpr T beta = 1.0;
76+
BlasConnector::gemm(
77+
'N', 'N',
78+
len, atoms_phi_len_[j], atoms_phi_len_[i],
79+
alpha, &phi[start_idx * cols_ + atoms_startidx_[i]], cols_,
80+
dm_mat->get_pointer(), atoms_phi_len_[j],
81+
beta, &phi_dm[start_idx * cols_ + atoms_startidx_[j]], cols_);
82+
}
83+
}
84+
}
85+
86+
// result(ir) = phi(ir) * vl(ir)
2187
template<typename T>
2288
void PhiOperator::phi_mul_vldr3(
2389
const T*const vl,
@@ -37,17 +103,15 @@ void PhiOperator::phi_mul_vldr3(
37103
}
38104
}
39105

106+
// hr(iwt_i,iwt_j) = \sum_{ir} phi_i(ir,iwt_i) * phi_i(ir,iwt_j)
40107
// this is a thread-safe function
41108
template<typename T>
42109
void PhiOperator::phi_mul_phi(
43-
const T*const phi_i, // phi_i(igrid,iwt_i)
44-
const T*const phi_j, // phi_j(igrid,iwt_j)
110+
const T*const phi_i, // phi_i(ir,iwt_i)
111+
const T*const phi_j, // phi_j(ir,iwt_j)
45112
HContainer<T>& hr, // hr(iwt_i,iwt_j)
46113
const Triangular_Matrix triangular_matrix) const
47114
{
48-
const char transa='T', transb='N';
49-
const T alpha=1, beta=1;
50-
51115
std::vector<T> tmp_hr;
52116
for(int i = 0; i < biggrid_->get_atoms_num(); ++i)
53117
{
@@ -94,8 +158,9 @@ void PhiOperator::phi_mul_phi(
94158
tmp_hr.resize(n_i * n_j);
95159
ModuleBase::GlobalFunc::ZEROS(tmp_hr.data(), n_i*n_j);
96160

161+
constexpr T alpha=1, beta=1;
97162
BlasConnector::gemm(
98-
transa, transb, n_i, n_j, len,
163+
'T', 'N', n_i, n_j, len,
99164
alpha, phi_i + start_idx * cols_ + atoms_startidx_[i], cols_,
100165
phi_j + start_idx * cols_ + atoms_startidx_[j], cols_,
101166
beta, tmp_hr.data(), n_j,
@@ -106,4 +171,18 @@ void PhiOperator::phi_mul_phi(
106171
}
107172
}
108173

174+
// rho(ir) = \sum_{iwt} \phi_i(ir,iwt) * \phi_j^*(ir,iwt)
175+
template<typename T>
176+
void PhiOperator::phi_dot_phi(
177+
const T*const phi_i, // phi_i(igrid,iwt)
178+
const T*const phi_j, // phi_j(igrid,iwt)
179+
T*const rho) const // rho(igrid)
180+
{
181+
constexpr int inc = 1;
182+
for(int i = 0; i < biggrid_->get_mgrids_num(); ++i)
183+
{
184+
rho[meshgrids_local_idx_[i]] += BlasConnector::dotc(cols_, phi_j+i*cols_, inc, phi_i+i*cols_, inc);
185+
}
186+
}
187+
109188
} // namespace ModuleGint

0 commit comments

Comments
 (0)