Skip to content

Commit 5095627

Browse files
committed
optimize cal_DMR
1 parent d240434 commit 5095627

File tree

1 file changed

+128
-119
lines changed

1 file changed

+128
-119
lines changed

source/module_elecstate/module_dm/density_matrix.cpp

Lines changed: 128 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ void DensityMatrix<std::complex<double>, double>::cal_DMR(const int ik_in)
6464

6565
ModuleBase::timer::tick("DensityMatrix", "cal_DMR");
6666
int ld_hk = this->_paraV->nrow;
67-
int ld_hk2 = 2 * ld_hk;
6867
for (int is = 1; is <= this->_nspin; ++is)
6968
{
7069
int ik_begin = this->_nk * (is - 1); // jump this->_nk for spin_down if nspin==2
@@ -84,6 +83,7 @@ void DensityMatrix<std::complex<double>, double>::cal_DMR(const int ik_in)
8483
int col_ap = this->_paraV->atom_begin_col[iat2];
8584
const int row_size = this->_paraV->get_row_size(iat1);
8685
const int col_size = this->_paraV->get_col_size(iat2);
86+
const int mat_size = row_size * col_size;
8787
const int r_size = target_ap.get_R_size();
8888
if (row_ap == -1 || col_ap == -1)
8989
{
@@ -92,9 +92,13 @@ void DensityMatrix<std::complex<double>, double>::cal_DMR(const int ik_in)
9292
std::vector<std::complex<double>> tmp_DMR;
9393
if (PARAM.inp.nspin == 4)
9494
{
95-
tmp_DMR.resize(target_ap.get_size());
95+
tmp_DMR.resize(mat_size * r_size, 0);
9696
}
97-
for (int ir = 0; ir < r_size; ++ir)
97+
98+
// calculate kphase and target_mat_ptr
99+
std::vector<std::complex<double>> kphase_vec(r_size * this->_nk);
100+
std::vector<double*> target_DMR_mat_vec(r_size);
101+
for(int ir = 0; ir < r_size; ++ir)
98102
{
99103
const ModuleBase::Vector3<int> r_index = target_ap.get_R_index(ir);
100104
hamilt::BaseMatrix<double>* target_mat = target_ap.find_matrix(r_index);
@@ -105,118 +109,112 @@ void DensityMatrix<std::complex<double>, double>::cal_DMR(const int ik_in)
105109
continue;
106110
}
107111
#endif
108-
// loop over k-points
109-
if (PARAM.inp.nspin != 4)
112+
target_DMR_mat_vec[ir] = target_mat->get_pointer();
113+
for(int ik = 0; ik < this->_nk; ++ik)
114+
{
115+
if(ik_in >= 0 && ik_in != ik)
116+
{
117+
continue;
118+
}
119+
// cal k_phase
120+
// if TK==std::complex<double>, kphase is e^{ikR}
121+
const ModuleBase::Vector3<double> dR(r_index[0], r_index[1], r_index[2]);
122+
const double arg = (this->_kvec_d[ik] * dR) * ModuleBase::TWO_PI;
123+
double sinp, cosp;
124+
ModuleBase::libm::sincos(arg, &sinp, &cosp);
125+
kphase_vec[ik * r_size + ir] = std::complex<double>(cosp, sinp);
126+
}
127+
}
128+
129+
std::vector<std::complex<double>> tmp_DMK_mat(mat_size);
130+
// step_trace = 0 for NSPIN=1,2; ={0, 1, local_col, local_col+1} for NSPIN=4
131+
// step_trace is used when nspin = 4;
132+
int step_trace[4]{};
133+
if(PARAM.inp.nspin == 4)
134+
{
135+
const int npol = 2;
136+
for (int is = 0; is < npol; is++)
110137
{
111-
for (int ik = 0; ik < this->_nk; ++ik)
138+
for (int is2 = 0; is2 < npol; is2++)
112139
{
113-
if(ik_in >= 0 && ik_in != ik) { continue;
114-
}
115-
// cal k_phase
116-
// if TK==std::complex<double>, kphase is e^{ikR}
117-
const ModuleBase::Vector3<double> dR(r_index[0], r_index[1], r_index[2]);
118-
const double arg = (this->_kvec_d[ik] * dR) * ModuleBase::TWO_PI;
119-
double sinp, cosp;
120-
ModuleBase::libm::sincos(arg, &sinp, &cosp);
121-
std::complex<double> kphase = std::complex<double>(cosp, sinp);
122-
// set DMR element
123-
double* target_DMR_ptr = target_mat->get_pointer();
124-
std::complex<double>* DMK_ptr = this->_DMK[ik + ik_begin].data();
125-
double* DMK_real_ptr = nullptr;
126-
double* DMK_imag_ptr = nullptr;
127-
// jump DMK to fill DMR
128-
// DMR is row-major, DMK is column-major
129-
DMK_ptr += col_ap * this->_paraV->nrow + row_ap;
130-
for (int mu = 0; mu < row_size; ++mu)
131-
{
132-
DMK_real_ptr = (double*)DMK_ptr;
133-
DMK_imag_ptr = DMK_real_ptr + 1;
134-
BlasConnector::axpy(col_size,
135-
kphase.real(),
136-
DMK_real_ptr,
137-
ld_hk2,
138-
target_DMR_ptr,
139-
1);
140-
// "-" since i^2 = -1
141-
BlasConnector::axpy(col_size,
142-
-kphase.imag(),
143-
DMK_imag_ptr,
144-
ld_hk2,
145-
target_DMR_ptr,
146-
1);
147-
DMK_ptr += 1;
148-
target_DMR_ptr += col_size;
149-
}
140+
step_trace[is * npol + is2] = target_ap.get_col_size() * is + is2;
150141
}
151142
}
143+
}
144+
for(int ik = 0; ik < this->_nk; ++ik)
145+
{
146+
if(ik_in >= 0 && ik_in != ik)
147+
{
148+
continue;
149+
}
152150

153-
// treat DMR as pauli matrix when NSPIN=4
154-
if (PARAM.inp.nspin == 4)
151+
// copy column-major DMK to row-major tmp_DMK_mat (for the purpose of computational efficiency)
152+
const std::complex<double>* DMK_mat_ptr = this->_DMK[ik + ik_begin].data() + col_ap * this->_paraV->nrow + row_ap;
153+
for(int icol = 0; icol < col_size; ++icol)
155154
{
156-
tmp_DMR.assign(target_ap.get_size(), std::complex<double>(0.0, 0.0));
157-
for (int ik = 0; ik < this->_nk; ++ik)
155+
for(int irow = 0; irow < row_size; ++irow)
158156
{
159-
if(ik_in >= 0 && ik_in != ik) { continue;
160-
}
161-
// cal k_phase
162-
// if TK==std::complex<double>, kphase is e^{ikR}
163-
const ModuleBase::Vector3<double> dR(r_index[0], r_index[1], r_index[2]);
164-
const double arg = (this->_kvec_d[ik] * dR) * ModuleBase::TWO_PI;
165-
double sinp, cosp;
166-
ModuleBase::libm::sincos(arg, &sinp, &cosp);
167-
std::complex<double> kphase = std::complex<double>(cosp, sinp);
168-
// set DMR element
169-
std::complex<double>* tmp_DMR_ptr = tmp_DMR.data();
170-
std::complex<double>* DMK_ptr = this->_DMK[ik + ik_begin].data();
171-
double* DMK_real_ptr = nullptr;
172-
double* DMK_imag_ptr = nullptr;
173-
// jump DMK to fill DMR
174-
// DMR is row-major, DMK is column-major
175-
DMK_ptr += col_ap * this->_paraV->nrow + row_ap;
176-
for (int mu = 0; mu < target_ap.get_row_size(); ++mu)
177-
{
178-
BlasConnector::axpy(target_ap.get_col_size(),
179-
kphase,
180-
DMK_ptr,
181-
ld_hk,
182-
tmp_DMR_ptr,
183-
1);
184-
DMK_ptr += 1;
185-
tmp_DMR_ptr += target_ap.get_col_size();
186-
}
157+
tmp_DMK_mat[irow * col_size + icol] = DMK_mat_ptr[icol * ld_hk + irow];
187158
}
188-
int npol = 2;
189-
// step_trace = 0 for NSPIN=1,2; ={0, 1, local_col, local_col+1} for NSPIN=4
190-
int step_trace[4];
191-
for (int is = 0; is < npol; is++)
159+
}
160+
161+
// if nspin != 4, fill DMR
162+
// if nspin == 4, fill tmp_DMR
163+
for(int ir = 0; ir < r_size; ++ir)
164+
{
165+
std::complex<double> kphase = kphase_vec[ik * r_size + ir];
166+
if(PARAM.inp.nspin != 4)
192167
{
193-
for (int is2 = 0; is2 < npol; is2++)
168+
double* target_DMR_mat = target_DMR_mat_vec[ir];
169+
for(int irow = 0; irow < row_size; ++irow)
194170
{
195-
step_trace[is * npol + is2] = target_ap.get_col_size() * is + is2;
171+
for(int icol = 0; icol < col_size; ++icol)
172+
{
173+
target_DMR_mat[irow * col_size + icol] += kphase.real() * tmp_DMK_mat[irow * col_size + icol].real()
174+
- kphase.imag() * tmp_DMK_mat[irow * col_size + icol].imag();
175+
}
196176
}
177+
} else if(PARAM.inp.nspin == 4)
178+
{
179+
std::complex<double>* tmp_DMR_mat = &tmp_DMR[ir * mat_size];
180+
BlasConnector::axpy(mat_size,
181+
kphase,
182+
tmp_DMK_mat.data(),
183+
1,
184+
tmp_DMR_mat,
185+
1);
197186
}
198-
std::complex<double> tmp[4];
199-
double* target_DMR = target_mat->get_pointer();
200-
std::complex<double>* tmp_DMR_ptr = tmp_DMR.data();
201-
for (int irow = 0; irow < target_ap.get_row_size(); irow += 2)
187+
}
188+
}
189+
190+
// if nspin == 4
191+
// copy tmp_DMR to fill target_DMR
192+
if(PARAM.inp.nspin == 4)
193+
{
194+
std::complex<double> tmp[4]{};
195+
for(int ir = 0; ir < r_size; ++ir)
196+
{
197+
std::complex<double>* tmp_DMR_mat = &tmp_DMR[ir * mat_size];
198+
double* target_DMR_mat = target_DMR_mat_vec[ir];
199+
for (int irow = 0; irow < row_size; irow += 2)
202200
{
203-
for (int icol = 0; icol < target_ap.get_col_size(); icol += 2)
201+
for (int icol = 0; icol < col_size; icol += 2)
204202
{
205203
// catch the 4 spin component value of one orbital pair
206-
tmp[0] = tmp_DMR_ptr[icol + step_trace[0]];
207-
tmp[1] = tmp_DMR_ptr[icol + step_trace[1]];
208-
tmp[2] = tmp_DMR_ptr[icol + step_trace[2]];
209-
tmp[3] = tmp_DMR_ptr[icol + step_trace[3]];
204+
tmp[0] = tmp_DMR_mat[icol + step_trace[0]];
205+
tmp[1] = tmp_DMR_mat[icol + step_trace[1]];
206+
tmp[2] = tmp_DMR_mat[icol + step_trace[2]];
207+
tmp[3] = tmp_DMR_mat[icol + step_trace[3]];
210208
// transfer to Pauli matrix and save the real part
211209
// save them back to the target_mat
212-
target_DMR[icol + step_trace[0]] = tmp[0].real() + tmp[3].real();
213-
target_DMR[icol + step_trace[1]] = tmp[1].real() + tmp[2].real();
214-
target_DMR[icol + step_trace[2]]
210+
target_DMR_mat[icol + step_trace[0]] = tmp[0].real() + tmp[3].real();
211+
target_DMR_mat[icol + step_trace[1]] = tmp[1].real() + tmp[2].real();
212+
target_DMR_mat[icol + step_trace[2]]
215213
= -tmp[1].imag() + tmp[2].imag(); // (i * (rho_updown - rho_downup)).real()
216-
target_DMR[icol + step_trace[3]] = tmp[0].real() - tmp[3].real();
214+
target_DMR_mat[icol + step_trace[3]] = tmp[0].real() - tmp[3].real();
217215
}
218-
tmp_DMR_ptr += target_ap.get_col_size() * 2;
219-
target_DMR += target_ap.get_col_size() * 2;
216+
tmp_DMR_mat += col_size * 2;
217+
target_DMR_mat += col_size * 2;
220218
}
221219
}
222220
}
@@ -252,49 +250,60 @@ void DensityMatrix<std::complex<double>, double>::cal_DMR_full(hamilt::HContaine
252250
int col_ap = this->_paraV->atom_begin_col[iat2];
253251
const int row_size = this->_paraV->get_row_size(iat1);
254252
const int col_size = this->_paraV->get_col_size(iat2);
253+
const int mat_size = row_size * col_size;
255254
const int r_size = target_ap.get_R_size();
256-
for (int ir = 0; ir < r_size; ++ir)
255+
256+
// calculate kphase and target_mat_ptr
257+
std::vector<std::complex<double>> kphase_vec(r_size * this->_nk);
258+
std::vector<std::complex<double>*> target_DMR_mat_vec(r_size);
259+
for(int ir = 0; ir < r_size; ++ir)
257260
{
258261
const ModuleBase::Vector3<int> r_index = target_ap.get_R_index(ir);
259-
auto* target_mat = target_ap.find_matrix(r_index);
262+
hamilt::BaseMatrix<std::complex<double>>* target_mat = target_ap.find_matrix(r_index);
260263
#ifdef __DEBUG
261264
if (target_mat == nullptr)
262265
{
263266
std::cout << "target_mat is nullptr" << std::endl;
264267
continue;
265268
}
266269
#endif
267-
// loop over k-points
268-
// calculate full matrix for complex density matrix
269-
for (int ik = 0; ik < this->_nk; ++ik)
270+
target_DMR_mat_vec[ir] = target_mat->get_pointer();
271+
for(int ik = 0; ik < this->_nk; ++ik)
270272
{
271273
// cal k_phase
272274
// if TK==std::complex<double>, kphase is e^{ikR}
273275
const ModuleBase::Vector3<double> dR(r_index[0], r_index[1], r_index[2]);
274276
const double arg = (this->_kvec_d[ik] * dR) * ModuleBase::TWO_PI;
275277
double sinp, cosp;
276278
ModuleBase::libm::sincos(arg, &sinp, &cosp);
277-
std::complex<double> kphase = std::complex<double>(cosp, sinp);
278-
// set DMR element
279-
std::complex<double>* target_DMR_ptr = target_mat->get_pointer();
280-
const std::complex<double>* DMK_ptr = this->_DMK[ik].data();
281-
double* DMK_real_ptr = nullptr;
282-
double* DMK_imag_ptr = nullptr;
283-
// jump DMK to fill DMR
284-
// DMR is row-major, DMK is column-major
285-
DMK_ptr += col_ap * this->_paraV->nrow + row_ap;
286-
for (int mu = 0; mu < row_size; ++mu)
279+
kphase_vec[ik * r_size + ir] = std::complex<double>(cosp, sinp);
280+
}
281+
}
282+
283+
std::vector<std::complex<double>> tmp_DMK_mat(mat_size);
284+
for(int ik = 0; ik < this->_nk; ++ik)
285+
{
286+
// copy column-major DMK to row-major tmp_DMK_mat (for the purpose of computational efficiency)
287+
const std::complex<double>* DMK_mat_ptr = this->_DMK[ik].data() + col_ap * this->_paraV->nrow + row_ap;
288+
for(int icol = 0; icol < col_size; ++icol)
289+
{
290+
for(int irow = 0; irow < row_size; ++irow)
287291
{
288-
BlasConnector::axpy(col_size,
289-
kphase,
290-
DMK_ptr,
291-
ld_hk,
292-
target_DMR_ptr,
293-
1);
294-
DMK_ptr += 1;
295-
target_DMR_ptr += col_size;
292+
tmp_DMK_mat[irow * col_size + icol] = DMK_mat_ptr[icol * ld_hk + irow];
296293
}
297294
}
295+
296+
for(int ir = 0; ir < r_size; ++ir)
297+
{
298+
std::complex<double> kphase = kphase_vec[ik * r_size + ir];
299+
std::complex<double>* target_DMR_mat = target_DMR_mat_vec[ir];
300+
BlasConnector::axpy(mat_size,
301+
kphase,
302+
tmp_DMK_mat.data(),
303+
1,
304+
target_DMR_mat,
305+
1);
306+
}
298307
}
299308
}
300309
ModuleBase::timer::tick("DensityMatrix", "cal_DMR_full");

0 commit comments

Comments
 (0)