@@ -59,6 +59,8 @@ Psi<T, Device>::Psi(const int nk_in,
5959 this ->allocate_inside = true ;
6060
6161 this ->ngk = ngk_in.data (); // modify later
62+ ngk_vector = ngk_in;
63+
6264 // This function will delete the psi array first(if psi exist), then malloc a new memory for it.
6365 resize_memory_op ()(this ->psi , nk_in * static_cast <std::size_t >(nbd_in) * nbs_in, " no_record" );
6466
@@ -80,7 +82,7 @@ Psi<T, Device>::Psi(const int nk_in,
8082 sizeof (T) * nk_in * nbd_in * nbs_in);
8183}
8284
83- // Constructor 3 -1: 2D Psi version
85+ // Constructor 2 -1: 2D Psi version
8486template <typename T, typename Device>
8587Psi<T, Device>::Psi(T* psi_pointer,
8688 const int nk_in,
@@ -96,6 +98,9 @@ Psi<T, Device>::Psi(T* psi_pointer,
9698 this ->allocate_inside = false ;
9799
98100 this ->ngk = nullptr ;
101+ ngk_vector = std::vector<int >(nk_in, current_nbasis_in);
102+
103+
99104 this ->psi = psi_pointer;
100105
101106 this ->nk = nk_in;
@@ -112,7 +117,7 @@ Psi<T, Device>::Psi(T* psi_pointer,
112117 base_device::information::print_device_info<Device>(this ->ctx , GlobalV::ofs_device);
113118}
114119
115- // Constructor 3 -2: 2D Psi version
120+ // Constructor 2 -2: 2D Psi version
116121template <typename T, typename Device>
117122Psi<T, Device>::Psi(const int nk_in,
118123 const int nbd_in,
@@ -127,6 +132,8 @@ Psi<T, Device>::Psi(const int nk_in,
127132 this ->allocate_inside = true ;
128133
129134 this ->ngk = nullptr ;
135+ ngk_vector = std::vector<int >(nk_in, current_nbasis_in);
136+
130137 assert (nk_in > 0 && nbd_in >= 0 && nbs_in > 0 );
131138 resize_memory_op ()(this ->psi , nk_in * static_cast <std::size_t >(nbd_in) * nbs_in, " no_record" );
132139
@@ -148,12 +155,14 @@ Psi<T, Device>::Psi(const int nk_in,
148155 sizeof (T) * nk_in * nbd_in * nbs_in);
149156}
150157
151- // Constructor 2 -1:
158+ // Copy Constructor 3 -1:
152159template <typename T, typename Device>
153160Psi<T, Device>::Psi(const Psi& psi_in)
154161{
155162
156163 this ->ngk = psi_in.ngk ;
164+ this ->ngk_vector = psi_in.ngk_vector ;
165+
157166 this ->nk = psi_in.get_nk ();
158167 this ->nbands = psi_in.get_nbands ();
159168 this ->nbasis = psi_in.get_nbasis ();
@@ -172,13 +181,15 @@ Psi<T, Device>::Psi(const Psi& psi_in)
172181}
173182
174183
175- // Constructor 2 -2:
184+ // Copy Constructor 3 -2:
176185template <typename T, typename Device>
177186template <typename T_in, typename Device_in>
178187Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
179188{
180189
181190 this ->ngk = psi_in.get_ngk_pointer ();
191+ this ->ngk_vector = psi_in.get_ngk_vector ();
192+
182193 this ->nk = psi_in.get_nk ();
183194 this ->nbands = psi_in.get_nbands ();
184195 this ->nbasis = psi_in.get_nbasis ();
@@ -276,6 +287,13 @@ const int* Psi<T, Device>::get_ngk_pointer() const
276287 return this ->ngk ;
277288}
278289
290+ template <typename T, typename Device>
291+ const std::vector<int >& Psi<T, Device>::get_ngk_vector() const
292+ {
293+ return this ->ngk_vector ;
294+ }
295+
296+
279297template <typename T, typename Device>
280298const int & Psi<T, Device>::get_psi_bias() const
281299{
0 commit comments