Skip to content

Commit b9d0160

Browse files
committed
refactor psi code
1 parent a920924 commit b9d0160

File tree

2 files changed

+50
-50
lines changed

2 files changed

+50
-50
lines changed

source/module_psi/psi.cpp

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Range::Range(const bool k_first_in, const size_t index_1_in, const size_t range_
2828
range_2 = range_2_in;
2929
}
3030

31+
// Constructor 0: basic
3132
template <typename T, typename Device>
3233
Psi<T, Device>::Psi()
3334
{
@@ -43,16 +44,31 @@ Psi<T, Device>::~Psi()
4344
}
4445
}
4546

47+
// Constructor 1-1:
4648
template <typename T, typename Device>
4749
Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in)
4850
{
51+
assert(nk_in > 0);
52+
assert(nbd_in > 0);
53+
assert(nbs_in > 0);
54+
4955
this->k_first = k_first_in;
50-
this->ngk = ngk_in;
51-
this->current_b = 0;
52-
this->current_k = 0;
5356
this->npol = PARAM.globalv.npol;
57+
this->allocate_inside = true;
5458

55-
this->resize(nk_in, nbd_in, nbs_in);
59+
this->ngk = ngk_in; // modify later
60+
// This function will delete the psi array first(if psi exist), then malloc a new memory for it.
61+
resize_memory_op()(this->ctx, this->psi, nk_in * static_cast<std::size_t>(nbd_in) * nbs_in, "no_record");
62+
63+
this->nk = nk_in;
64+
this->nbands = nbd_in;
65+
this->nbasis = nbs_in;
66+
67+
this->current_b = 0;
68+
this->current_k = 0;
69+
this->current_nbasis = nbs_in;
70+
this->psi_current = this->psi;
71+
this->psi_bias = 0;
5672

5773
// Currently only GPU's implementation is supported for device recording!
5874
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
@@ -62,20 +78,35 @@ Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i
6278
sizeof(T) * nk_in * nbd_in * nbs_in);
6379
}
6480

81+
// Constructor 1-2:
6582
template <typename T, typename Device>
6683
Psi<T, Device>::Psi(const int nk_in,
6784
const int nbd_in,
6885
const int nbs_in,
6986
const std::vector<int>& ngk_in,
7087
const bool k_first_in)
7188
{
89+
assert(nk_in > 0);
90+
assert(nbd_in > 0);
91+
assert(nbs_in > 0);
92+
7293
this->k_first = k_first_in;
73-
this->ngk = ngk_in.data();
74-
this->current_b = 0;
75-
this->current_k = 0;
7694
this->npol = PARAM.globalv.npol;
95+
this->allocate_inside = true;
96+
97+
this->ngk = ngk_in.data(); // modify later
98+
// This function will delete the psi array first(if psi exist), then malloc a new memory for it.
99+
resize_memory_op()(this->ctx, this->psi, nk_in * static_cast<std::size_t>(nbd_in) * nbs_in, "no_record");
77100

78-
this->resize(nk_in, nbd_in, nbs_in);
101+
this->nk = nk_in;
102+
this->nbands = nbd_in;
103+
this->nbasis = nbs_in;
104+
105+
this->current_b = 0;
106+
this->current_k = 0;
107+
this->current_nbasis = nbs_in;
108+
this->psi_current = this->psi;
109+
this->psi_bias = 0;
79110

80111
// Currently only GPU's implementation is supported for device recording!
81112
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
@@ -85,7 +116,7 @@ Psi<T, Device>::Psi(const int nk_in,
85116
sizeof(T) * nk_in * nbd_in * nbs_in);
86117
}
87118

