Skip to content

Commit c716bb7

Browse files
committed
update Psi(const Psi& psi_in, const int nk_in, int nband_in)
1 parent 8e3a58f commit c716bb7

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

source/module_psi/psi.cpp

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,31 +105,44 @@ Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int
105105
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
106106
}
107107

108-
109108
template <typename T, typename Device>
110109
Psi<T, Device>::Psi(const Psi& psi_in, const int nk_in, int nband_in)
111110
{
112111
assert(nk_in <= psi_in.get_nk() && nk_in > 0);
113112
assert(nband_in <= psi_in.get_nbands() && nband_in > 0);
114113

115114
this->k_first = psi_in.get_k_first();
116-
this->resize(nk_in, nband_in, psi_in.get_nbasis());
117-
this->ngk = psi_in.ngk;
118115
this->npol = psi_in.npol;
119-
if (nband_in <= psi_in.get_nbands())
116+
this->allocate_inside = true;
117+
118+
this->nk = nk_in;
119+
this->nbands = nband_in;
120+
this->nbasis = psi_in.get_nbasis();
121+
122+
// This function will delete the psi array first(if psi exist), then malloc a new memory for it.
123+
resize_memory_op()(this->ctx,
124+
this->psi,
125+
(static_cast<std::size_t>(this->nk) * static_cast<std::size_t>(this->nbands)
126+
* static_cast<std::size_t>(this->nbasis)),
127+
"no_record");
128+
synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size());
129+
130+
this->current_k = 0;
131+
this->current_b = 0;
132+
this->current_nbasis = this->nbasis;
133+
this->psi_current = this->psi;
134+
this->psi_bias = 0;
135+
136+
if (this->nk != psi_in.get_nk())
120137
{
121-
// copy from Psi from psi_in(current_k, 0, 0),
122-
// if size of k is 1, current_k in new Psi is psi_in.current_k
123-
if (nk_in == 1)
124-
{
125-
// current_k for this Psi only keep the spin index same as the copied Psi
126-
this->current_k = psi_in.get_current_k();
127-
}
128-
synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size());
138+
this->ngk = nullptr;
139+
}
140+
else
141+
{
142+
this->ngk = psi_in.ngk;
129143
}
130144
}
131145

132-
133146
template <typename T, typename Device>
134147
Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in)
135148
{

source/module_psi/psi.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class Psi
131131

132132
// solve Range: return(pointer of begin, number of bands or k-points)
133133
std::tuple<const T*, int> to_range(const Range& range) const;
134+
134135
int npol = 1;
135136

136137
private:

0 commit comments

Comments
 (0)