@@ -53,7 +53,7 @@ void PSIInit<T, Device>::prepare_init(const int& random_seed)
5353 this ->psi_initer = std::unique_ptr<psi_initializer<T>>(new psi_initializer_random<T>());
5454 }
5555 else if (this ->init_wfc == " atomic"
56- || (this ->init_wfc == " atomic+random" && this ->ucell .natomwfc != PARAM.inp .nbands ))
56+ || (this ->init_wfc == " atomic+random" && this ->ucell .natomwfc < PARAM.inp .nbands ))
5757 {
5858 this ->psi_initer = std::unique_ptr<psi_initializer<T>>(new psi_initializer_atomic<T>());
5959 }
@@ -99,17 +99,30 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
9999 const int nbands_start = this ->psi_initer ->nbands_start ();
100100 const int nbands = psi->get_nbands ();
101101 const int nbasis = psi->get_nbasis ();
102- const bool another_psi_space = (nbands_start != nbands || PARAM. inp . precision == " single " );
102+ const bool not_equal = (nbands_start != nbands);
103103
104104 Psi<T>* psi_cpu = reinterpret_cast <psi::Psi<T>*>(psi);
105105 Psi<T, Device>* psi_device = kspw_psi;
106106
107- if (another_psi_space )
107+ if (not_equal )
108108 {
109109 psi_cpu = new Psi<T>(1 , nbands_start, nbasis, nullptr );
110110 psi_device = PARAM.inp .device == " gpu" ? new psi::Psi<T, Device>(psi_cpu[0 ])
111111 : reinterpret_cast <psi::Psi<T, Device>*>(psi_cpu);
112112 }
113+ else if (PARAM.inp .precision == " single" )
114+ {
115+ if (PARAM.inp .device == " cpu" )
116+ {
117+ psi_cpu = reinterpret_cast <psi::Psi<T>*>(kspw_psi);
118+ psi_device = kspw_psi;
119+ }
120+ else
121+ {
122+ psi_cpu = new Psi<T>(1 , nbands_start, nbasis, nullptr );
123+ psi_device = kspw_psi;
124+ }
125+ }
113126
114127 // loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
115128 // like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
@@ -126,16 +139,16 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
126139 this ->psi_initer ->init_psig (psi_cpu->get_pointer (), ik);
127140 if (psi_device->get_pointer () != psi_cpu->get_pointer ())
128141 {
129- castmem_h2d_op ()(ctx, cpu_ctx, psi_device->get_pointer (), psi_cpu->get_pointer (), nbands_start * nbasis);
142+ syncmem_h2d_op ()(ctx, cpu_ctx, psi_device->get_pointer (), psi_cpu->get_pointer (), nbands_start * nbasis);
130143 }
131144
132145 std::vector<typename GetTypeReal<T>::type> etatom (nbands_start, 0.0 );
133146
134147 if (this ->ks_solver == " cg" )
135148 {
136- if (another_psi_space )
149+ if (not_equal )
137150 {
138- // for diagH_subspace_init, psi_cpu ->get_pointer() and kspw_psi->get_pointer() should be different
151+ // for diagH_subspace_init, psi_device ->get_pointer() and kspw_psi->get_pointer() should be different
139152 hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init (p_hamilt,
140153 psi_device->get_pointer (),
141154 nbands_start,
@@ -145,7 +158,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
145158 }
146159 else
147160 {
148- // for diagH_subspace_init, psi_cpu ->get_pointer() and kspw_psi->get_pointer() can be the same
161+ // for diagH_subspace, psi_device ->get_pointer() and kspw_psi->get_pointer() can be the same
149162 hsolver::DiagoIterAssist<T, Device>::diagH_subspace (p_hamilt,
150163 *psi_device,
151164 *kspw_psi,
@@ -155,21 +168,25 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
155168 }
156169 else // dav, bpcg
157170 {
158- if (another_psi_space )
171+ if (psi_device-> get_pointer () != kspw_psi-> get_pointer () )
159172 {
160173 syncmem_complex_op ()(ctx, ctx, kspw_psi->get_pointer (), psi_device->get_pointer (), nbands * nbasis);
161174 }
162175 }
163176 } // end k-point loop
164177
165- if (another_psi_space )
178+ if (not_equal )
166179 {
167180 delete psi_cpu;
168181 if (PARAM.inp .device == " gpu" )
169182 {
170183 delete psi_device;
171184 }
172185 }
186+ else if (PARAM.inp .precision == " single" && PARAM.inp .device == " gpu" )
187+ {
188+ delete psi_cpu;
189+ }
173190
174191 ModuleBase::timer::tick (" PSIInit" , " initialize_psi" );
175192}
0 commit comments