diff --git a/source/module_hsolver/diag_const_nums.cpp b/source/module_hsolver/diag_const_nums.cpp index 2c9f926c46..8b459cbf7c 100644 --- a/source/module_hsolver/diag_const_nums.cpp +++ b/source/module_hsolver/diag_const_nums.cpp @@ -9,27 +9,60 @@ template class const_nums>; // Specialize templates to support double types template <> -const_nums::const_nums() : zero(0.0), one(1.0), neg_one(-1.0) +const_nums::const_nums() { + base_device::memory::resize_memory_op()( + this->cpu_ctx, this->zero, 1); + this->zero[0] = 0.0; + base_device::memory::resize_memory_op()( + this->cpu_ctx, this->one, 1); + this->one[0] = 1.0; + base_device::memory::resize_memory_op()( + this->cpu_ctx, this->neg_one, 1); + this->neg_one[0] = -1.0; } // Specialize templates to support double types template <> -const_nums::const_nums() : zero(0.0), one(1.0), neg_one(-1.0) +const_nums::const_nums() { + base_device::memory::resize_memory_op()( + this->cpu_ctx, this->zero, 1); + this->zero[0] = 0.0; + base_device::memory::resize_memory_op()( + this->cpu_ctx, this->one, 1); + this->one[0] = 1.0; + base_device::memory::resize_memory_op()( + this->cpu_ctx, this->neg_one, 1); + this->neg_one[0] = -1.0; } // Specialized templates to support std:: complextypes template <> const_nums>::const_nums() - : zero(std::complex(0.0, 0.0)), one(std::complex(1.0, 0.0)), - neg_one(std::complex(-1.0, 0.0)) { + base_device::memory::resize_memory_op, base_device::DEVICE_CPU>()( + this->cpu_ctx, this->zero, 1); + this->zero[0] = std::complex(0.0, 0.0); + base_device::memory::resize_memory_op, base_device::DEVICE_CPU>()( + this->cpu_ctx, this->one, 1); + this->one[0] = std::complex(1.0, 0.0); + base_device::memory::resize_memory_op, base_device::DEVICE_CPU>()( + this->cpu_ctx, this->neg_one, 1); + this->neg_one[0] = std::complex(-1.0, 0.0); } // Specialized templates to support std:: complextypes template <> const_nums>::const_nums() - : zero(std::complex(0.0, 0.0)), one(std::complex(1.0, 0.0)), neg_one(std::complex(-1.0, 0.0)) { -} + base_device::memory::resize_memory_op, base_device::DEVICE_CPU>()( + this->cpu_ctx, this->zero, 1); + this->zero[0] = std::complex(0.0, 0.0); + base_device::memory::resize_memory_op, base_device::DEVICE_CPU>()( + this->cpu_ctx, this->one, 1); + this->one[0] = std::complex(1.0, 0.0); + base_device::memory::resize_memory_op, base_device::DEVICE_CPU>()( + this->cpu_ctx, this->neg_one, 1); + this->neg_one[0] = std::complex(-1.0, 0.0); +} \ No newline at end of file diff --git a/source/module_hsolver/diag_const_nums.h b/source/module_hsolver/diag_const_nums.h index c5a97d4d61..24a33194e1 100644 --- a/source/module_hsolver/diag_const_nums.h +++ b/source/module_hsolver/diag_const_nums.h @@ -1,13 +1,15 @@ #ifndef DIAG_CONST_NUMS #define DIAG_CONST_NUMS +#include "module_base/module_device/memory_op.h" template struct const_nums { const_nums(); - T zero; - T one; - T neg_one; + base_device::DEVICE_CPU* cpu_ctx = {}; + T* zero = nullptr; + T* one = nullptr; + T* neg_one = nullptr; }; #endif \ No newline at end of file diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index 28d64ef033..1bfd0a73a1 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -25,9 +25,9 @@ Diago_DavSubspace::Diago_DavSubspace(const std::vector& precond { this->device = base_device::get_device_type(this->ctx); - this->one = &this->cs.one; - this->zero = &this->cs.zero; - this->neg_one = &this->cs.neg_one; + this->one = this->cs.one; + this->zero = this->cs.zero; + this->neg_one = this->cs.neg_one; assert(david_ndim_in > 1); assert(david_ndim_in * nband_in < nbasis_in * this->diag_comm.nproc); @@ -534,8 +534,8 @@ void Diago_DavSubspace::diag_zhegvx(const int& nbase, } else { - std::vector> h_diag(nbase, std::vector(nbase, cs.zero)); - std::vector> s_diag(nbase, std::vector(nbase, cs.zero)); + std::vector> h_diag(nbase, std::vector(nbase, cs.zero[0])); + std::vector> s_diag(nbase, std::vector(nbase, cs.zero[0])); for (size_t i = 0; i < nbase; i++) { @@ -564,10 +564,10 @@ void Diago_DavSubspace::diag_zhegvx(const int& nbase, for (size_t j = nbase; j < this->nbase_x; j++) { - hcc[i * this->nbase_x + j] = cs.zero; - hcc[j * this->nbase_x + i] = cs.zero; - scc[i * this->nbase_x + j] = cs.zero; - scc[j * this->nbase_x + i] = cs.zero; + hcc[i * this->nbase_x + j] = cs.zero[0]; + hcc[j * this->nbase_x + i] = cs.zero[0]; + scc[i * this->nbase_x + j] = cs.zero[0]; + scc[j * this->nbase_x + i] = cs.zero[0]; } } }