Skip to content

Commit 14c2036

Browse files
committed
Use memory_op to set consts
1 parent a5c35d9 commit 14c2036

File tree

3 files changed

+59
-18
lines changed

3 files changed

+59
-18
lines changed

source/module_hsolver/diag_const_nums.cpp

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,66 @@ template class const_nums<std::complex<float>>;
99

1010
// Specialize templates to support double types
1111
template <>
12-
const_nums<double>::const_nums() : zero(0.0), one(1.0), neg_one(-1.0)
12+
const_nums<double>::const_nums()
1313
{
14+
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
15+
this->cpu_ctx, this->zero, 1);
16+
base_device::memory::set_memory_op<double, base_device::DEVICE_CPU>()(
17+
this->cpu_ctx, this->zero, 0.0, 1);
18+
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
19+
this->cpu_ctx, this->one, 1);
20+
base_device::memory::set_memory_op<double, base_device::DEVICE_CPU>()(
21+
this->cpu_ctx, this->one, 1.0, 1);
22+
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
23+
this->cpu_ctx, this->neg_one, 1);
24+
base_device::memory::set_memory_op<double, base_device::DEVICE_CPU>()(
25+
this->cpu_ctx, this->neg_one, -1.0, 1);
1426
}
1527

1628
// Specialize templates to support double types
1729
template <>
18-
const_nums<float>::const_nums() : zero(0.0), one(1.0), neg_one(-1.0)
30+
const_nums<float>::const_nums()
1931
{
32+
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
33+
this->cpu_ctx, this->zero, 1);
34+
base_device::memory::set_memory_op<float, base_device::DEVICE_CPU>()(
35+
this->cpu_ctx, this->zero, 0.0, 1);
36+
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
37+
this->cpu_ctx, this->one, 1);
38+
base_device::memory::set_memory_op<float, base_device::DEVICE_CPU>()(
39+
this->cpu_ctx, this->one, 1.0, 1);
40+
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
41+
this->cpu_ctx, this->neg_one, 1);
42+
base_device::memory::set_memory_op<float, base_device::DEVICE_CPU>()(
43+
this->cpu_ctx, this->neg_one, -1.0, 1);
2044
}
2145

2246
// Specialized templates to support std:: complex<double>types
2347
template <>
2448
const_nums<std::complex<double>>::const_nums()
25-
: zero(std::complex<double>(0.0, 0.0)), one(std::complex<double>(1.0, 0.0)),
26-
neg_one(std::complex<double>(-1.0, 0.0))
2749
{
50+
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
51+
this->cpu_ctx, this->zero, 1);
52+
*this->zero = std::complex<double>(0.0, 0.0);
53+
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
54+
this->cpu_ctx, this->one, 1);
55+
*this->one = std::complex<double>(1.0, 0.0);
56+
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
57+
this->cpu_ctx, this->neg_one, 1);
58+
*this->neg_one = std::complex<double>(-1.0, 0.0);
2859
}
2960

3061
// Specialized templates to support std:: complex<float>types
3162
template <>
3263
const_nums<std::complex<float>>::const_nums()
33-
: zero(std::complex<float>(0.0, 0.0)), one(std::complex<float>(1.0, 0.0)), neg_one(std::complex<float>(-1.0, 0.0))
3464
{
35-
}
65+
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
66+
this->cpu_ctx, this->zero, 1);
67+
*this->zero = std::complex<float>(0.0, 0.0);
68+
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
69+
this->cpu_ctx, this->one, 1);
70+
*this->one = std::complex<float>(1.0, 0.0);
71+
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
72+
this->cpu_ctx, this->neg_one, 1);
73+
*this->neg_one = std::complex<float>(-1.0, 0.0);
74+
}
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
#ifndef DIAG_CONST_NUMS
22
#define DIAG_CONST_NUMS
3+
#include "module_base/module_device/memory_op.h"
34

45
template <typename T>
56
struct const_nums
67
{
78
const_nums();
8-
T zero;
9-
T one;
10-
T neg_one;
9+
base_device::DEVICE_CPU* cpu_ctx = {};
10+
T* zero;
11+
T* one;
12+
T* neg_one;
1113
};
1214

1315
#endif

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
2525
{
2626
this->device = base_device::get_device_type<Device>(this->ctx);
2727

28-
this->one = &this->cs.one;
29-
this->zero = &this->cs.zero;
30-
this->neg_one = &this->cs.neg_one;
28+
this->one = this->cs.one;
29+
this->zero = this->cs.zero;
30+
this->neg_one = this->cs.neg_one;
3131

3232
assert(david_ndim_in > 1);
3333
assert(david_ndim_in * nband_in < nbasis_in * this->diag_comm.nproc);
@@ -534,8 +534,8 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,
534534
}
535535
else
536536
{
537-
std::vector<std::vector<T>> h_diag(nbase, std::vector<T>(nbase, cs.zero));
538-
std::vector<std::vector<T>> s_diag(nbase, std::vector<T>(nbase, cs.zero));
537+
std::vector<std::vector<T>> h_diag(nbase, std::vector<T>(nbase, *cs.zero));
538+
std::vector<std::vector<T>> s_diag(nbase, std::vector<T>(nbase, *cs.zero));
539539

540540
for (size_t i = 0; i < nbase; i++)
541541
{
@@ -564,10 +564,10 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,
564564

565565
for (size_t j = nbase; j < this->nbase_x; j++)
566566
{
567-
hcc[i * this->nbase_x + j] = cs.zero;
568-
hcc[j * this->nbase_x + i] = cs.zero;
569-
scc[i * this->nbase_x + j] = cs.zero;
570-
scc[j * this->nbase_x + i] = cs.zero;
567+
hcc[i * this->nbase_x + j] = *cs.zero;
568+
hcc[j * this->nbase_x + i] = *cs.zero;
569+
scc[i * this->nbase_x + j] = *cs.zero;
570+
scc[j * this->nbase_x + i] = *cs.zero;
571571
}
572572
}
573573
}

0 commit comments

Comments
 (0)