Skip to content

Commit 41b5045

Browse files
authored
Accelerate cal_overlap. (#6460)
1 parent a7cff68 commit 41b5045

File tree

5 files changed

+64
-45
lines changed

5 files changed

+64
-45
lines changed

source/source_base/vector3.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ template <class T> class Vector3
170170
}
171171

172172
/**
173-
* @brief Get the square of nomr of a Vector3
173+
* @brief Get the square of norm of a Vector3
174174
*
175175
* @return T
176176
*/
177-
T norm2(void) const
177+
inline T norm2(void) const
178178
{
179179
return x * x + y * y + z * z;
180180
}
@@ -184,7 +184,7 @@ template <class T> class Vector3
184184
*
185185
* @return T
186186
*/
187-
T norm(void) const
187+
inline T norm(void) const
188188
{
189189
return sqrt(norm2());
190190
}

source/source_basis/module_ao/ORB_gaunt_table.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,6 @@ void ORB_gaunt_table::init_Ylm_Gaunt
154154
}
155155
*/
156156

157-
int ORB_gaunt_table::get_lm_index(
158-
const int l,
159-
const int m)
160-
{
161-
return l*l+m;
162-
}
163-
164157

165158
///effective pointers
166159
int ORB_gaunt_table::EP_EL(const int& L)

source/source_basis/module_ao/ORB_gaunt_table.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ class ORB_gaunt_table
8686
/// ------------------------------------
8787
void init_Gaunt(const int &lmax);
8888

89-
static int get_lm_index(const int l, const int m);
89+
static inline int get_lm_index(const int l, const int m)
90+
{
91+
return l*l+m;
92+
}
9093

9194
static int Index_M(const int& m);
9295

source/source_lcao/center2_orb-orb11.cpp

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,20 @@ void Center2_Orb::Orb11::init_radial_table(const std::set<size_t>& radials)
5858
const size_t rmesh = Center2_Orb::get_rmesh(this->nA.getRcut(), this->nB.getRcut(), dr_);
5959

6060
std::set<size_t> radials_used;
61-
for (const size_t& ir: radials) {
62-
if (ir < rmesh) {
61+
for (const size_t& ir: radials)
62+
{
63+
if (ir < rmesh)
64+
{
6365
radials_used.insert(ir);
64-
}
65-
}
66+
}
67+
}
6668

