Skip to content

Commit 8e3a58f

Browse files
committed
remove device value in psi
1 parent a02b5d8 commit 8e3a58f

File tree

2 files changed

+11
-25
lines changed

2 files changed

+11
-25
lines changed

source/module_psi/psi.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ template <typename T, typename Device>
3232
Psi<T, Device>::Psi()
3333
{
3434
this->npol = PARAM.globalv.npol;
35-
this->device = base_device::get_device_type<Device>(this->ctx);
3635
}
3736

3837
template <typename T, typename Device>
@@ -52,8 +51,9 @@ Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i
5251
this->current_b = 0;
5352
this->current_k = 0;
5453
this->npol = PARAM.globalv.npol;
55-
this->device = base_device::get_device_type<Device>(this->ctx);
54+
5655
this->resize(nk_in, nbd_in, nbs_in);
56+
5757
// Currently only GPU's implementation is supported for device recording!
5858
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
5959
base_device::information::record_device_memory<Device>(this->ctx,
@@ -76,7 +76,6 @@ Psi<T, Device>::Psi(T* psi_pointer,
7676
this->current_b = 0;
7777
this->current_k = 0;
7878
this->npol = PARAM.globalv.npol;
79-
this->device = base_device::get_device_type<Device>(this->ctx);
8079
this->nk = nk_in;
8180
this->nbands = nbd_in;
8281
this->nbasis = nbs_in;
@@ -96,7 +95,6 @@ Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int
9695
this->current_b = 0;
9796
this->current_k = 0;
9897
this->npol = PARAM.globalv.npol;
99-
this->device = base_device::get_device_type<Device>(this->ctx);
10098
this->nk = nk_in;
10199
this->nbands = nbd_in;
102100
this->nbasis = nbs_in;
@@ -111,13 +109,10 @@ Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int
111109
template <typename T, typename Device>
112110
Psi<T, Device>::Psi(const Psi& psi_in, const int nk_in, int nband_in)
113111
{
114-
assert(nk_in <= psi_in.get_nk());
115-
if (nband_in == 0)
116-
{
117-
nband_in = psi_in.get_nbands();
118-
}
112+
assert(nk_in <= psi_in.get_nk() && nk_in > 0);
113+
assert(nband_in <= psi_in.get_nbands() && nband_in > 0);
114+
119115
this->k_first = psi_in.get_k_first();
120-
this->device = psi_in.device;
121116
this->resize(nk_in, nband_in, psi_in.get_nbasis());
122117
this->ngk = psi_in.ngk;
123118
this->npol = psi_in.npol;
@@ -139,8 +134,6 @@ template <typename T, typename Device>
139134
Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in)
140135
{
141136
this->k_first = psi_in.get_k_first();
142-
this->device = base_device::get_device_type<Device>(this->ctx);
143-
assert(this->device == psi_in.device);
144137
assert(nk_in <= psi_in.get_nk());
145138
if (nband_in == 0)
146139
{
@@ -168,7 +161,7 @@ Psi<T, Device>::Psi(const Psi& psi_in)
168161
this->current_b = psi_in.get_current_b();
169162
this->k_first = psi_in.get_k_first();
170163
// this function will copy psi_in.psi to this->psi no matter the device types of each other.
171-
this->device = base_device::get_device_type<Device>(this->ctx);
164+
172165
this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis());
173166
base_device::memory::synchronize_memory_op<T, Device, Device>()(this->ctx,
174167
psi_in.get_device(),
@@ -193,7 +186,7 @@ Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
193186
this->current_b = psi_in.get_current_b();
194187
this->k_first = psi_in.get_k_first();
195188
// this function will copy psi_in.psi to this->psi no matter the device types of each other.
196-
this->device = base_device::get_device_type<Device>(this->ctx);
189+
197190
this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis());
198191

199192
// Specifically, if the Device_in type is CPU and the Device type is GPU.

source/module_psi/psi.h

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ class Psi
4343
Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true);
4444

4545
// Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in
46-
Psi(const Psi& psi_in, const int nk_in, int nband_in = 0);
47-
48-
46+
Psi(const Psi& psi_in, const int nk_in, const int nband_in);
47+
4948
// Constructor 5: a wrapper of a data pointer, used for Operator::hPsi()
5049
// in this case, fix_k can not be used
5150
Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in = 0);
@@ -69,13 +68,8 @@ class Psi
6968

7069
// Constructor 8-2: a pointer version of constructor 3
7170
// only used in operator.cpp call_act func
72-
Psi(T* psi_pointer,
73-
const int nk_in,
74-
const int nbd_in,
75-
const int nbs_in,
76-
const bool k_first_in);
71+
Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const bool k_first_in);
7772

78-
7973
// Destructor for deleting the psi array manually
8074
~Psi();
8175

@@ -141,8 +135,7 @@ class Psi
141135

142136
private:
143137
T* psi = nullptr; // avoid using C++ STL
144-
145-
base_device::AbacusDevice_t device = {}; // track the device type (CPU, GPU and SYCL are supported currented)
138+
146139
Device* ctx = {}; // an context identifier for obtaining the device variable
147140

148141
// dimensions

0 commit comments

Comments
 (0)