Skip to content

Commit 2fead84

Browse files
authored
Refactor: reduce the memory consumption of cal_gint_vlocal (#6069)
* reduce the memory consumption of cal_gint_vlocal * fix a bug
1 parent 5b19d01 commit 2fead84

File tree

10 files changed

+56
-132
lines changed

10 files changed

+56
-132
lines changed

source/module_hamilt_lcao/module_gint/gint.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ class Gint {
140140
//! psir_ylm: dim is [bxyz][LD_pool]
141141
//! psir_vlbr3: dim is [bxyz][LD_pool]
142142
//! hR: HContainer for storing the <phi_0|V|phi_R> matrix elements
143+
//! cal_meshball_vlocal is thread-safe!
143144
void cal_meshball_vlocal(
144145
const int na_grid,
145146
const int LD_pool,

source/module_hamilt_lcao/module_gint/gint_vl.cpp

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <mkl_service.h>
2020
#endif
2121

22-
22+
// this is a thread-safe function
2323
void Gint::cal_meshball_vlocal(
2424
const int na_grid, // how many atoms on this (i,j,k) grid
2525
const int LD_pool,
@@ -36,6 +36,7 @@ void Gint::cal_meshball_vlocal(
3636
const int lgd_now = this->gridt->lgd;
3737

3838
const int mcell_index = this->gridt->bcell_start[grid_index];
39+
std::vector<double> hr_tmp;
3940
for(int ia1=0; ia1<na_grid; ++ia1)
4041
{
4142
const int bcell1 = mcell_index + ia1;
@@ -80,33 +81,14 @@ void Gint::cal_meshball_vlocal(
8081
}
8182
const int m = tmp_matrix->get_row_size();
8283
const int n = tmp_matrix->get_col_size();
83-
84-
int cal_pair_num=0;
85-
for(int ib=first_ib;ib<last_ib; ++ib)
86-
{
87-
cal_pair_num += cal_flag[ib][ia1] && cal_flag[ib][ia2];
88-
}
89-
if(cal_pair_num>ib_length/4)
90-
{
91-
dgemm_(&transa, &transb, &n, &m, &ib_length, &alpha,
92-
&psir_vlbr3[first_ib][block_index[ia2]], &LD_pool,
93-
&psir_ylm[first_ib][block_index[ia1]], &LD_pool,
94-
&beta, tmp_matrix->get_pointer(), &n);
95-
}
96-
else
97-
{
98-
for(int ib=first_ib; ib<last_ib; ++ib)
99-
{
100-
if(cal_flag[ib][ia1] && cal_flag[ib][ia2])
101-
{
102-
int k=1;
103-
dgemm_(&transa, &transb, &n, &m, &k, &alpha,
104-
&psir_vlbr3[ib][block_index[ia2]], &LD_pool,
105-
&psir_ylm[ib][block_index[ia1]], &LD_pool,
106-
&beta, tmp_matrix->get_pointer(), &n);
107-
}
108-
}
109-
}
84+
hr_tmp.resize(m * n);
85+
ModuleBase::GlobalFunc::ZEROS(hr_tmp.data(), m*n);
86+
87+
dgemm_(&transa, &transb, &n, &m, &ib_length, &alpha,
88+
&psir_vlbr3[first_ib][block_index[ia2]], &LD_pool,
89+
&psir_ylm[first_ib][block_index[ia1]], &LD_pool,
90+
&beta, hr_tmp.data(), &n);
91+
tmp_matrix->add_array_ts(hr_tmp.data());
11092
}
11193
}
11294
}

source/module_hamilt_lcao/module_gint/gint_vl_cpu_interface.cpp

Lines changed: 8 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ void Gint::gint_kernel_vlocal(Gint_inout* inout) {
1919
{ /**
2020
* @brief When in OpenMP, it points to a newly allocated memory,
2121
*/
22-
hamilt::HContainer<double> hRGint_thread(*hRGint_kernel);
2322
std::vector<int> block_iw(max_size,0);
2423
std::vector<int> block_index(max_size+1,0);
2524
std::vector<int> block_size(max_size,0);
@@ -74,19 +73,8 @@ void Gint::gint_kernel_vlocal(Gint_inout* inout) {
7473
this->cal_meshball_vlocal(
7574
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index,
7675
cal_flag.get_ptr_2D(),psir_ylm.get_ptr_2D(), psir_vlbr3.get_ptr_2D(),
77-
&hRGint_thread);
76+
hRGint_kernel);
7877
}
79-
80-
#pragma omp critical
81-
{
82-
BlasConnector::axpy(hRGint_thread.get_nnr(),
83-
1.0,
84-
hRGint_thread.get_wrapper(),
85-
1,
86-
hRGint_kernel->get_wrapper(),
87-
1);
88-
}
89-
9078
ModuleBase::TITLE("Gint_interface", "cal_gint_vlocal");
9179
ModuleBase::timer::tick("Gint_interface", "cal_gint_vlocal");
9280
}
@@ -112,9 +100,6 @@ void Gint::gint_kernel_dvlocal(Gint_inout* inout) {
112100

113101
#pragma omp parallel
114102
{
115-
hamilt::HContainer<double> pvdpRx_thread(pvdpRx_reduced[inout->ispin]);
116-
hamilt::HContainer<double> pvdpRy_thread(pvdpRy_reduced[inout->ispin]);
117-
hamilt::HContainer<double> pvdpRz_thread(pvdpRz_reduced[inout->ispin]);
118103
std::vector<int> block_iw(max_size,0);
119104
std::vector<int> block_index(max_size+1,0);
120105
std::vector<int> block_size(max_size,0);
@@ -159,34 +144,13 @@ void Gint::gint_kernel_dvlocal(Gint_inout* inout) {
159144
//and accumulates to the corresponding element in Hamiltonian
160145
this->cal_meshball_vlocal(na_grid, LD_pool, block_size.data(), block_index.data(),
161146
grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
162-
dpsir_ylm_x.get_ptr_2D(), &pvdpRx_thread);
147+
dpsir_ylm_x.get_ptr_2D(), &this->pvdpRx_reduced[inout->ispin]);
163148
this->cal_meshball_vlocal(na_grid, LD_pool, block_size.data(), block_index.data(),
164149
grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
165-
dpsir_ylm_y.get_ptr_2D(), &pvdpRy_thread);
150+
dpsir_ylm_y.get_ptr_2D(), &this->pvdpRy_reduced[inout->ispin]);
166151
this->cal_meshball_vlocal(na_grid, LD_pool, block_size.data(), block_index.data(),
167152
grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
168-
dpsir_ylm_z.get_ptr_2D(), &pvdpRz_thread);
169-
}
170-
#pragma omp critical(gint_k)
171-
{
172-
BlasConnector::axpy(nnrg,
173-
1.0,
174-
pvdpRx_thread.get_wrapper(),
175-
1,
176-
this->pvdpRx_reduced[inout->ispin].get_wrapper(),
177-
1);
178-
BlasConnector::axpy(nnrg,
179-
1.0,
180-
pvdpRy_thread.get_wrapper(),
181-
1,
182-
this->pvdpRy_reduced[inout->ispin].get_wrapper(),
183-
1);
184-
BlasConnector::axpy(nnrg,
185-
1.0,
186-
pvdpRz_thread.get_wrapper(),
187-
1,
188-
this->pvdpRz_reduced[inout->ispin].get_wrapper(),
189-
1);
153+
dpsir_ylm_z.get_ptr_2D(), &this->pvdpRz_reduced[inout->ispin]);
190154
}
191155
}
192156
ModuleBase::TITLE("Gint_interface", "cal_gint_dvlocal");
@@ -210,7 +174,6 @@ void Gint::gint_kernel_vlocal_meta(Gint_inout* inout) {
210174
{
211175
// define HContainer here to reference.
212176
//Under the condition of gamma_only, hRGint will be instantiated.
213-
hamilt::HContainer<double> hRGint_thread(*hRGint_kernel);
214177
std::vector<int> block_iw(max_size,0);
215178
std::vector<int> block_index(max_size+1,0);
216179
std::vector<int> block_size(max_size,0);
@@ -282,28 +245,18 @@ void Gint::gint_kernel_vlocal_meta(Gint_inout* inout) {
282245
//and accumulates to the corresponding element in Hamiltonian
283246
this->cal_meshball_vlocal(
284247
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
285-
psir_ylm.get_ptr_2D(), psir_vlbr3.get_ptr_2D(), &hRGint_thread);
248+
psir_ylm.get_ptr_2D(), psir_vlbr3.get_ptr_2D(), hRGint_kernel);
286249
//integrate (d/dx_i psi_mu*vk(r)*dv) * (d/dx_i psi_nu) on grid (x_i=x,y,z)
287250
//and accumulates to the corresponding element in Hamiltonian
288251
this->cal_meshball_vlocal(
289252
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
290-
dpsir_ylm_x.get_ptr_2D(), dpsix_vlbr3.get_ptr_2D(), &hRGint_thread);
253+
dpsir_ylm_x.get_ptr_2D(), dpsix_vlbr3.get_ptr_2D(), hRGint_kernel);
291254
this->cal_meshball_vlocal(
292255
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
293-
dpsir_ylm_y.get_ptr_2D(), dpsiy_vlbr3.get_ptr_2D(), &hRGint_thread);
256+
dpsir_ylm_y.get_ptr_2D(), dpsiy_vlbr3.get_ptr_2D(), hRGint_kernel);
294257
this->cal_meshball_vlocal(
295258
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
296-
dpsir_ylm_z.get_ptr_2D(), dpsiz_vlbr3.get_ptr_2D(), &hRGint_thread);
297-
}
298-
299-
#pragma omp critical
300-
{
301-
BlasConnector::axpy(nnrg,
302-
1.0,
303-
hRGint_thread.get_wrapper(),
304-
1,
305-
hRGint_kernel->get_wrapper(),
306-
1);
259+
dpsir_ylm_z.get_ptr_2D(), dpsiz_vlbr3.get_ptr_2D(), hRGint_kernel);
307260
}
308261
}
309262