88-
// Constructor 8-1:
119+
// Constructor 3-1: 2D Psi version
89120
template <typename T, typename Device>
90121
Psi<T, Device>::Psi(T* psi_pointer,
91122
const int nk_in,
@@ -94,7 +125,6 @@ Psi<T, Device>::Psi(T* psi_pointer,
94125
const int current_nbasis_in,
95126
const bool k_first_in)
96127
{
97-
98128
// Currently this function only supports nk_in == 1 when called within diagH_subspace_init.
99129
// assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func
100130

@@ -103,7 +133,6 @@ Psi<T, Device>::Psi(T* psi_pointer,
103133
this->allocate_inside = false;
104134

105135
this->ngk = nullptr;
106-
107136
this->psi = psi_pointer;
108137

109138
this->nk = nk_in;
@@ -120,15 +149,14 @@ Psi<T, Device>::Psi(T* psi_pointer,
120149
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
121150
}
122151

123-
// Constructor 8-3: 2D Psi version 3
152+
// Constructor 3-2: 2D Psi version
124153
template <typename T, typename Device>
125154
Psi<T, Device>::Psi(const int nk_in,
126155
const int nbd_in,
127156
const int nbs_in,
128157
const int current_nbasis_in,
129158
const bool k_first_in)
130159
{
131-
132160
// Currently this function only supports nk_in == 1 when called within diagH_subspace_init.
133161
assert(nk_in == 1);
134162

@@ -158,37 +186,7 @@ Psi<T, Device>::Psi(const int nk_in,
158186
sizeof(T) * nk_in * nbd_in * nbs_in);
159187
}
160188

161-
// template <typename T, typename Device>
162-
// Psi<T, Device>::Psi(const Psi& psi_in, const int nk_in, const int nband_in)
163-
// {
164-
// assert(nk_in == 1);
165-
// assert(nband_in <= psi_in.get_nbands() && nband_in > 0);
166-
167-
// this->k_first = psi_in.get_k_first();
168-
// this->npol = psi_in.npol;
169-
// this->allocate_inside = true;
170-
171-
// this->nk = nk_in;
172-
// this->nbands = nband_in;
173-
// this->nbasis = psi_in.get_nbasis();
174-
175-
// // This function will delete the psi array first(if psi exist), then malloc a new memory for it.
176-
// resize_memory_op()(this->ctx,
177-
// this->psi,
178-
// (static_cast<std::size_t>(this->nk) * static_cast<std::size_t>(this->nbands)
179-
// * static_cast<std::size_t>(this->nbasis)),
180-
// "no_record");
181-
// synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size());
182-
183-
// this->current_k = 0;
184-
// this->current_b = 0;
185-
// this->current_nbasis = this->nbasis;
186-
// this->psi_current = this->psi;
187-
// this->psi_bias = 0;
188-
189-
// this->ngk = nullptr;
190-
// }
191-
189+
// Constructor 2-1:
192190
template <typename T, typename Device>
193191
Psi<T, Device>::Psi(const Psi& psi_in)
194192
{
@@ -213,6 +211,8 @@ Psi<T, Device>::Psi(const Psi& psi_in)
213211
this->psi_current = this->psi + psi_in.get_psi_bias();
214212
}
215213

214+
215+
// Constructor 2-2:
216216
template <typename T, typename Device>
217217
template <typename T_in, typename Device_in>
218218
Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)

source/module_psi/psi.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,31 +65,31 @@ class Psi
6565
// Constructor 3-2: 2D Psi version
6666
Psi(const int nk_in, const int nbd_in, const int nbs_in, const int current_nbasis_in, const bool k_first_in);
6767

68-
// // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in
69-
// Psi(const Psi& psi_in, const int nk_in, const int nband_in);
70-
7168
// Destructor for deleting the psi array manually
7269
~Psi();
7370

71+
// set psi value func 1
7472
void set_all_psi(const T* another_pointer, const std::size_t size_in);
7573

76-
// mark
74+
// set psi value func 2
7775
void zero_out();
7876

77+
// size_t size() const {return this->psi.size();}
78+
size_t size() const;
79+
7980
// allocate psi for three dimensions
8081
void resize(const int nks_in, const int nbands_in, const int nbasis_in);
8182

8283
// get the pointer for the 1st index
8384
T* get_pointer() const;
85+
8486
// get the pointer for the 2nd index (iband for k_first = true, ik for k_first = false)
8587
T* get_pointer(const int& ikb) const;
8688

8789
// interface to get three dimension size
8890
const int& get_nk() const;
8991
const int& get_nbands() const;
9092
const int& get_nbasis() const;
91-
// size_t size() const {return this->psi.size();}
92-
size_t size() const;
9393

9494
/// if k_first=true: choose k-point index , then Psi(iband, ibasis) can reach Psi(ik, iband, ibasis)
9595
/// if k_first=false: choose k-point index, then Psi(ibasis) can reach Psi(iband, ik, ibasis)

0 commit comments

Comments
 (0)