Skip to content

Commit 8822aad

Browse files
committed
add ngk_vector
1 parent ec692be commit 8822aad

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

source/module_psi/psi.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ Psi<T, Device>::Psi(const int nk_in,
5959
this->allocate_inside = true;
6060

6161
this->ngk = ngk_in.data(); // modify later
62+
ngk_vector = ngk_in;
63+
6264
// This function will delete the psi array first(if psi exist), then malloc a new memory for it.
6365
resize_memory_op()(this->psi, nk_in * static_cast<std::size_t>(nbd_in) * nbs_in, "no_record");
6466

@@ -80,7 +82,7 @@ Psi<T, Device>::Psi(const int nk_in,
8082
sizeof(T) * nk_in * nbd_in * nbs_in);
8183
}
8284

83-
// Constructor 3-1: 2D Psi version
85+
// Constructor 2-1: 2D Psi version
8486
template <typename T, typename Device>
8587
Psi<T, Device>::Psi(T* psi_pointer,
8688
const int nk_in,
@@ -96,6 +98,9 @@ Psi<T, Device>::Psi(T* psi_pointer,
9698
this->allocate_inside = false;
9799

98100
this->ngk = nullptr;
101+
ngk_vector = std::vector<int>(nk_in, current_nbasis_in);
102+
103+
99104
this->psi = psi_pointer;
100105

101106
this->nk = nk_in;
@@ -112,7 +117,7 @@ Psi<T, Device>::Psi(T* psi_pointer,
112117
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
113118
}
114119

115-
// Constructor 3-2: 2D Psi version
120+
// Constructor 2-2: 2D Psi version
116121
template <typename T, typename Device>
117122
Psi<T, Device>::Psi(const int nk_in,
118123
const int nbd_in,
@@ -127,6 +132,8 @@ Psi<T, Device>::Psi(const int nk_in,
127132
this->allocate_inside = true;
128133

129134
this->ngk = nullptr;
135+
ngk_vector = std::vector<int>(nk_in, current_nbasis_in);
136+
130137
assert(nk_in > 0 && nbd_in >= 0 && nbs_in > 0);
131138
resize_memory_op()(this->psi, nk_in * static_cast<std::size_t>(nbd_in) * nbs_in, "no_record");
132139

@@ -148,12 +155,14 @@ Psi<T, Device>::Psi(const int nk_in,
148155
sizeof(T) * nk_in * nbd_in * nbs_in);
149156
}
150157

151-
// Constructor 2-1:
158+
// Copy Constructor 3-1:
152159
template <typename T, typename Device>
153160
Psi<T, Device>::Psi(const Psi& psi_in)
154161
{
155162

156163
this->ngk = psi_in.ngk;
164+
this->ngk_vector = psi_in.ngk_vector;
165+
157166
this->nk = psi_in.get_nk();
158167
this->nbands = psi_in.get_nbands();
159168
this->nbasis = psi_in.get_nbasis();
@@ -172,13 +181,15 @@ Psi<T, Device>::Psi(const Psi& psi_in)
172181
}
173182

174183

175-
// Constructor 2-2:
184+
// Copy Constructor 3-2:
176185
template <typename T, typename Device>
177186
template <typename T_in, typename Device_in>
178187
Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
179188
{
180189

181190
this->ngk = psi_in.get_ngk_pointer();
191+
this->ngk_vector = psi_in.get_ngk_vector();
192+
182193
this->nk = psi_in.get_nk();
183194
this->nbands = psi_in.get_nbands();
184195
this->nbasis = psi_in.get_nbasis();
@@ -276,6 +287,13 @@ const int* Psi<T, Device>::get_ngk_pointer() const
276287
return this->ngk;
277288
}
278289

290+
template <typename T, typename Device>
291+
const std::vector<int>& Psi<T, Device>::get_ngk_vector() const
292+
{
293+
return this->ngk_vector;
294+
}
295+
296+
279297
template <typename T, typename Device>
280298
const int& Psi<T, Device>::get_psi_bias() const
281299
{

source/module_psi/psi.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,26 +42,26 @@ class Psi
4242
// Constructor 1:
4343
Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector<int>& ngk_in, const bool k_first_in);
4444

45-
// Constructor 2-1: initialize a new psi from the given psi_in
46-
Psi(const Psi& psi_in);
47-
48-
// Constructor 2-2: initialize a new psi from the given psi_in with a different class template
49-
// in this case, psi_in may have a different device type.
50-
template <typename T_in, typename Device_in = Device>
51-
Psi(const Psi<T_in, Device_in>& psi_in);
52-
53-
// Constructor 3-1: 2D Psi version
54-
// used in hsolver-pw function pointer and somewhere.
45+
// Constructor 2-1: 2D Psi version, used in hsolver-pw function pointer and somewhere.
5546
Psi(T* psi_pointer,
5647
const int nk_in,
5748
const int nbd_in,
5849
const int nbs_in,
5950
const int current_nbasis_in,
6051
const bool k_first_in = true);
6152

62-
// Constructor 3-2: 2D Psi version
53+
// Constructor 2-2: 2D Psi version
6354
Psi(const int nk_in, const int nbd_in, const int nbs_in, const int current_nbasis_in, const bool k_first_in);
6455

56+
// Copy Constructor 3-1: initialize a new psi from the given psi_in
57+
Psi(const Psi& psi_in);
58+
59+
// Copy Constructor 3-2: initialize a new psi from the given psi_in with a different class template
60+
// in this case, psi_in may have a different device type.
61+
template <typename T_in, typename Device_in = Device>
62+
Psi(const Psi<T_in, Device_in>& psi_in);
63+
64+
6565
// Destructor for deleting the psi array manually
6666
~Psi();
6767

@@ -120,6 +120,8 @@ class Psi
120120

121121
const int* get_ngk_pointer() const;
122122

123+
const std::vector<int>& get_ngk_vector() const;
124+
123125
// return k_first
124126
const bool& get_k_first() const;
125127

@@ -156,6 +158,7 @@ class Psi
156158
mutable int psi_bias = 0;
157159

158160
const int* ngk = nullptr;
161+
std::vector<int> ngk_vector;
159162

160163
bool k_first = true;
161164

0 commit comments

Comments
 (0)