source/module_hamilt_lcao/module_gint/temp_gint/gint_vl.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ void Gint_vl::cal_hr_gint_()
3333
PhiOperator phi_op;
3434
std::vector<double> phi;
3535
std::vector<double> phi_vldr3;
36-
HContainer<double> hr_gint_local(*hr_gint_);
3736
#pragma omp for schedule(dynamic)
3837
for(const auto& biggrid: gint_info_->get_biggrids())
3938
{
@@ -47,12 +46,7 @@ void Gint_vl::cal_hr_gint_()
4746
phi_vldr3.resize(phi_len);
4847
phi_op.set_phi(phi.data());
4948
phi_op.phi_mul_vldr3(vr_eff_, dr3_, phi.data(), phi_vldr3.data());
50-
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), &hr_gint_local);
51-
}
52-
#pragma omp critical
53-
{
54-
BlasConnector::axpy(hr_gint_local.get_nnr(), 1.0, hr_gint_local.get_wrapper(),
55-
1, hr_gint_->get_wrapper(), 1);
49+
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), *hr_gint_);
5650
}
5751
}
5852
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_vl_metagga.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ void Gint_vl_metagga::cal_hr_gint_()
3939
std::vector<double> dphi_x_vldr3;
4040
std::vector<double> dphi_y_vldr3;
4141
std::vector<double> dphi_z_vldr3;
42-
HContainer<double> hr_gint_local(*hr_gint_);
4342
#pragma omp for schedule(dynamic)
4443
for(const auto& biggrid: gint_info_->get_biggrids())
4544
{
@@ -62,15 +61,10 @@ void Gint_vl_metagga::cal_hr_gint_()
6261
phi_op.phi_mul_vldr3(vofk_, dr3_, dphi_x.data(), dphi_x_vldr3.data());
6362
phi_op.phi_mul_vldr3(vofk_, dr3_, dphi_y.data(), dphi_y_vldr3.data());
6463
phi_op.phi_mul_vldr3(vofk_, dr3_, dphi_z.data(), dphi_z_vldr3.data());
65-
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), &hr_gint_local);
66-
phi_op.phi_mul_phi_vldr3(dphi_x.data(), dphi_x_vldr3.data(), &hr_gint_local);
67-
phi_op.phi_mul_phi_vldr3(dphi_y.data(), dphi_y_vldr3.data(), &hr_gint_local);
68-
phi_op.phi_mul_phi_vldr3(dphi_z.data(), dphi_z_vldr3.data(), &hr_gint_local);
69-
}
70-
#pragma omp critical
71-
{
72-
BlasConnector::axpy(hr_gint_local.get_nnr(), 1.0, hr_gint_local.get_wrapper(),
73-
1, hr_gint_->get_wrapper(), 1);
64+
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), *hr_gint_);
65+
phi_op.phi_mul_phi_vldr3(dphi_x.data(), dphi_x_vldr3.data(), *hr_gint_);
66+
phi_op.phi_mul_phi_vldr3(dphi_y.data(), dphi_y_vldr3.data(), *hr_gint_);
67+
phi_op.phi_mul_phi_vldr3(dphi_z.data(), dphi_z_vldr3.data(), *hr_gint_);
7468
}
7569
}
7670
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_vl_metagga_nspin4.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ void Gint_vl_metagga_nspin4::cal_hr_gint_()
4141
std::vector<double> dphi_x_vldr3;
4242
std::vector<double> dphi_y_vldr3;
4343
std::vector<double> dphi_z_vldr3;
44-
std::vector<HContainer<double>> hr_gint_part_thread(nspin_, *hr_gint_part_[0]);
4544
#pragma omp for schedule(dynamic)
4645
for(const auto& biggrid: gint_info_->get_biggrids())
4746
{
@@ -66,20 +65,10 @@ void Gint_vl_metagga_nspin4::cal_hr_gint_()
6665
phi_op.phi_mul_vldr3(vofk_[is], dr3_, dphi_x.data(), dphi_x_vldr3.data());
6766
phi_op.phi_mul_vldr3(vofk_[is], dr3_, dphi_y.data(), dphi_y_vldr3.data());
6867
phi_op.phi_mul_vldr3(vofk_[is], dr3_, dphi_z.data(), dphi_z_vldr3.data());
69-
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), &hr_gint_part_thread[is]);
70-
phi_op.phi_mul_phi_vldr3(dphi_x.data(), dphi_x_vldr3.data(), &hr_gint_part_thread[is]);
71-
phi_op.phi_mul_phi_vldr3(dphi_y.data(), dphi_y_vldr3.data(), &hr_gint_part_thread[is]);
72-
phi_op.phi_mul_phi_vldr3(dphi_z.data(), dphi_z_vldr3.data(), &hr_gint_part_thread[is]);
73-
}
74-
}
75-
#pragma omp critical
76-
{
77-
for(int is = 0; is < nspin_; is++)
78-
{
79-
{
80-
BlasConnector::axpy(hr_gint_part_thread[is].get_nnr(), 1.0, hr_gint_part_thread[is].get_wrapper(),
81-
1, hr_gint_part_[is]->get_wrapper(), 1);
82-
}
68+
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), *hr_gint_part_[is]);
69+
phi_op.phi_mul_phi_vldr3(dphi_x.data(), dphi_x_vldr3.data(), *hr_gint_part_[is]);
70+
phi_op.phi_mul_phi_vldr3(dphi_y.data(), dphi_y_vldr3.data(), *hr_gint_part_[is]);
71+
phi_op.phi_mul_phi_vldr3(dphi_z.data(), dphi_z_vldr3.data(), *hr_gint_part_[is]);
8372
}
8473
}
8574
}

