@@ -99,17 +99,30 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
9999 const int nbands_start = this ->psi_initer ->nbands_start ();
100100 const int nbands = psi->get_nbands ();
101101 const int nbasis = psi->get_nbasis ();
102- const bool another_psi_space = (nbands_start != nbands || PARAM. inp . precision == " single " );
102+ const bool not_equal = (nbands_start != nbands);
103103
104104 Psi<T>* psi_cpu = reinterpret_cast <psi::Psi<T>*>(psi);
105105 Psi<T, Device>* psi_device = kspw_psi;
106106
107- if (another_psi_space )
107+ if (not_equal )
108108 {
109109 psi_cpu = new Psi<T>(1 , nbands_start, nbasis, nullptr );
110110 psi_device = PARAM.inp .device == " gpu" ? new psi::Psi<T, Device>(psi_cpu[0 ])
111111 : reinterpret_cast <psi::Psi<T, Device>*>(psi_cpu);
112112 }
113+ else if (PARAM.inp .precision == " single" )
114+ {
115+ if (PARAM.inp .device == " cpu" )
116+ {
117+ psi_cpu = reinterpret_cast <psi::Psi<T>*>(kspw_psi);
118+ psi_device = kspw_psi;
119+ }
120+ else
121+ {
122+ psi_cpu = new Psi<T>(1 , nbands_start, nbasis, nullptr );
123+ psi_device = kspw_psi;
124+ }
125+ }
113126
114127 // loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
115128 // like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
@@ -133,9 +146,9 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
133146
134147 if (this ->ks_solver == " cg" )
135148 {
136- if (another_psi_space )
149+ if (not_equal )
137150 {
138- // for diagH_subspace_init, psi_cpu ->get_pointer() and kspw_psi->get_pointer() should be different
151+ // for diagH_subspace_init, psi_device ->get_pointer() and kspw_psi->get_pointer() should be different
139152 hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init (p_hamilt,
140153 psi_device->get_pointer (),
141154 nbands_start,
@@ -145,7 +158,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
145158 }
146159 else
147160 {
148- // for diagH_subspace_init, psi_cpu ->get_pointer() and kspw_psi->get_pointer() can be the same
161+ // for diagH_subspace, psi_device ->get_pointer() and kspw_psi->get_pointer() can be the same
149162 hsolver::DiagoIterAssist<T, Device>::diagH_subspace (p_hamilt,
150163 *psi_device,
151164 *kspw_psi,
@@ -155,21 +168,25 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
155168 }
156169 else // dav, bpcg
157170 {
158- if (another_psi_space )
171+ if (psi_device-> get_pointer () != kspw_psi-> get_pointer () )
159172 {
160173 syncmem_complex_op ()(ctx, ctx, kspw_psi->get_pointer (), psi_device->get_pointer (), nbands * nbasis);
161174 }
162175 }
163176 } // end k-point loop
164177
165- if (another_psi_space )
178+ if (not_equal )
166179 {
167180 delete psi_cpu;
168181 if (PARAM.inp .device == " gpu" )
169182 {
170183 delete psi_device;
171184 }
172185 }
186+ else if (PARAM.inp .precision == " single" && PARAM.inp .device == " gpu" )
187+ {
188+ delete psi_cpu;
189+ }
173190
174191 ModuleBase::timer::tick (" PSIInit" , " initialize_psi" );
175192}
0 commit comments