@@ -76,7 +76,6 @@ template <typename T, typename Device> Psi<T, Device>::Psi(T* psi_pointer, const
7676 this ->nk = nk_in;
7777 this ->nbands = nbd_in;
7878 this ->nbasis = nbs_in;
79- this ->current_nbasis = nbs_in;
8079 this ->psi_current = this ->psi = psi_pointer;
8180 this ->allocate_inside = false ;
8281 // Currently only GPU's implementation is supported for device recording!
@@ -148,7 +147,6 @@ template <typename T, typename Device> Psi<T, Device>::Psi(const Psi& psi_in)
148147 psi_in.get_pointer () - psi_in.get_psi_bias (),
149148 psi_in.size ());
150149 this ->psi_bias = psi_in.get_psi_bias ();
151- this ->current_nbasis = psi_in.get_current_nbas ();
152150 this ->psi_current = this ->psi + psi_in.get_psi_bias ();
153151}
154152
@@ -200,7 +198,6 @@ Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
200198 psi_in.size ());
201199 }
202200 this ->psi_bias = psi_in.get_psi_bias ();
203- this ->current_nbasis = psi_in.get_current_nbas ();
204201 this ->psi_current = this ->psi + psi_in.get_psi_bias ();
205202}
206203
@@ -213,7 +210,6 @@ void Psi<T, Device>::resize(const int nks_in, const int nbands_in, const int nba
213210 this ->nk = nks_in;
214211 this ->nbands = nbands_in;
215212 this ->nbasis = nbasis_in;
216- this ->current_nbasis = nbasis_in;
217213 this ->psi_current = this ->psi ;
218214 // GlobalV::ofs_device << "allocated xxx MB memory for psi" << std::endl;
219215}
@@ -276,27 +272,26 @@ template <typename T, typename Device> std::size_t Psi<T, Device>::size() const
276272
277273template <typename T, typename Device> void Psi<T, Device>::fix_k(const int ik) const
278274{
279- assert (ik >= 0 );
280- this ->current_k = ik;
281- if (this ->ngk != nullptr && this ->npol != 2 )
282- this ->current_nbasis = this ->ngk [ik];
283- else
284- this ->current_nbasis = this ->nbasis ;
275+ assert (ik >= 0 && ik < this ->nk );
285276
286- if (this ->k_first )this ->current_b = 0 ;
287- int base = this ->current_b * this ->nk * this ->nbasis ;
288- if (ik >= this ->nk )
277+ if (this ->k_first == true )
289278 {
290- // mem_saver: fix to base
291- this ->psi_bias = base;
292- this ->psi_current = const_cast <T*>(&(this ->psi [base]));
279+ this ->current_k = ik;
280+ this ->current_b = 0 ;
281+
282+ this ->psi_bias = this ->current_k * this ->nbands * this ->nbasis ;
283+ this ->psi_current = this ->psi + this ->psi_bias ;
293284 }
294285 else
295286 {
296- this ->psi_bias = k_first ? ik * this ->nbands * this ->nbasis : base + ik * this ->nbasis ;
297- this ->psi_current = const_cast <T*>(&(this ->psi [psi_bias]));
287+ this ->current_k = ik;
288+ // this->current_b remains unchanged
289+
290+ this ->psi_bias = this ->current_b * this ->nk * this ->nbasis + this ->current_k * this ->nbasis ;
291+ this ->psi_current = this ->psi + this ->psi_bias ;
298292 }
299293}
294+
300295template <typename T, typename Device> void Psi<T, Device>::fix_b(const int ib) const
301296{
302297 assert (ib >= 0 );
@@ -350,12 +345,6 @@ template <typename T, typename Device> T& Psi<T, Device>::operator()(const int i
350345 return this ->psi_current [ikb2 * this ->nbasis + ibasis];
351346}
352347
353- template <typename T, typename Device> T& Psi<T, Device>::operator ()(const int ibasis) const
354- {
355- assert (ibasis >= 0 && ibasis < this ->nbasis );
356- return this ->psi_current [ibasis];
357- }
358-
359348template <typename T, typename Device> int Psi<T, Device>::get_current_k() const
360349{
361350 return this ->current_k ;
@@ -366,15 +355,28 @@ template <typename T, typename Device> int Psi<T, Device>::get_current_b() const
366355 return this ->current_b ;
367356}
368357
369- template <typename T, typename Device> int Psi<T, Device>::get_current_nbas() const
358+ template <typename T, typename Device> const int & Psi<T, Device>::get_current_nbas() const
370359{
371- return this ->current_nbasis ;
360+ if (this ->ngk != nullptr )
361+ {
362+ return this ->ngk [this ->current_k ];
363+ }
364+ else
365+ {
366+ return this ->nbasis ;
367+ }
372368}
373369
374- template <typename T, typename Device> const int & Psi<T, Device>::get_ngk (const int ik_in) const
370+ template <typename T, typename Device> const int & Psi<T, Device>::get_ik_nbas (const int ik_in) const
375371{
376- if (!this ->ngk ) return this ->nbasis ;
377- return this ->ngk [ik_in];
372+ if (this ->ngk != nullptr )
373+ {
374+ return this ->ngk [ik_in];
375+ }
376+ else
377+ {
378+ return this ->nbasis ;
379+ }
378380}
379381
380382template <typename T, typename Device> void Psi<T, Device>::zero_out()
0 commit comments