6769
for (int LAB = std::abs(LA - LB); LAB <= LA + LB; ++LAB)
6870
{
69-
if ((LAB - std::abs(LA - LB)) % 2 == 1) { // if LA+LB-LAB == odd, then Gaunt_Coefficients = 0
71+
if ((LAB - std::abs(LA - LB)) % 2 == 1) // if LA+LB-LAB == odd, then Gaunt_Coefficients = 0
72+
{
7073
continue;
71-
}
74+
}
7275

7376
this->Table_r[LAB].resize(rmesh, 0);
7477
this->Table_dr[LAB].resize(rmesh, 0);
@@ -97,9 +100,10 @@ double Center2_Orb::Orb11::cal_overlap(const ModuleBase::Vector3<double>& RA,
97100
const double distance = (distance_true >= tiny1) ? distance_true : distance_true + tiny1;
98101
const double RcutA = this->nA.getRcut();
99102
const double RcutB = this->nB.getRcut();
100-
if (distance > (RcutA + RcutB)) {
103+
if (distance > (RcutA + RcutB))
104+
{
101105
return 0.0;
102-
}
106+
}
103107

104108
const int LA = this->nA.getL();
105109
const int LB = this->nB.getL();
@@ -112,6 +116,9 @@ double Center2_Orb::Orb11::cal_overlap(const ModuleBase::Vector3<double>& RA,
112116
rly);
113117

114118
double overlap = 0.0;
119+
const int idx1 = this->MGT.get_lm_index(LA, mA);
120+
const int idx2 = this->MGT.get_lm_index(LB, mB);
121+
const double* Gaunt_Coefficients_ptr = &(this->MGT.Gaunt_Coefficients(idx1, idx2, 0));
115122

116123
for (const auto& tb_r: this->Table_r)
117124
{
@@ -120,17 +127,18 @@ double Center2_Orb::Orb11::cal_overlap(const ModuleBase::Vector3<double>& RA,
120127
for (int mAB = 0; mAB != 2 * LAB + 1; ++mAB)
121128
// const int mAB = mA + mB;
122129
{
123-
const double Gaunt_real_A_B_AB = this->MGT.Gaunt_Coefficients(this->MGT.get_lm_index(LA, mA),
124-
this->MGT.get_lm_index(LB, mB),
125-
this->MGT.get_lm_index(LAB, mAB));
126-
if (0 == Gaunt_real_A_B_AB) {
130+
const int idx3 = this->MGT.get_lm_index(LAB, mAB);
131+
const double Gaunt_real_A_B_AB = *(Gaunt_Coefficients_ptr + idx3);
132+
if (0 == Gaunt_real_A_B_AB)
133+
{
127134
continue;
128-
}
135+
}
129136

130-
const double ylm_solid = rly[this->MGT.get_lm_index(LAB, mAB)];
131-
if (0 == ylm_solid) {
137+
const double ylm_solid = rly[idx3];
138+
if (0 == ylm_solid)
139+
{
132140
continue;
133-
}
141+
}
134142
const double ylm_real = (distance > tiny2) ? ylm_solid / pow(distance, LAB) : ylm_solid;
135143

136144
const double i_exp = std::pow(-1.0, (LA - LB - LAB) / 2);
@@ -166,18 +174,23 @@ ModuleBase::Vector3<double> Center2_Orb::Orb11::cal_grad_overlap( // caoyu add 2
166174
const double distance = (distance_true >= tiny1) ? distance_true : distance_true + tiny1;
167175
const double RcutA = this->nA.getRcut();
168176
const double RcutB = this->nB.getRcut();
169-
if (distance > (RcutA + RcutB)) {
177+
if (distance > (RcutA + RcutB))
178+
{
170179
return ModuleBase::Vector3<double>(0.0, 0.0, 0.0);
171-
}
180+
}
172181

173182
const int LA = this->nA.getL();
174183
const int LB = this->nB.getL();
184+
const int idx1 = this->MGT.get_lm_index(LA, mA);
185+
const int idx2 = this->MGT.get_lm_index(LB, mB);
186+
const double* Gaunt_Coefficients_ptr = &(this->MGT.Gaunt_Coefficients(idx1, idx2, 0));
175187

176-
std::vector<double> rly((LA + LB + 1) * (LA + LB + 1));
188+
const int LAB2 = (LA + LB + 1) * (LA + LB + 1);
189+
std::vector<double> rly(LAB2);
177190
std::vector<ModuleBase::Vector3<double>> grly;
178-
ModuleBase::Array_Pool<double> tmp_grly((LA + LB + 1) * (LA + LB + 1), 3);
191+
ModuleBase::Array_Pool<double> tmp_grly(LAB2, 3);
179192
ModuleBase::Ylm::grad_rl_sph_harm(LA + LB, delta_R.x, delta_R.y, delta_R.z, rly.data(), tmp_grly.get_ptr_2D());
180-
for (int i=0; i<(LA + LB + 1) * (LA + LB + 1); ++i)
193+
for (int i=0; i<LAB2; ++i)
181194
{
182195
ModuleBase::Vector3<double> ele(tmp_grly[i][0], tmp_grly[i][1], tmp_grly[i][2]);
183196
grly.push_back(ele);
@@ -191,17 +204,17 @@ ModuleBase::Vector3<double> Center2_Orb::Orb11::cal_grad_overlap( // caoyu add 2
191204
for (int mAB = 0; mAB != 2 * LAB + 1; ++mAB)
192205
// const int mAB = mA + mB;
193206
{
194-
const double Gaunt_real_A_B_AB = this->MGT.Gaunt_Coefficients(this->MGT.get_lm_index(LA, mA),
195-
this->MGT.get_lm_index(LB, mB),
196-
this->MGT.get_lm_index(LAB, mAB));
197-
if (0 == Gaunt_real_A_B_AB) {
207+
const int idx3 = this->MGT.get_lm_index(LAB, mAB);
208+
const double Gaunt_real_A_B_AB = *(Gaunt_Coefficients_ptr + idx3);
209+
if (0 == Gaunt_real_A_B_AB)
210+
{
198211
continue;
199-
}
212+
}
200213

201-
const double ylm_solid = rly[this->MGT.get_lm_index(LAB, mAB)];
214+
const double ylm_solid = rly[idx3];
202215
const double ylm_real = (distance > tiny2) ? ylm_solid / pow(distance, LAB) : ylm_solid;
203216

204-
const ModuleBase::Vector3<double> gylm_solid = grly[this->MGT.get_lm_index(LAB, mAB)];
217+
const ModuleBase::Vector3<double> gylm_solid = grly[idx3];
205218
const ModuleBase::Vector3<double> gylm_real
206219
= (distance > tiny2) ? gylm_solid / pow(distance, LAB) : gylm_solid;
207220

source/source_lcao/center2_orb-orb21.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ void Center2_Orb::Orb21::init_radial_table()
3535
for (int LA = std::abs(LA1 - LA2); LA <= LA1 + LA2; ++LA)
3636
{
3737
if ((LA - std::abs(LA1 - LA2)) % 2 == 1) // if LA+LB-LAB == odd, then Gaunt_Coefficients = 0
38+
{
3839
continue;
40+
}
3941

4042
this->nA[LA].set_orbital_info(nA_short.getLabel(),
4143
nA_short.getType(),
@@ -74,7 +76,9 @@ void Center2_Orb::Orb21::init_radial_table(const std::set<size_t>& radials)
7476
for (int LA = std::abs(LA1 - LA2); LA <= LA1 + LA2; ++LA)
7577
{
7678
if ((LA - std::abs(LA1 - LA2)) % 2 == 1) // if LA+LB-LAB == odd, then Gaunt_Coefficients = 0
79+
{
7780
continue;
81+
}
7882

7983
this->nA[LA].set_orbital_info(nA_short.getLabel(),
8084
nA_short.getType(),
@@ -106,6 +110,9 @@ double Center2_Orb::Orb21::cal_overlap(const ModuleBase::Vector3<double>& RA,
106110
{
107111
const int LA1 = this->nA1.getL();
108112
const int LA2 = this->nA2.getL();
113+
const int idx1 = this->MGT.get_lm_index(LA1, mA1);
114+
const int idx2 = this->MGT.get_lm_index(LA2, mA2);
115+
const double* Gaunt_Coefficients_ptr = &(this->MGT.Gaunt_Coefficients(idx1, idx2, 0));
109116

110117
double overlap = 0.0;
111118

@@ -116,11 +123,11 @@ double Center2_Orb::Orb21::cal_overlap(const ModuleBase::Vector3<double>& RA,
116123
for (int mA = 0; mA != 2 * LA + 1; ++mA)
117124
// const int mA=mA1+mA2;
118125
{
119-
const double Gaunt_real_A1_A2_A12 = this->MGT.Gaunt_Coefficients(this->MGT.get_lm_index(LA1, mA1),
120-
this->MGT.get_lm_index(LA2, mA2),
121-
this->MGT.get_lm_index(LA, mA));
126+
const double Gaunt_real_A1_A2_A12 = *(Gaunt_Coefficients_ptr + this->MGT.get_lm_index(LA, mA));
122127
if (0 == Gaunt_real_A1_A2_A12)
128+
{
123129
continue;
130+
}
124131

125132
overlap += Gaunt_real_A1_A2_A12 * orb11.second.cal_overlap(RA, RB, mA, mB);
126133
}
@@ -137,6 +144,9 @@ ModuleBase::Vector3<double> Center2_Orb::Orb21::cal_grad_overlap(const ModuleBas
137144
{
138145
const int LA1 = this->nA1.getL();
139146
const int LA2 = this->nA2.getL();
147+
const int idx1 = this->MGT.get_lm_index(LA1, mA1);
148+
const int idx2 = this->MGT.get_lm_index(LA2, mA2);
149+
const double* Gaunt_Coefficients_ptr = &(this->MGT.Gaunt_Coefficients(idx1, idx2, 0));
140150

141151
ModuleBase::Vector3<double> grad_overlap(0.0, 0.0, 0.0);
142152

@@ -147,11 +157,11 @@ ModuleBase::Vector3<double> Center2_Orb::Orb21::cal_grad_overlap(const ModuleBas
147157
for (int mA = 0; mA != 2 * LA + 1; ++mA)
148158
// const int mA=mA1+mA2;
149159
{
150-
const double Gaunt_real_A1_A2_A12 = this->MGT.Gaunt_Coefficients(this->MGT.get_lm_index(LA1, mA1),
151-
this->MGT.get_lm_index(LA2, mA2),
152-
this->MGT.get_lm_index(LA, mA));
160+
const double Gaunt_real_A1_A2_A12 = *(Gaunt_Coefficients_ptr + this->MGT.get_lm_index(LA, mA));
153161
if (0 == Gaunt_real_A1_A2_A12)
162+
{
154163
continue;
164+
}
155165

156166
grad_overlap += Gaunt_real_A1_A2_A12 * orb11.second.cal_grad_overlap(RA, RB, mA, mB);
157167
}

0 commit comments

Comments
 (0)