Skip to content

Commit 37517e6

Browse files
committed
Combine some checking functions in DeePKS.
1 parent 12acec2 commit 37517e6

File tree

97 files changed

+29538
-6454
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+29538
-6454
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ OBJS_CELL=atom_pseudo.o\
201201

202202
OBJS_DEEPKS=LCAO_deepks.o\
203203
deepks_basic.o\
204+
deepks_check.o\
204205
deepks_descriptor.o\
205206
deepks_force.o\
206207
deepks_fpre.o\

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ if(ENABLE_DEEPKS)
22
list(APPEND objects
33
LCAO_deepks.cpp
44
deepks_basic.cpp
5+
deepks_check.cpp
56
deepks_descriptor.cpp
67
deepks_force.cpp
78
deepks_fpre.cpp

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#ifdef __DEEPKS
55

66
#include "deepks_basic.h"
7+
#include "deepks_check.h"
78
#include "deepks_descriptor.h"
89
#include "deepks_force.h"
910
#include "deepks_fpre.h"

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
175175

176176
if (PARAM.inp.deepks_out_unittest)
177177
{
178-
DeePKS_domain::check_gdmx(gdmx);
179-
DeePKS_domain::check_gvx(gvx, rank);
178+
DeePKS_domain::check_tensor<double>(gdmx, "gdmx.dat", rank);
179+
DeePKS_domain::check_tensor<double>(gvx, "gvx.dat", rank);
180180
}
181181
}
182182
}
@@ -198,8 +198,8 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
198198

199199
if (PARAM.inp.deepks_out_unittest)
200200
{
201-
DeePKS_domain::check_gdmepsl(gdmepsl);
202-
DeePKS_domain::check_gvepsl(gvepsl, rank);
201+
DeePKS_domain::check_tensor<double>(gdmepsl, "gdmepsl.dat", rank);
202+
DeePKS_domain::check_tensor<double>(gvepsl, "gvepsl.dat", rank);
203203
}
204204
}
205205
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#ifdef __DEEPKS
2+
3+
#include "deepks_check.h"
4+
5+
template <typename T>
6+
void DeePKS_domain::check_tensor(const torch::Tensor& tensor, const std::string& filename, const int rank)
7+
{
8+
if (rank != 0)
9+
{
10+
return;
11+
}
12+
using T_tensor = typename std::conditional<std::is_same<T, std::complex<double>>::value, c10::complex<double>, T>::type;
13+
14+
std::ofstream ofs(filename.c_str());
15+
ofs << std::setprecision(10);
16+
17+
auto sizes = tensor.sizes();
18+
int ndim = sizes.size();
19+
auto data_ptr = tensor.data_ptr<T_tensor>();
20+
int64_t numel = tensor.numel();
21+
22+
// stride for each dimension
23+
std::vector<int64_t> strides(ndim, 1);
24+
for (int i = ndim - 2; i >= 0; --i) {
25+
strides[i] = strides[i + 1] * sizes[i + 1];
26+
}
27+
28+
for (int64_t idx = 0; idx < numel; ++idx) {
29+
// index to multi-dimensional indices
30+
std::vector<int64_t> indices(ndim);
31+
int64_t tmp = idx;
32+
for (int d = 0; d < ndim; ++d) {
33+
indices[d] = tmp / strides[d];
34+
tmp = tmp % strides[d];
35+
}
36+
37+
T_tensor tmp_val = data_ptr[idx];
38+
T* tmp_ptr = reinterpret_cast<T*>(&tmp_val);
39+
ofs << *tmp_ptr;
40+
41+
// print space or newline
42+
if ( ( (idx+1) % sizes[ndim-1] ) == 0 ) {
43+
ofs << std::endl;
44+
} else {
45+
ofs << " ";
46+
}
47+
}
48+
49+
ofs.close();
50+
}
51+
52+
template void DeePKS_domain::check_tensor<double>(const torch::Tensor& tensor, const std::string& filename, const int rank);
53+
template void DeePKS_domain::check_tensor<std::complex<double>>(const torch::Tensor& tensor, const std::string& filename, const int rank);
54+
55+
#endif
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#ifndef DEEPKS_CHECK_H
2+
#define DEEPKS_CHECK_H
3+
4+
#ifdef __DEEPKS
5+
6+
#include <string>
7+
#include <torch/script.h>
8+
#include <torch/torch.h>
9+
10+
namespace DeePKS_domain
11+
{
12+
//------------------------
13+
// deepks_check.cpp
14+
//------------------------
15+
16+
// This file contains subroutines for checking files
17+
18+
// There are 1 subroutines in this file:
19+
// 1. check_tensor, which is used for tensor data checking
20+
21+
template <typename T>
22+
void check_tensor(const torch::Tensor& tensor, const std::string& filename, const int rank);
23+
24+
} // namespace DeePKS_domain
25+
26+
#endif
27+
#endif