source/module_hamilt_lcao/module_gint/temp_gint/gint_vl_nspin4.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ void Gint_vl_nspin4::cal_hr_gint_()
3434
PhiOperator phi_op;
3535
std::vector<double> phi;
3636
std::vector<double> phi_vldr3;
37-
std::vector<HContainer<double>> hr_gint_part_thread(nspin_, *hr_gint_part_[0]);
3837
#pragma omp for schedule(dynamic)
3938
for(const auto& biggrid: gint_info_->get_biggrids())
4039
{
@@ -50,17 +49,7 @@ void Gint_vl_nspin4::cal_hr_gint_()
5049
for(int is = 0; is < nspin_; is++)
5150
{
5251
phi_op.phi_mul_vldr3(vr_eff_[is], dr3_, phi.data(), phi_vldr3.data());
53-
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), &hr_gint_part_thread[is]);
54-
}
55-
}
56-
#pragma omp critical
57-
{
58-
for(int is = 0; is < nspin_; is++)
59-
{
60-
{
61-
BlasConnector::axpy(hr_gint_part_thread[is].get_nnr(), 1.0, hr_gint_part_thread[is].get_wrapper(),
62-
1, hr_gint_part_[is]->get_wrapper(), 1);
63-
}
52+
phi_op.phi_mul_phi_vldr3(phi.data(), phi_vldr3.data(), *hr_gint_part_[is]);
6453
}
6554
}
6655
}

