@@ -28,6 +28,7 @@ Range::Range(const bool k_first_in, const size_t index_1_in, const size_t range_
2828 range_2 = range_2_in;
2929}
3030
31+ // Constructor 0: basic
3132template <typename T, typename Device>
3233Psi<T, Device>::Psi()
3334{
@@ -43,16 +44,31 @@ Psi<T, Device>::~Psi()
4344 }
4445}
4546
47+ // Constructor 1-1:
4648template <typename T, typename Device>
4749Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int * ngk_in, const bool k_first_in)
4850{
51+ assert (nk_in > 0 );
52+ assert (nbd_in > 0 );
53+ assert (nbs_in > 0 );
54+
4955 this ->k_first = k_first_in;
50- this ->ngk = ngk_in;
51- this ->current_b = 0 ;
52- this ->current_k = 0 ;
5356 this ->npol = PARAM.globalv .npol ;
57+ this ->allocate_inside = true ;
5458
55- this ->resize (nk_in, nbd_in, nbs_in);
59+ this ->ngk = ngk_in; // modify later
60+ // This function will delete the psi array first(if psi exist), then malloc a new memory for it.
61+ resize_memory_op ()(this ->ctx , this ->psi , nk_in * static_cast <std::size_t >(nbd_in) * nbs_in, " no_record" );
62+
63+ this ->nk = nk_in;
64+ this ->nbands = nbd_in;
65+ this ->nbasis = nbs_in;
66+
67+ this ->current_b = 0 ;
68+ this ->current_k = 0 ;
69+ this ->current_nbasis = nbs_in;
70+ this ->psi_current = this ->psi ;
71+ this ->psi_bias = 0 ;
5672
5773 // Currently only GPU's implementation is supported for device recording!
5874 base_device::information::print_device_info<Device>(this ->ctx , GlobalV::ofs_device);
@@ -62,20 +78,35 @@ Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i
6278 sizeof (T) * nk_in * nbd_in * nbs_in);
6379}
6480
81+ // Constructor 1-2:
6582template <typename T, typename Device>
6683Psi<T, Device>::Psi(const int nk_in,
6784 const int nbd_in,
6885 const int nbs_in,
6986 const std::vector<int >& ngk_in,
7087 const bool k_first_in)
7188{
89+ assert (nk_in > 0 );
90+ assert (nbd_in > 0 );
91+ assert (nbs_in > 0 );
92+
7293 this ->k_first = k_first_in;
73- this ->ngk = ngk_in.data ();
74- this ->current_b = 0 ;
75- this ->current_k = 0 ;
7694 this ->npol = PARAM.globalv .npol ;
95+ this ->allocate_inside = true ;
96+
97+ this ->ngk = ngk_in.data (); // modify later
98+ // This function will delete the psi array first(if psi exist), then malloc a new memory for it.
99+ resize_memory_op ()(this ->ctx , this ->psi , nk_in * static_cast <std::size_t >(nbd_in) * nbs_in, " no_record" );
77100
78- this ->resize (nk_in, nbd_in, nbs_in);
101+ this ->nk = nk_in;
102+ this ->nbands = nbd_in;
103+ this ->nbasis = nbs_in;
104+
105+ this ->current_b = 0 ;
106+ this ->current_k = 0 ;
107+ this ->current_nbasis = nbs_in;
108+ this ->psi_current = this ->psi ;
109+ this ->psi_bias = 0 ;
79110
80111 // Currently only GPU's implementation is supported for device recording!
81112 base_device::information::print_device_info<Device>(this ->ctx , GlobalV::ofs_device);
@@ -85,7 +116,7 @@ Psi<T, Device>::Psi(const int nk_in,
85116 sizeof (T) * nk_in * nbd_in * nbs_in);
86117}
87118
88- // Constructor 8 -1:
119+ // Constructor 3 -1: 2D Psi version
89120template <typename T, typename Device>
90121Psi<T, Device>::Psi(T* psi_pointer,
91122 const int nk_in,
@@ -94,7 +125,6 @@ Psi<T, Device>::Psi(T* psi_pointer,
94125 const int current_nbasis_in,
95126 const bool k_first_in)
96127{
97-
98128 // Currently this function only supports nk_in == 1 when called within diagH_subspace_init.
99129 // assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func
100130
@@ -103,7 +133,6 @@ Psi<T, Device>::Psi(T* psi_pointer,
103133 this ->allocate_inside = false ;
104134
105135 this ->ngk = nullptr ;
106-
107136 this ->psi = psi_pointer;
108137
109138 this ->nk = nk_in;
@@ -120,15 +149,14 @@ Psi<T, Device>::Psi(T* psi_pointer,
120149 base_device::information::print_device_info<Device>(this ->ctx , GlobalV::ofs_device);
121150}
122151
123- // Constructor 8-3 : 2D Psi version 3
152+ // Constructor 3-2 : 2D Psi version
124153template <typename T, typename Device>
125154Psi<T, Device>::Psi(const int nk_in,
126155 const int nbd_in,
127156 const int nbs_in,
128157 const int current_nbasis_in,
129158 const bool k_first_in)
130159{
131-
132160 // Currently this function only supports nk_in == 1 when called within diagH_subspace_init.
133161 assert (nk_in == 1 );
134162
@@ -158,37 +186,7 @@ Psi<T, Device>::Psi(const int nk_in,
158186 sizeof (T) * nk_in * nbd_in * nbs_in);
159187}
160188
161- // template <typename T, typename Device>
162- // Psi<T, Device>::Psi(const Psi& psi_in, const int nk_in, const int nband_in)
163- // {
164- // assert(nk_in == 1);
165- // assert(nband_in <= psi_in.get_nbands() && nband_in > 0);
166-
167- // this->k_first = psi_in.get_k_first();
168- // this->npol = psi_in.npol;
169- // this->allocate_inside = true;
170-
171- // this->nk = nk_in;
172- // this->nbands = nband_in;
173- // this->nbasis = psi_in.get_nbasis();
174-
175- // // This function will delete the psi array first(if psi exist), then malloc a new memory for it.
176- // resize_memory_op()(this->ctx,
177- // this->psi,
178- // (static_cast<std::size_t>(this->nk) * static_cast<std::size_t>(this->nbands)
179- // * static_cast<std::size_t>(this->nbasis)),
180- // "no_record");
181- // synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size());
182-
183- // this->current_k = 0;
184- // this->current_b = 0;
185- // this->current_nbasis = this->nbasis;
186- // this->psi_current = this->psi;
187- // this->psi_bias = 0;
188-
189- // this->ngk = nullptr;
190- // }
191-
189+ // Constructor 2-1:
192190template <typename T, typename Device>
193191Psi<T, Device>::Psi(const Psi& psi_in)
194192{
@@ -213,6 +211,8 @@ Psi<T, Device>::Psi(const Psi& psi_in)
213211 this ->psi_current = this ->psi + psi_in.get_psi_bias ();
214212}
215213
214+
215+ // Constructor 2-2:
216216template <typename T, typename Device>
217217template <typename T_in, typename Device_in>
218218Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
0 commit comments