source/module_hamilt_lcao/module_deepks/deepks_fpre.cpp

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -142,53 +142,6 @@ void DeePKS_domain::cal_gdmx(const int lmaxd,
142142
return;
143143
}
144144

145-
void DeePKS_domain::check_gdmx(const torch::Tensor& gdmx)
146-
{
147-
std::stringstream ss;
148-
std::ofstream ofs_x;
149-
std::ofstream ofs_y;
150-
std::ofstream ofs_z;
151-
152-
ofs_x << std::setprecision(10);
153-
ofs_y << std::setprecision(10);
154-
ofs_z << std::setprecision(10);
155-
156-
// size: [3][natom][inlmax][nm][nm]
157-
auto size = gdmx.sizes();
158-
auto accessor = gdmx.accessor<double, 5>();
159-
for (int ia = 0; ia < size[1]; ia++)
160-
{
161-
ss.str("");
162-
ss << "gdmx_" << ia << ".dat";
163-
ofs_x.open(ss.str().c_str());
164-
ss.str("");
165-
ss << "gdmy_" << ia << ".dat";
166-
ofs_y.open(ss.str().c_str());
167-
ss.str("");
168-
ss << "gdmz_" << ia << ".dat";
169-
ofs_z.open(ss.str().c_str());
170-
171-
for (int inl = 0; inl < size[2]; inl++)
172-
{
173-
for (int m1 = 0; m1 < size[3]; m1++)
174-
{
175-
for (int m2 = 0; m2 < size[4]; m2++)
176-
{
177-
ofs_x << accessor[0][ia][inl][m1][m2] << " ";
178-
ofs_y << accessor[1][ia][inl][m1][m2] << " ";
179-
ofs_z << accessor[2][ia][inl][m1][m2] << " ";
180-
}
181-
}
182-
ofs_x << std::endl;
183-
ofs_y << std::endl;
184-
ofs_z << std::endl;
185-
}
186-
ofs_x.close();
187-
ofs_y.close();
188-
ofs_z.close();
189-
}
190-
}
191-
192145
// calculates gradient of descriptors from gradient of projected density matrices
193146
void DeePKS_domain::cal_gvx(const int nat,
194147
const int inlmax,
@@ -243,55 +196,6 @@ void DeePKS_domain::cal_gvx(const int nat,
243196
return;
244197
}
245198

246-
void DeePKS_domain::check_gvx(const torch::Tensor& gvx, const int rank)
247-
{
248-
std::stringstream ss;
249-
std::ofstream ofs_x;
250-
std::ofstream ofs_y;
251-
std::ofstream ofs_z;
252-
253-
if (rank != 0)
254-
{
255-
return;
256-
}
257-
258-
auto size = gvx.sizes();
259-
auto accessor = gvx.accessor<double, 4>();
260-
261-
for (int ia = 0; ia < size[0]; ia++)
262-
{
263-
ss.str("");
264-
ss << "gvx_" << ia << ".dat";
265-
ofs_x.open(ss.str().c_str());
266-
ss.str("");
267-
ss << "gvy_" << ia << ".dat";
268-
ofs_y.open(ss.str().c_str());
269-
ss.str("");
270-
ss << "gvz_" << ia << ".dat";
271-
ofs_z.open(ss.str().c_str());
272-
273-
ofs_x << std::setprecision(10);
274-
ofs_y << std::setprecision(10);
275-
ofs_z << std::setprecision(10);
276-
277-
for (int ib = 0; ib < size[2]; ib++)
278-
{
279-
for (int nlm = 0; nlm < size[3]; nlm++)
280-
{
281-
ofs_x << accessor[ia][0][ib][nlm] << " ";
282-
ofs_y << accessor[ia][1][ib][nlm] << " ";
283-
ofs_z << accessor[ia][2][ib][nlm] << " ";
284-
}
285-
ofs_x << std::endl;
286-
ofs_y << std::endl;
287-
ofs_z << std::endl;
288-
}
289-
ofs_x.close();
290-
ofs_y.close();
291-
ofs_z.close();
292-
}
293-
}
294-
295199
template void DeePKS_domain::cal_gdmx<double>(const int lmaxd,
296200
const int inlmax,
297201
const int nks,

source/module_hamilt_lcao/module_deepks/deepks_fpre.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@ namespace DeePKS_domain
2121
// deepks_fpre.cpp
2222
//------------------------
2323

24-
// This file contains 4 subroutines for calculating,
24+
// This file contains 2 subroutines for calculating,
2525
// 1. cal_gdmx, calculating gdmx
26-
// 2. check_gdmx, which prints gdmx to a series of .dat files
27-
// 3. cal_gvx : gvx is used for training with force label, which is gradient of descriptors,
26+
// 2. cal_gvx : gvx is used for training with force label, which is gradient of descriptors,
2827
// calculated by d(des)/dX = d(pdm)/dX * d(des)/d(pdm) = gdmx * gvdm
2928
// using einsum
30-
// 4. check_gvx : prints gvx into gvx.dat for checking
3129

3230
// calculate the gradient of pdm with regard to atomic positions
3331
// d/dX D_{Inl,mm'}
@@ -45,8 +43,6 @@ void cal_gdmx(const int lmaxd,
4543
const Grid_Driver& GridD,
4644
torch::Tensor& gdmx);
4745

48-
void check_gdmx(const torch::Tensor& gdmx);
49-
5046
/// calculates gradient of descriptors w.r.t atomic positions
5147
///----------------------------------------------------
5248
/// m, n: 2*l+1
@@ -64,7 +60,6 @@ void cal_gvx(const int nat,
6460
const torch::Tensor& gdmx,
6561
torch::Tensor& gvx,
6662
const int rank);
67-
void check_gvx(const torch::Tensor& gvx, const int rank);
6863

6964
} // namespace DeePKS_domain
7065
#endif

source/module_hamilt_lcao/module_deepks/deepks_orbpre.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -295,25 +295,6 @@ void DeePKS_domain::cal_orbital_precalc(const std::vector<TH>& dm_hl,
295295
return;
296296
}
297297

298-
void DeePKS_domain::check_orbpre(const torch::Tensor& orbpre)
299-
{
300-
auto sizes = orbpre.sizes();
301-
auto accessor = orbpre.accessor<double, 3>();
302-
std::ofstream ofs("orbital_precalc.dat");
303-
for (int iknb = 0; iknb < sizes[0]; iknb++)
304-
{
305-
for (int iat = 0; iat < sizes[1]; iat++)
306-
{
307-
for (int m = 0; m < sizes[2]; m++)
308-
{
309-
ofs << accessor[iknb][iat][m] << " ";
310-
}
311-
ofs << std::endl;
312-
}
313-
ofs << std::endl;
314-
}
315-
}
316-
317298
template void DeePKS_domain::cal_orbital_precalc<double, ModuleBase::matrix>(
318299
const std::vector<ModuleBase::matrix>& dm_hl,
319300
const int lmaxd,

source/module_hamilt_lcao/module_deepks/deepks_orbpre.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ void cal_orbital_precalc(const std::vector<TH>& dm_hl,
4141
const Parallel_Orbitals& pv,
4242
const Grid_Driver& GridD,
4343
torch::Tensor& orbital_precalc);
44-
45-
void check_orbpre(const torch::Tensor& orbpre);
4644
} // namespace DeePKS_domain
4745
#endif
4846
#endif

0 commit comments

Comments
 (0)