source/module_hamilt_lcao/module_gint/temp_gint/phi_operator.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,34 +148,41 @@ void PhiOperator::phi_mul_vldr3(const double* vl, const double dr3, const double
148148
}
149149
}
150150

151+
// this is a thread-safe function
151152
void PhiOperator::phi_mul_phi_vldr3(
152153
const double* phi,
153154
const double* phi_vldr3,
154-
HContainer<double>* hr) const
155+
HContainer<double>& hr) const
155156
{
156157
const char transa='N', transb='T';
157158
const double alpha=1, beta=1;
158159

160+
std::vector<double> tmp_hr;
159161
for(int i = 0; i < biggrid_->get_atoms_num(); ++i)
160162
{
161163
const auto atom_i = biggrid_->get_atom(i);
162164
const auto& r_i = atom_i->get_R();
163165
const int iat_i = atom_i->get_iat();
166+
const int m = atoms_phi_len_[i];
164167

165168
for(int j = 0; j < biggrid_->get_atoms_num(); ++j)
166169
{
167170
const auto atom_j = biggrid_->get_atom(j);
168171
const auto& r_j = atom_j->get_R();
169172
const int iat_j = atom_j->get_iat();
173+
const int n = atoms_phi_len_[j];
170174

171175
// only calculate the upper triangle matrix
172176
if(iat_i > iat_j)
173177
{
174178
continue;
175179
}
176180

181+
tmp_hr.resize(m * n);
182+
ModuleBase::GlobalFunc::ZEROS(tmp_hr.data(), m*n);
183+
177184
// FIXME may be r = r_j - r_i
178-
const auto result = hr->find_matrix(iat_i, iat_j, r_i-r_j);
185+
const auto result = hr.find_matrix(iat_i, iat_j, r_i-r_j);
179186

180187
if(result == nullptr)
181188
{
@@ -191,8 +198,10 @@ void PhiOperator::phi_mul_phi_vldr3(
191198
continue;
192199
}
193200

194-
dgemm_(&transa, &transb, &atoms_phi_len_[j], &atoms_phi_len_[i], &len, &alpha, &phi_vldr3[start_idx * cols_ + atoms_startidx_[j]],
195-
&cols_,&phi[start_idx * cols_ + atoms_startidx_[i]], &cols_, &beta, result->get_pointer(), &atoms_phi_len_[j]);
201+
dgemm_(&transa, &transb, &n, &m, &len, &alpha, &phi_vldr3[start_idx * cols_ + atoms_startidx_[j]],
202+
&cols_,&phi[start_idx * cols_ + atoms_startidx_[i]], &cols_, &beta, tmp_hr.data(), &n);
203+
204+
result->add_array_ts(tmp_hr.data());
196205
}
197206
}
198207
}

source/module_hamilt_lcao/module_gint/temp_gint/phi_operator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@ class PhiOperator
5555
const double* phi,
5656
double* result) const;
5757

