Skip to content

Commit 74b2954

Browse files
authored
fix: memory leak when precision=single (#5839)
* fix: memory leak when precision=single * change op * fix wrong logic of atomic+random
1 parent 48fbc90 commit 74b2954

File tree

5 files changed

+52
-14
lines changed

5 files changed

+52
-14
lines changed

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,18 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
106106
container::kernels::destroyGpuBlasHandle();
107107
container::kernels::destroyGpuSolverHandle();
108108
#endif
109-
delete reinterpret_cast<psi::Psi<T, Device>*>(this->kspw_psi);
110109
}
111110
#ifdef __DSP
112111
std::cout << " ** Closing DSP Hardware..." << std::endl;
113112
dspDestoryHandle(GlobalV::MY_RANK);
114113
#endif
114+
if(PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
115+
{
116+
delete this->kspw_psi;
117+
}
115118
if (PARAM.inp.precision == "single")
116119
{
117-
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
120+
delete this->__kspw_psi;
118121
}
119122

120123
delete this->psi;

source/module_hamilt_pw/hamilt_pwdft/VNL_in_pw.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,12 @@ void pseudopot_cell_vnl::getvnl(Device* ctx,
532532
delmem_var_op()(ctx, ylm);
533533
delmem_var_op()(ctx, vkb1);
534534
delmem_complex_op()(ctx, sk);
535-
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
535+
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
536536
{
537537
delmem_var_op()(ctx, gk);
538+
}
539+
if (PARAM.inp.device == "gpu")
540+
{
538541
delmem_int_op()(ctx, atom_nh);
539542
delmem_int_op()(ctx, atom_nb);
540543
delmem_int_op()(ctx, atom_na);

source/module_io/read_input_item_system.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,12 +775,28 @@ void ReadInput::item_system()
775775
para.input.device=base_device::information::get_device_flag(
776776
para.inp.device, para.inp.basis_type);
777777
};
778+
item.check_value = [](const Input_Item& item, const Parameter& para) {
779+
std::vector<std::string> avail_list = {"cpu", "gpu"};
780+
if (std::find(avail_list.begin(), avail_list.end(), para.input.device) == avail_list.end())
781+
{
782+
const std::string warningstr = nofound_str(avail_list, "device");
783+
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
784+
}
785+
};
778786
this->add_item(item);
779787
}
780788
{
781789
Input_Item item("precision");
782790
item.annotation = "the computing precision for ABACUS";
783791
read_sync_string(input.precision);
792+
item.check_value = [](const Input_Item& item, const Parameter& para) {
793+
std::vector<std::string> avail_list = {"single", "double"};
794+
if (std::find(avail_list.begin(), avail_list.end(), para.input.precision) == avail_list.end())
795+
{
796+
const std::string warningstr = nofound_str(avail_list, "precision");
797+
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
798+
}
799+
};
784800
this->add_item(item);
785801
}
786802
}

source/module_psi/psi_init.cpp

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void PSIInit<T, Device>::prepare_init(const int& random_seed)
5353
this->psi_initer = std::unique_ptr<psi_initializer<T>>(new psi_initializer_random<T>());
5454
}
5555
else if (this->init_wfc == "atomic"
56-
|| (this->init_wfc == "atomic+random" && this->ucell.natomwfc != PARAM.inp.nbands))
56+
|| (this->init_wfc == "atomic+random" && this->ucell.natomwfc < PARAM.inp.nbands))
5757
{
5858
this->psi_initer = std::unique_ptr<psi_initializer<T>>(new psi_initializer_atomic<T>());
5959
}
@@ -99,17 +99,30 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
9999
const int nbands_start = this->psi_initer->nbands_start();
100100
const int nbands = psi->get_nbands();
101101
const int nbasis = psi->get_nbasis();
102-
const bool another_psi_space = (nbands_start != nbands || PARAM.inp.precision == "single");
102+
const bool not_equal = (nbands_start != nbands);
103103

104104
Psi<T>* psi_cpu = reinterpret_cast<psi::Psi<T>*>(psi);
105105
Psi<T, Device>* psi_device = kspw_psi;
106106

107-
if (another_psi_space)
107+
if (not_equal)
108108
{
109109
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nullptr);
110110
psi_device = PARAM.inp.device == "gpu" ? new psi::Psi<T, Device>(psi_cpu[0])
111111
: reinterpret_cast<psi::Psi<T, Device>*>(psi_cpu);
112112
}
113+
else if (PARAM.inp.precision == "single")
114+
{
115+
if (PARAM.inp.device == "cpu")
116+
{
117+
psi_cpu = reinterpret_cast<psi::Psi<T>*>(kspw_psi);
118+
psi_device = kspw_psi;
119+
}
120+
else
121+
{
122+
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nullptr);
123+
psi_device = kspw_psi;
124+
}
125+
}
113126

114127
// loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
115128
// like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
@@ -126,16 +139,16 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
126139
this->psi_initer->init_psig(psi_cpu->get_pointer(), ik);
127140
if (psi_device->get_pointer() != psi_cpu->get_pointer())
128141
{
129-
castmem_h2d_op()(ctx, cpu_ctx, psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis);
142+
syncmem_h2d_op()(ctx, cpu_ctx, psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis);
130143
}
131144

132145
std::vector<typename GetTypeReal<T>::type> etatom(nbands_start, 0.0);
133146

134147
if (this->ks_solver == "cg")
135148
{
136-
if (another_psi_space)
149+
if (not_equal)
137150
{
138-
// for diagH_subspace_init, psi_cpu->get_pointer() and kspw_psi->get_pointer() should be different
151+
// for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be different
139152
hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init(p_hamilt,
140153
psi_device->get_pointer(),
141154
nbands_start,
@@ -145,7 +158,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
145158
}
146159
else
147160
{
148-
// for diagH_subspace_init, psi_cpu->get_pointer() and kspw_psi->get_pointer() can be the same
161+
// for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same
149162
hsolver::DiagoIterAssist<T, Device>::diagH_subspace(p_hamilt,
150163
*psi_device,
151164
*kspw_psi,
@@ -155,21 +168,25 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
155168
}
156169
else // dav, bpcg
157170
{
158-
if (another_psi_space)
171+
if (psi_device->get_pointer() != kspw_psi->get_pointer())
159172
{
160173
syncmem_complex_op()(ctx, ctx, kspw_psi->get_pointer(), psi_device->get_pointer(), nbands * nbasis);
161174
}
162175
}
163176
} // end k-point loop
164177

165-
if (another_psi_space)
178+
if (not_equal)
166179
{
167180
delete psi_cpu;
168181
if(PARAM.inp.device == "gpu")
169182
{
170183
delete psi_device;
171184
}
172185
}
186+
else if (PARAM.inp.precision == "single" && PARAM.inp.device == "gpu")
187+
{
188+
delete psi_cpu;
189+
}
173190

174191
ModuleBase::timer::tick("PSIInit", "initialize_psi");
175192
}

source/module_psi/psi_init.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ class PSIInit
8282

8383
//-------------------------OP--------------------------------------------
8484
using syncmem_complex_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
85-
using castmem_h2d_op
86-
= base_device::memory::cast_memory_op<T, T, Device, base_device::DEVICE_CPU>;
85+
using syncmem_h2d_op = base_device::memory::synchronize_memory_op<T, Device, base_device::DEVICE_CPU>;
8786
};
8887

8988
///@brief allocate the wavefunction

0 commit comments

Comments
 (0)