Skip to content

Commit 7ab1a78

Browse files
committed
replace pvpr with hcontainer in gpu code of gint_vl
1 parent 368c0fe commit 7ab1a78

File tree

7 files changed

+130
-242
lines changed

7 files changed

+130
-242
lines changed

source/module_hamilt_lcao/module_gint/gint.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@ Gint::~Gint() {
2424

2525
delete this->hRGint;
2626
delete this->hRGintCd;
27+
// in gamma_only case, DMRGint.size()=0,
28+
// in multi-k case, DMRGint.size()=nspin
2729
for (int is = 0; is < this->DMRGint.size(); is++) {
2830
delete this->DMRGint[is];
2931
}
32+
for(int is = 0; is < this->hRGint_tmp.size(); is++) {
33+
delete this->hRGint_tmp[is];
34+
}
3035
#ifdef __MPI
3136
delete this->DMRGint_full;
3237
#endif
@@ -141,6 +146,7 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, Grid_Driver* gd, const int&
141146
if (this->DMRGint.size() == 0) {
142147
this->DMRGint.resize(nspin);
143148
}
149+
hRGint_tmp.resize(nspin);
144150
if (nspin != 4) {
145151
if (this->hRGint != nullptr) {
146152
delete this->hRGint;
@@ -161,9 +167,8 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, Grid_Driver* gd, const int&
161167
delete this->hRGint_tmp[is];
162168
}
163169
this->DMRGint[is] = new hamilt::HContainer<double>(ucell_in.nat);
170+
this->hRGint_tmp[is] = new hamilt::HContainer<double>(ucell_in.nat);
164171
}
165-
this->hRGint_tmp[0]
166-
= new hamilt::HContainer<double>(ucell_in.nat);
167172
#ifdef __MPI
168173
if (this->DMRGint_full != nullptr) {
169174
delete this->DMRGint_full;
@@ -177,7 +182,7 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, Grid_Driver* gd, const int&
177182
}
178183
if (npol == 1) {
179184
this->hRGint->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
180-
this->hRGint->allocate(nullptr, false);
185+
this->hRGint->allocate(nullptr, true);
181186
ModuleBase::Memory::record("Gint::hRGint",
182187
this->hRGint->get_memory_size());
183188
// initialize DMRGint with hRGint when NSPIN != 4
@@ -192,11 +197,12 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, Grid_Driver* gd, const int&
192197
* this->DMRGint.size());
193198
} else {
194199
this->hRGintCd->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
200+
this->hRGintCd->allocate(nullptr, true);
195201
for(int is = 0; is < nspin; is++) {
196202
this->hRGint_tmp[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
197203
this->DMRGint[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
198-
this->hRGint_tmp[is]->allocate(nullptr, false);
199-
this->DMRGint[is]->allocate(nullptr, false);
204+
this->hRGint_tmp[is]->allocate(nullptr, true);
205+
this->DMRGint[is]->allocate(nullptr, true);
200206
}
201207
ModuleBase::Memory::record("Gint::hRGint_tmp",
202208
this->hRGint_tmp[0]->get_memory_size()*nspin);
@@ -205,7 +211,7 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, Grid_Driver* gd, const int&
205211
* this->DMRGint.size()*nspin);
206212
#ifdef __MPI
207213
this->DMRGint_full->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
208-
this->DMRGint_full->allocate(nullptr, false);
214+
this->DMRGint_full->allocate(nullptr, true);
209215
ModuleBase::Memory::record("Gint::DMRGint_full",
210216
this->DMRGint_full->get_memory_size());
211217
#endif

source/module_hamilt_lcao/module_gint/gint_gpu_interface.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,14 @@ void Gint::gpu_vlocal_interface(Gint_inout* inout) {
1818
ylmcoef[i] = ModuleBase::Ylm::ylmcoef[i];
1919
}
2020

21-
double* pvpR = PARAM.globalv.gamma_only_local ? nullptr : this->pvpR_reduced[inout->ispin];
22-
GintKernel::gint_vl_gpu(this->hRGint,
21+
hamilt::HContainer<double>* hRGint_kernel = PARAM.inp.nspin != 4 ? this->hRGint : this->hRGint_tmp[inout->ispin];
22+
GintKernel::gint_vl_gpu(hRGint_kernel,
2323
inout->vl,
2424
ylmcoef,
2525
dr,
2626
this->gridt->rcuts.data(),
2727
*this->gridt,
28-
ucell,
29-
pvpR,
30-
PARAM.globalv.gamma_only_local);
28+
ucell);
3129

3230
ModuleBase::TITLE("Gint_interface", "cal_gint_vlocal");
3331
ModuleBase::timer::tick("Gint_interface", "cal_gint_vlocal");

source/module_hamilt_lcao/module_gint/gint_k_pvpr.cpp

Lines changed: 56 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ void Gint_k::transfer_pvpR(hamilt::HContainer<double>* hR, const UnitCell* ucell
8181
for (int ir = 0; ir < ap.get_R_size(); ir++)
8282
{
8383
auto R_index = ap.get_R_index(ir);
84-
auto upper_matrix = upper_ap->find_matrix(-R_index);
85-
auto lower_matrix = lower_ap->find_matrix(R_index);
86-
for (int irow = 0; irow < upper_matrix->get_row_size(); ++irow)
84+
auto upper_mat = upper_ap->find_matrix(-R_index);
85+
auto lower_mat = lower_ap->find_matrix(R_index);
86+
for (int irow = 0; irow < upper_mat->get_row_size(); ++irow)
8787
{
88-
for (int icol = 0; icol < upper_matrix->get_col_size(); ++icol)
88+
for (int icol = 0; icol < upper_mat->get_col_size(); ++icol)
8989
{
90-
lower_matrix->get_value(icol, irow) = upper_ap->get_value(irow, icol);
90+
lower_mat->get_value(icol, irow) = upper_ap->get_value(irow, icol);
9191
}
9292
}
9393
}
@@ -116,141 +116,68 @@ void Gint_k::transfer_pvpR(hamilt::HContainer<std::complex<double>>* hR, const U
116116
{
117117
ModuleBase::TITLE("Gint_k", "transfer_pvpR");
118118
ModuleBase::timer::tick("Gint_k", "transfer_pvpR");
119-
if (!pvpR_alloc_flag || this->hRGintCd == nullptr)
120-
{
121-
ModuleBase::WARNING_QUIT("Gint_k::destroy_pvpR", "pvpR hasnot been allocated yet!");
122-
}
123-
this->hRGintCd->set_zero();
124-
125-
const int npol = PARAM.globalv.npol;
126-
const UnitCell& ucell = *ucell_in;
127119

128-
for (int iat = 0; iat < ucell.nat; ++iat)
120+
this->hRGintCd->set_zero();
121+
122+
for (int iap = 0; iap < this->hRGintCd->size_atom_pairs(); iap++)
129123
{
130-
const int T1 = ucell.iat2it[iat];
131-
const int I1 = ucell.iat2ia[iat];
124+
auto* ap = &this->hRGintCd->get_atom_pair(iap);
125+
const int iat1 = ap->get_atom_i();
126+
const int iat2 = ap->get_atom_j();
127+
if (iat1 <= iat2)
132128
{
133-
// atom in this grid piece.
134-
if (this->gridt->in_this_processor[iat])
135-
{
136-
Atom* atom1 = &ucell.atoms[T1];
137-
138-
// get the start positions of elements.
139-
const int DM_start = this->gridt->nlocstartg[iat];
140-
141-
// get the coordinates of adjacent atoms.
142-
auto& tau1 = ucell.atoms[T1].tau[I1];
143-
// gd.Find_atom(tau1);
144-
AdjacentAtomInfo adjs;
145-
gd->Find_atom(ucell, tau1, T1, I1, &adjs);
146-
// search for the adjacent atoms.
147-
int nad = 0;
129+
hamilt::AtomPair<std::complex<double>>* upper_ap = ap;
130+
hamilt::AtomPair<std::complex<double>>* lower_ap = this->hRGintCd->find_pair(iat2, iat1);
131+
const hamilt::AtomPair<double>* ap_nspin_0 = this->hRGint_tmp[0]->find_pair(iat1, iat2);
132+
const hamilt::AtomPair<double>* ap_nspin_3 = this->hRGint_tmp[3]->find_pair(iat1, iat2);
133+
for (int ir = 0; ir < upper_ap->get_R_size(); ir++)
134+
{
135+
const auto R_index = upper_ap->get_R_index(ir);
136+
auto upper_mat = upper_ap->find_matrix(R_index);
137+
auto mat_nspin_0 = ap_nspin_0->find_matrix(R_index);
138+
auto mat_nspin_3 = ap_nspin_3->find_matrix(R_index);
148139

149-
for (int ad = 0; ad < adjs.adj_num + 1; ad++)
140+
// The row size and the col size of upper_matrix is double that of matrix_nspin_0
141+
for (int irow = 0; irow < mat_nspin_0->get_row_size(); ++irow)
150142
{
151-
// get iat2
152-
const int T2 = adjs.ntype[ad];
153-
const int I2 = adjs.natom[ad];
154-
const int iat2 = ucell.itia2iat(T2, I2);
155-
156-
// adjacent atom is also on the grid.
157-
if (this->gridt->in_this_processor[iat2])
143+
for (int icol = 0; icol < mat_nspin_0->get_col_size(); ++icol)
158144
{
159-
Atom* atom2 = &ucell.atoms[T2];
160-
auto dtau = adjs.adjacent_tau[ad] - tau1;
161-
double distance = dtau.norm() * ucell.lat0;
162-
double rcut = this->gridt->rcuts[T1] + this->gridt->rcuts[T2];
145+
upper_mat->get_value(2*irow, 2*icol) = mat_nspin_0->get_value(irow, icol) + mat_nspin_3->get_value(irow, icol);
146+
upper_mat->get_value(2*irow+1, 2*icol+1) = mat_nspin_0->get_value(irow, icol) - mat_nspin_3->get_value(irow, icol);
147+
}
148+
}
163149

164-
if (distance < rcut)
150+
if (PARAM.globalv.domag)
151+
{
152+
const hamilt::AtomPair<double>* ap_nspin_1 = this->hRGint_tmp[1]->find_pair(iat1, iat2);
153+
const hamilt::AtomPair<double>* ap_nspin_2 = this->hRGint_tmp[2]->find_pair(iat1, iat2);
154+
const auto mat_nspin_1 = ap_nspin_1->find_matrix(R_index);
155+
const auto mat_nspin_2 = ap_nspin_2->find_matrix(R_index);
156+
for (int irow = 0; irow < mat_nspin_1->get_row_size(); ++irow)
157+
{
158+
for (int icol = 0; icol < mat_nspin_1->get_col_size(); ++icol)
165159
{
166-
if (iat > iat2)
167-
{ // skip the lower triangle.
168-
nad++;
169-
continue;
170-
}
171-
// calculate the distance between iat1 and iat2.
172-
// ModuleBase::Vector3<double> dR = gd.getAdjacentTau(ad) - tau1;
173-
auto& dR = adjs.box[ad];
174-
// dR.x = adjs.box[ad].x;
175-
// dR.y = adjs.box[ad].y;
176-
// dR.z = adjs.box[ad].z;
177-
178-
int ixxx = DM_start + this->gridt->find_R2st[iat][nad];
160+
upper_mat->get_value(2*irow, 2*icol+1) = mat_nspin_1->get_value(irow, icol) + std::complex<double>(0.0, 1.0) * mat_nspin_2->get_value(irow, icol);
161+
upper_mat->get_value(2*irow+1, 2*icol) = mat_nspin_1->get_value(irow, icol) - std::complex<double>(0.0, 1.0) * mat_nspin_1->get_value(irow, icol);
162+
}
163+
}
164+
}
179165

180-
hamilt::BaseMatrix<std::complex<double>>* tmp_matrix
181-
= this->hRGintCd->find_matrix(iat, iat2, dR);
182-
#ifdef __DEBUG
183-
assert(tmp_matrix != nullptr);
184-
#endif
185-
std::complex<double>* tmp_pointer = tmp_matrix->get_pointer();
186-
std::vector<int> step_trace(4, 0);
187-
for (int is = 0; is < 2; is++)
188-
{
189-
for (int is2 = 0; is2 < 2; is2++)
190-
{
191-
step_trace[is * 2 + is2] = atom2->nw * 2 * is + is2;
192-
}
193-
}
194-
const double* vijR[4];
195-
for (int spin = 0; spin < 4; spin++)
196-
{
197-
vijR[spin] = &pvpR_reduced[spin][ixxx];
198-
}
199-
for (int iw = 0; iw < atom1->nw; iw++)
200-
{
201-
for (int iw2 = 0; iw2 < atom2->nw; ++iw2)
202-
{
203-
tmp_pointer[step_trace[0]] = *vijR[0] + *vijR[3];
204-
tmp_pointer[step_trace[3]] = *vijR[0] - *vijR[3];
205-
tmp_pointer += 2;
206-
vijR[0]++;
207-
vijR[3]++;
208-
}
209-
tmp_pointer += 2 * atom2->nw;
210-
}
211-
if (PARAM.globalv.domag)
212-
{
213-
tmp_pointer = tmp_matrix->get_pointer();
214-
for (int iw = 0; iw < atom1->nw; iw++)
215-
{
216-
for (int iw2 = 0; iw2 < atom2->nw; ++iw2)
217-
{
218-
tmp_pointer[step_trace[1]]
219-
= *vijR[1] + std::complex<double>(0.0, 1.0) * *vijR[2];
220-
tmp_pointer[step_trace[2]]
221-
= *vijR[1] - std::complex<double>(0.0, 1.0) * *vijR[2];
222-
tmp_pointer += 2;
223-
vijR[1]++;
224-
vijR[2]++;
225-
}
226-
tmp_pointer += 2 * atom2->nw;
227-
}
228-
}
229-
// save the lower triangle.
230-
if (iat < iat2)
231-
{
232-
hamilt::BaseMatrix<std::complex<double>>* conj_matrix
233-
= this->hRGintCd->find_matrix(iat2, iat, -dR);
234-
#ifdef __DEBUG
235-
assert(conj_matrix != nullptr);
236-
#endif
237-
tmp_pointer = tmp_matrix->get_pointer();
238-
for (int iw1 = 0; iw1 < atom1->nw * 2; ++iw1)
239-
{
240-
for (int iw2 = 0; iw2 < atom2->nw * 2; ++iw2)
241-
{
242-
conj_matrix->get_value(iw2, iw1) = conj(*tmp_pointer);
243-
tmp_pointer++;
244-
}
245-
}
246-
}
247-
++nad;
248-
} // end distane<rcut
166+
// fill the lower triangle matrix
167+
if (iat1 < iat2)
168+
{
169+
auto lower_mat = lower_ap->find_matrix(-R_index);
170+
for (int irow = 0; irow < upper_mat->get_row_size(); ++irow)
171+
{
172+
for (int icol = 0; icol < upper_mat->get_col_size(); ++icol)
173+
{
174+
lower_mat->get_value(icol, irow) = conj(upper_mat->get_value(irow, icol));
175+
}
249176
}
250-
} // end ad
177+
}
251178
}
252-
} // end ia
253-
} // end it
179+
}
180+
}
254181

255182
// ===================================
256183
// transfer HR from Gint to Veff<OperatorLCAO<std::complex<double>, std::complex<double>>>

0 commit comments

Comments
 (0)