58+
// this is a thread-safe function
5859
void phi_mul_phi_vldr3(
5960
const double* phi,
6061
const double* phi_vldr3,
61-
HContainer<double>* hr) const;
62+
HContainer<double>& hr) const;
6263

6364
void phi_dot_phi_dm(
6465
const double* phi,

source/module_hamilt_lcao/module_hcontainer/base_matrix.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define BASE_MATRIX_H
33

44
#include <iostream>
5+
#include <mutex>
56

67
namespace hamilt
78
{
@@ -107,6 +108,15 @@ class BaseMatrix
107108
*/
108109
void set_size(const int& col_size_in, const int& row_size_in);
109110

111+
void add_array_ts(T* array)
112+
{
113+
std::lock_guard<std::mutex> lock(mtx);
114+
for (int i = 0; i < nrow_local * ncol_local; ++i)
115+
{
116+
value_begin[i] += array[i];
117+
}
118+
}
119+
110120
private:
111121
bool allocated = false;
112122

@@ -118,6 +128,8 @@ class BaseMatrix
118128

119129
// int current_multiple = 0;
120130

131+
// for thread safe
132+
mutable std::mutex mtx;
121133
// number of rows and columns
122134
int nrow_local = 0;
123135
int ncol_local = 0;

0 commit comments

Comments
 (0)