@@ -104,14 +104,19 @@ namespace LR_Util
104104
105105 // / psi(nk=1, nbands=nb, nk * nbasis) -> psi(nb, nk, nbasis) without memory copy
106106 template <typename T, typename Device>
107- psi::Psi<T, Device> k1_to_bfirst_wrapper (const psi::Psi<T, Device>& psi_kfirst, int nk_in, int nbasis_in)
107+ psi::Psi<T, Device> c (const psi::Psi<T, Device>& psi_kfirst, int nk_in, int nbasis_in)
108108 {
109109 assert (psi_kfirst.get_nk () == 1 );
110110 assert (nk_in * nbasis_in == psi_kfirst.get_nbasis ());
111111
112112 int ib_now = psi_kfirst.get_current_b ();
113113 psi_kfirst.fix_b (0 ); // for get_pointer() to get the head pointer
114- psi::Psi<T, Device> psi_bfirst (psi_kfirst.get_pointer (), nk_in, psi_kfirst.get_nbands (), nbasis_in, false );
114+ psi::Psi<T, Device> psi_bfirst (psi_kfirst.get_pointer (),
115+ nk_in,
116+ psi_kfirst.get_nbands (),
117+ nbasis_in,
118+ nbasis_in,
119+ false );
115120 psi_kfirst.fix_b (ib_now);
116121 return psi_bfirst;
117122 }
@@ -124,7 +129,12 @@ namespace LR_Util
124129 int ik_now = psi_bfirst.get_current_k ();
125130
126131 psi_bfirst.fix_kb (0 , 0 ); // for get_pointer() to get the head pointer
127- psi::Psi<T, Device> psi_kfirst (psi_bfirst.get_pointer (), 1 , psi_bfirst.get_nbands (), psi_bfirst.get_nk () * psi_bfirst.get_nbasis (), true );
132+ psi::Psi<T, Device> psi_kfirst (psi_bfirst.get_pointer (),
133+ 1 ,
134+ psi_bfirst.get_nbands (),
135+ psi_bfirst.get_nk () * psi_bfirst.get_nbasis (),
136+ psi_bfirst.get_nk () * psi_bfirst.get_nbasis (),
137+ true );
128138 psi_bfirst.fix_kb (ik_now, ib_now);
129139 return psi_kfirst;
130140 }
0 commit comments