Skip to content

Commit 545d2ec

Browse files
Refactor: Use memory_op to set diag_const_nums (#5246)
* Use memory_op to set consts * Fix segfault * Pyabacus test * Revert changes * Modify pointer usage * I will win * Malloc test * Initialize with nullptr * Remove useless code --------- Co-authored-by: Haozhi Han <[email protected]>
1 parent 08ea40c commit 545d2ec

File tree

3 files changed

+53
-18
lines changed

3 files changed

+53
-18
lines changed

source/module_hsolver/diag_const_nums.cpp

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,60 @@ template class const_nums<std::complex<float>>;
99

1010
// Specialize templates to support double types
1111
template <>
12-
const_nums<double>::const_nums() : zero(0.0), one(1.0), neg_one(-1.0)
12+
const_nums<double>::const_nums()
1313
{
14+
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
15+
this->cpu_ctx, this->zero, 1);
16+
this->zero[0] = 0.0;
17+
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
18+
this->cpu_ctx, this->one, 1);
19+
this->one[0] = 1.0;
20+
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
21+
this->cpu_ctx, this->neg_one, 1);
22+
this->neg_one[0] = -1.0;
1423
}
1524

1625
// Specialize templates to support double types
1726
template <>
18-
const_nums<float>::const_nums() : zero(0.0), one(1.0), neg_one(-1.0)
27+
const_nums<float>::const_nums()
1928
{
29+
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
30+
this->cpu_ctx, this->zero, 1);
31+
this->zero[0] = 0.0;
32+
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
33+
this->cpu_ctx, this->one, 1);
34+
this->one[0] = 1.0;
35+
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
36+
this->cpu_ctx, this->neg_one, 1);
37+
this->neg_one[0] = -1.0;
2038
}
2139

2240
// Specialized templates to support std:: complex<double>types
2341
template <>
2442
const_nums<std::complex<double>>::const_nums()
25-
: zero(std::complex<double>(0.0, 0.0)), one(std::complex<double>(1.0, 0.0)),
26-
neg_one(std::complex<double>(-1.0, 0.0))
2743
{
44+
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
45+
this->cpu_ctx, this->zero, 1);
46+
this->zero[0] = std::complex<double>(0.0, 0.0);
47+
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
48+
this->cpu_ctx, this->one, 1);
49+
this->one[0] = std::complex<double>(1.0, 0.0);
50+
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
51+
this->cpu_ctx, this->neg_one, 1);
52+
this->neg_one[0] = std::complex<double>(-1.0, 0.0);
2853
}
2954

3055
// Specialized templates to support std:: complex<float>types
3156
template <>
3257
const_nums<std::complex<float>>::const_nums()
33-
: zero(std::complex<float>(0.0, 0.0)), one(std::complex<float>(1.0, 0.0)), neg_one(std::complex<float>(-1.0, 0.0))
3458
{
35-
}
59+
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
60+
this->cpu_ctx, this->zero, 1);
61+
this->zero[0] = std::complex<float>(0.0, 0.0);
62+
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
63+
this->cpu_ctx, this->one, 1);
64+
this->one[0] = std::complex<float>(1.0, 0.0);
65+
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
66+
this->cpu_ctx, this->neg_one, 1);
67+
this->neg_one[0] = std::complex<float>(-1.0, 0.0);
68+
}
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
#ifndef DIAG_CONST_NUMS
22
#define DIAG_CONST_NUMS
3+
#include "module_base/module_device/memory_op.h"
34

45
template <typename T>
56
struct const_nums
67
{
78
const_nums();
8-
T zero;
9-
T one;
10-
T neg_one;
9+
base_device::DEVICE_CPU* cpu_ctx = {};
10+
T* zero = nullptr;
11+
T* one = nullptr;
12+
T* neg_one = nullptr;
1113
};
1214

1315
#endif

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
2525
{
2626
this->device = base_device::get_device_type<Device>(this->ctx);
2727

28-
this->one = &this->cs.one;
29-
this->zero = &this->cs.zero;
30-
this->neg_one = &this->cs.neg_one;
28+
this->one = this->cs.one;
29+
this->zero = this->cs.zero;
30+
this->neg_one = this->cs.neg_one;
3131

3232
assert(david_ndim_in > 1);
3333
assert(david_ndim_in * nband_in < nbasis_in * this->diag_comm.nproc);
@@ -534,8 +534,8 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,
534534
}
535535
else
536536
{
537-
std::vector<std::vector<T>> h_diag(nbase, std::vector<T>(nbase, cs.zero));
538-
std::vector<std::vector<T>> s_diag(nbase, std::vector<T>(nbase, cs.zero));
537+
std::vector<std::vector<T>> h_diag(nbase, std::vector<T>(nbase, cs.zero[0]));
538+
std::vector<std::vector<T>> s_diag(nbase, std::vector<T>(nbase, cs.zero[0]));
539539

540540
for (size_t i = 0; i < nbase; i++)
541541
{
@@ -564,10 +564,10 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,
564564

565565
for (size_t j = nbase; j < this->nbase_x; j++)
566566
{
567-
hcc[i * this->nbase_x + j] = cs.zero;
568-
hcc[j * this->nbase_x + i] = cs.zero;
569-
scc[i * this->nbase_x + j] = cs.zero;
570-
scc[j * this->nbase_x + i] = cs.zero;
567+
hcc[i * this->nbase_x + j] = cs.zero[0];
568+
hcc[j * this->nbase_x + i] = cs.zero[0];
569+
scc[i * this->nbase_x + j] = cs.zero[0];
570+
scc[j * this->nbase_x + i] = cs.zero[0];
571571
}
572572
}
573573
}

0 commit comments

Comments
 (0)