Skip to content

Commit b36db7f

Browse files
committed
update psi code
1 parent bec10a2 commit b36db7f

File tree

2 files changed

+23
-25
lines changed

2 files changed

+23
-25
lines changed

source/module_psi/psi.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ Range::Range(const bool k_first_in, const size_t index_1_in, const size_t range_
3131
// Constructor 0: basic
3232
template <typename T, typename Device>
3333
Psi<T, Device>::Psi()
34-
{
35-
this->npol = PARAM.globalv.npol;
36-
}
34+
{}
3735

3836
template <typename T, typename Device>
3937
Psi<T, Device>::~Psi()
@@ -53,7 +51,6 @@ Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i
5351
assert(nbs_in > 0);
5452

5553
this->k_first = k_first_in;
56-
this->npol = PARAM.globalv.npol;
5754
this->allocate_inside = true;
5855

5956
this->ngk = ngk_in; // modify later
@@ -91,7 +88,6 @@ Psi<T, Device>::Psi(const int nk_in,
9188
assert(nbs_in > 0);
9289

9390
this->k_first = k_first_in;
94-
this->npol = PARAM.globalv.npol;
9591
this->allocate_inside = true;
9692

9793
this->ngk = ngk_in.data(); // modify later
@@ -129,7 +125,6 @@ Psi<T, Device>::Psi(T* psi_pointer,
129125
// assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func
130126

131127
this->k_first = k_first_in;
132-
this->npol = PARAM.globalv.npol;
133128
this->allocate_inside = false;
134129

135130
this->ngk = nullptr;
@@ -161,7 +156,6 @@ Psi<T, Device>::Psi(const int nk_in,
161156
assert(nk_in == 1);
162157

163158
this->k_first = k_first_in;
164-
this->npol = PARAM.globalv.npol;
165159
this->allocate_inside = true;
166160

167161
this->ngk = nullptr;
@@ -191,7 +185,6 @@ template <typename T, typename Device>
191185
Psi<T, Device>::Psi(const Psi& psi_in)
192186
{
193187
this->ngk = psi_in.ngk;
194-
this->npol = PARAM.globalv.npol;
195188
this->nk = psi_in.get_nk();
196189
this->nbands = psi_in.get_nbands();
197190
this->nbasis = psi_in.get_nbasis();
@@ -218,7 +211,6 @@ template <typename T_in, typename Device_in>
218211
Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
219212
{
220213
this->ngk = psi_in.get_ngk_pointer();
221-
this->npol = PARAM.globalv.npol;
222214
this->nk = psi_in.get_nk();
223215
this->nbands = psi_in.get_nbands();
224216
this->nbasis = psi_in.get_nbasis();
@@ -331,7 +323,7 @@ const int& Psi<T, Device>::get_psi_bias() const
331323
template <typename T, typename Device>
332324
const int& Psi<T, Device>::get_current_ngk() const
333325
{
334-
if (this->npol == 1)
326+
if (PARAM.inp.nspin != 4)
335327
{
336328
return this->current_nbasis;
337329
}
@@ -519,13 +511,13 @@ std::tuple<const T*, int> Psi<T, Device>::to_range(const Range& range) const
519511
else if (i1 < 0) // [r1, r2] is the range of index1 with length m
520512
{
521513
const T* p = &this->psi[r1 * (k_first ? this->nbands : this->nk) * this->nbasis];
522-
int m = (r2 - r1 + 1) * this->npol;
514+
int m = (r2 - r1 + 1) * this->get_npol();
523515
return std::tuple<const T*, int>(p, m);
524516
}
525517
else // [r1, r2] is the range of index2 with length m
526518
{
527519
const T* p = &this->psi[(i1 * (k_first ? this->nbands : this->nk) + r1) * this->nbasis];
528-
int m = (r2 - r1 + 1) * this->npol;
520+
int m = (r2 - r1 + 1) * this->get_npol();
529521
return std::tuple<const T*, int>(p, m);
530522
}
531523
}

source/module_psi/psi.h

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "module_base/module_device/memory_op.h"
55
#include "module_base/module_device/types.h"
6+
#include "module_parameter/parameter.h"
67

78
#include <tuple>
89
#include <vector>
@@ -134,7 +135,17 @@ class Psi
134135

135136
const int& get_current_ngk() const;
136137

137-
const int& get_npol() const {return this->npol;}
138+
const int& get_npol() const
139+
{
140+
if (PARAM.inp.nspin == 4)
141+
{
142+
return 2;
143+
}
144+
else
145+
{
146+
return 1;
147+
}
148+
}
138149

139150
// solve Range: return(pointer of begin, number of bands or k-points)
140151
std::tuple<const T*, int> to_range(const Range& range) const;
@@ -143,27 +154,22 @@ class Psi
143154
T* psi = nullptr; // avoid using C++ STL
144155

145156
Device* ctx = {}; // an context identifier for obtaining the device variable
146-
int npol = 1;
157+
bool allocate_inside = true; ///< whether allocate psi inside Psi class
158+
bool k_first = true;
159+
160+
const int* ngk = nullptr;
147161

148162
// dimensions
149163
int nk = 1; // number of k points
150164
int nbands = 1; // number of bands
151165
int nbasis = 1; // number of basis
152166

167+
// mutable values
153168
mutable int current_k = 0; // current k point
154169
mutable int current_b = 0; // current band index
155170
mutable int current_nbasis = 1; // current number of basis of current_k
156-
157-
// current pointer for getting the psi
158-
mutable T* psi_current = nullptr;
159-
// psi_current = psi + psi_bias;
160-
mutable int psi_bias = 0;
161-
162-
const int* ngk = nullptr;
163-
164-
bool k_first = true;
165-
166-
bool allocate_inside = true; ///< whether allocate psi inside Psi class
171+
mutable T* psi_current = nullptr; // current pointer for getting the psi
172+
mutable int psi_bias = 0; // psi_current = psi + psi_bias;
167173

168174
#ifdef __DSP
169175
using delete_memory_op = base_device::memory::delete_memory_op_mt<T, Device>;

0 commit comments

Comments
 (0)