Skip to content

Commit 9e4b889

Browse files
committed
Fix a bug where CopyFrom caused shared data between tensors, using =(assignment operator overload) instead
1 parent e67b42f commit 9e4b889

File tree

5 files changed

+136
-43
lines changed

5 files changed

+136
-43
lines changed

source/module_hamilt_lcao/module_tddft/evolve_elec.cpp

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,29 +62,32 @@ void Evolve_elec::solve_psi(const int& istep,
6262
}
6363
else if (htype == 1)
6464
{
65-
evolve_psi(nband,
66-
nlocal,
67-
&(para_orb),
68-
phm,
69-
psi[0].get_pointer(),
70-
psi_laststep[0].get_pointer(),
71-
Hk_laststep[ik],
72-
Sk_laststep[ik],
73-
&(ekb(ik, 0)),
74-
htype,
75-
propagator);
76-
65+
// const bool use_tensor = true;
7766
const bool use_tensor = false;
78-
if (use_tensor)
67+
if (!use_tensor)
7968
{
80-
std::cout << "Print ekb: " << std::endl;
81-
ekb.print(std::cout);
82-
std::cout << "nband = " << nband << std::endl;
83-
std::cout << "psi->get_nbands() = " << psi->get_nbands() << std::endl;
84-
std::cout << "nlocal = " << nlocal << std::endl;
85-
std::cout << "psi->get_nbasis() = " << psi->get_nbasis() << std::endl;
86-
std::cout << "ekb.nr = " << ekb.nr << std::endl;
87-
std::cout << "ekb.nc = " << ekb.nc << std::endl;
69+
evolve_psi(nband,
70+
nlocal,
71+
&(para_orb),
72+
phm,
73+
psi[0].get_pointer(),
74+
psi_laststep[0].get_pointer(),
75+
Hk_laststep[ik],
76+
Sk_laststep[ik],
77+
&(ekb(ik, 0)),
78+
htype,
79+
propagator);
80+
// std::cout << "Print ekb: " << std::endl;
81+
// ekb.print(std::cout);
82+
}
83+
else
84+
{
85+
// std::cout << "nband = " << nband << std::endl;
86+
// std::cout << "psi->get_nbands() = " << psi->get_nbands() << std::endl;
87+
// std::cout << "nlocal = " << nlocal << std::endl;
88+
// std::cout << "psi->get_nbasis() = " << psi->get_nbasis() << std::endl;
89+
// std::cout << "ekb.nr = " << ekb.nr << std::endl;
90+
// std::cout << "ekb.nc = " << ekb.nc << std::endl;
8891

8992
// Create TensorMap for psi_k, psi_k_laststep, H_laststep, S_laststep, ekb
9093
container::TensorMap psi_k_tensor(psi[0].get_pointer(),
@@ -131,8 +134,20 @@ void Evolve_elec::solve_psi(const int& istep,
131134
// &(ekb(ik, 0)),
132135
// htype,
133136
// propagator);
134-
std::cout << "Print ekb tensor: " << std::endl;
135-
ekb.print(std::cout);
137+
// std::cout << "Print ekb tensor: " << std::endl;
138+
// ekb.print(std::cout);
139+
140+
// std::cout << "Print psi_k (after evolve): " << std::endl;
141+
// for (int i = 0; i < psi->get_nbands(); ++i)
142+
// {
143+
// for (int j = 0; j < psi->get_nbasis(); ++j)
144+
// {
145+
// std::cout << "psi[" << i << "][" << j << "] = " << psi[0](i, j) << std::endl;
146+
// }
147+
// }
148+
149+
// std::cout << "Print psi_k_tensor (after evolve): " << std::endl;
150+
// print_tensor_data<std::complex<double>>(psi_k_tensor, "psi_k");
136151
}
137152
}
138153
else

source/module_hamilt_lcao/module_tddft/evolve_elec.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,75 @@
1717
// k is the index for the points in the first Brillouin zone
1818
//-----------------------------------------------------------
1919

20+
//------------------------ Debugging utility function ------------------------//
21+
22+
// Print the shape of a Tensor
23+
inline void print_tensor_shape(const container::Tensor& tensor, const std::string& name)
24+
{
25+
std::cout << "Shape of " << name << ": [";
26+
for (int i = 0; i < tensor.shape().ndim(); ++i)
27+
{
28+
std::cout << tensor.shape().dim_size(i);
29+
if (i < tensor.shape().ndim() - 1)
30+
{
31+
std::cout << ", ";
32+
}
33+
}
34+
std::cout << "]" << std::endl;
35+
}
36+
37+
// Recursive print function
38+
template <typename T>
39+
inline void print_tensor_data_recursive(const T* data,
40+
const std::vector<int64_t>& shape,
41+
const std::vector<int64_t>& strides,
42+
int dim,
43+
std::vector<int64_t>& indices,
44+
const std::string& name)
45+
{
46+
if (dim == shape.size())
47+
{
48+
// Recursion base case: print data when reaching the innermost dimension
49+
std::cout << name;
50+
for (size_t i = 0; i < indices.size(); ++i)
51+
{
52+
std::cout << "[" << indices[i] << "]";
53+
}
54+
std::cout << " = " << *data << std::endl;
55+
return;
56+
}
57+
// Recursively process the current dimension
58+
for (int64_t i = 0; i < shape[dim]; ++i)
59+
{
60+
indices[dim] = i;
61+
print_tensor_data_recursive(data + i * strides[dim], shape, strides, dim + 1, indices, name);
62+
}
63+
}
64+
65+
// Generic print function
66+
template <typename T>
67+
inline void print_tensor_data(const container::Tensor& tensor, const std::string& name)
68+
{
69+
const std::vector<int64_t>& shape = tensor.shape().dims();
70+
const std::vector<int64_t>& strides = tensor.shape().strides();
71+
const T* data = tensor.data<T>();
72+
std::vector<int64_t> indices(shape.size(), 0);
73+
print_tensor_data_recursive(data, shape, strides, 0, indices, name);
74+
}
75+
76+
// Specialization for std::complex<double>
77+
template <>
78+
inline void print_tensor_data<std::complex<double>>(const container::Tensor& tensor, const std::string& name)
79+
{
80+
const std::vector<int64_t>& shape = tensor.shape().dims();
81+
const std::vector<int64_t>& strides = tensor.shape().strides();
82+
const std::complex<double>* data = tensor.data<std::complex<double>>();
83+
std::vector<int64_t> indices(shape.size(), 0);
84+
print_tensor_data_recursive(data, shape, strides, 0, indices, name);
85+
}
86+
87+
//------------------------ Debugging utility function ------------------------//
88+
2089
namespace module_tddft
2190
{
2291
class Evolve_elec

source/module_hamilt_lcao/module_tddft/evolve_psi.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ void evolve_psi_tensor(const int nband,
128128

129129
#ifdef __MPI
130130

131-
int print_matrix = 1;
131+
int print_matrix = 0;
132132
hamilt::MatrixBlock<std::complex<double>> h_mat, s_mat;
133133
p_hamilt->matrix(h_mat, s_mat);
134134

@@ -223,7 +223,7 @@ void evolve_psi_tensor(const int nband,
223223

224224
#ifdef __MPI
225225

226-
int print_matrix = 1;
226+
int print_matrix = 0;
227227
hamilt::MatrixBlock<std::complex<double>> h_mat, s_mat;
228228
p_hamilt->matrix(h_mat, s_mat);
229229

source/module_hamilt_lcao/module_tddft/norm_psi.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,9 @@ void norm_psi_tensor(const Parallel_Orbitals* pv,
342342
} // loop ipcol
343343
} // loop iprow
344344

345-
// Copy psi_k to tmp1
346-
tmp1.CopyFrom(psi_k);
345+
// Copy psi_k to tmp1 (using deep copy)
346+
// tmp1.CopyFrom(psi_k); // Does not work because this will cause tmp1 and psi_k to share the same data
347+
tmp1 = psi_k; // operator= overload for Tensor class
347348

348349
// Perform matrix multiplication: psi_k = tmp1 * Cij
349350
ScalapackConnector::gemm('N',

source/module_hamilt_lcao/module_tddft/upsi.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ void upsi(const Parallel_Orbitals* pv,
4949
double aa, bb;
5050
aa = psi_k[i * pv->ncol + j].real();
5151
bb = psi_k[i * pv->ncol + j].imag();
52-
if (std::abs(aa) < 1e-8) {
52+
if (std::abs(aa) < 1e-8)
53+
{
5354
aa = 0.0;
54-
}
55-
if (std::abs(bb) < 1e-8) {
55+
}
56+
if (std::abs(bb) < 1e-8)
57+
{
5658
bb = 0.0;
57-
}
59+
}
5860
GlobalV::ofs_running << aa << "+" << bb << "i ";
5961
}
6062
GlobalV::ofs_running << std::endl;
@@ -68,12 +70,14 @@ void upsi(const Parallel_Orbitals* pv,
6870
double aa, bb;
6971
aa = psi_k_laststep[i * pv->ncol + j].real();
7072
bb = psi_k_laststep[i * pv->ncol + j].imag();
71-
if (std::abs(aa) < 1e-8) {
73+
if (std::abs(aa) < 1e-8)
74+
{
7275
aa = 0.0;
73-
}
74-
if (std::abs(bb) < 1e-8) {
76+
}
77+
if (std::abs(bb) < 1e-8)
78+
{
7579
bb = 0.0;
76-
}
80+
}
7781
GlobalV::ofs_running << aa << "+" << bb << "i ";
7882
}
7983
GlobalV::ofs_running << std::endl;
@@ -122,12 +126,14 @@ void upsi_tensor(const Parallel_Orbitals* pv,
122126
double aa, bb;
123127
aa = psi_k.data<std::complex<double>>()[i * pv->ncol + j].real();
124128
bb = psi_k.data<std::complex<double>>()[i * pv->ncol + j].imag();
125-
if (std::abs(aa) < 1e-8) {
129+
if (std::abs(aa) < 1e-8)
130+
{
126131
aa = 0.0;
127-
}
128-
if (std::abs(bb) < 1e-8) {
132+
}
133+
if (std::abs(bb) < 1e-8)
134+
{
129135
bb = 0.0;
130-
}
136+
}
131137
GlobalV::ofs_running << aa << "+" << bb << "i ";
132138
}
133139
GlobalV::ofs_running << std::endl;
@@ -141,12 +147,14 @@ void upsi_tensor(const Parallel_Orbitals* pv,
141147
double aa, bb;
142148
aa = psi_k_laststep.data<std::complex<double>>()[i * pv->ncol + j].real();
143149
bb = psi_k_laststep.data<std::complex<double>>()[i * pv->ncol + j].imag();
144-
if (std::abs(aa) < 1e-8) {
150+
if (std::abs(aa) < 1e-8)
151+
{
145152
aa = 0.0;
146-
}
147-
if (std::abs(bb) < 1e-8) {
153+
}
154+
if (std::abs(bb) < 1e-8)
155+
{
148156
bb = 0.0;
149-
}
157+
}
150158
GlobalV::ofs_running << aa << "+" << bb << "i ";
151159
}
152160
GlobalV::ofs_running << std::endl;

0 commit comments

Comments
 (0)