44
55#ifdef __DEEPKS
66#include " deepks_basic.h"
7+
8+ #include " module_base/timer.h"
79#include " module_parameter/parameter.h"
810
911// d(Descriptor) / d(projected density matrix)
@@ -15,6 +17,7 @@ void DeePKS_domain::cal_gevdm(const int nat,
1517 std::vector<torch::Tensor>& gevdm)
1618{
1719 ModuleBase::TITLE (" DeePKS_domain" , " cal_gevdm" );
20+ ModuleBase::timer::tick (" DeePKS_domain" , " cal_gevdm" );
1821 // cal gevdm(d(EigenValue(D))/dD)
1922 int nlmax = inlmax / nat;
2023 for (int nl = 0 ; nl < nlmax; ++nl)
@@ -48,12 +51,14 @@ void DeePKS_domain::cal_gevdm(const int nat,
4851 gevdm.push_back (avmm);
4952 }
5053 assert (gevdm.size () == nlmax);
54+ ModuleBase::timer::tick (" DeePKS_domain" , " cal_gevdm" );
5155 return ;
5256}
5357
5458void DeePKS_domain::load_model (const std::string& model_file, torch::jit::script::Module& model)
5559{
5660 ModuleBase::TITLE (" DeePKS_domain" , " load_model" );
61+ ModuleBase::timer::tick (" DeePKS_domain" , " load_model" );
5762
5863 try
5964 {
@@ -62,8 +67,10 @@ void DeePKS_domain::load_model(const std::string& model_file, torch::jit::script
6267 catch (const c10::Error& e)
6368 {
6469 std::cerr << " error loading the model" << std::endl;
70+ ModuleBase::timer::tick (" DeePKS_domain" , " load_model" );
6571 return ;
6672 }
73+ ModuleBase::timer::tick (" DeePKS_domain" , " load_model" );
6774 return ;
6875}
6976
@@ -122,16 +129,17 @@ inline void generate_py_files(const int lmaxd, const int nmaxd, const std::strin
122129}
123130
124131void DeePKS_domain::cal_edelta_gedm_equiv (const int nat,
125- const int lmaxd,
126- const int nmaxd,
127- const int inlmax,
128- const int des_per_atom,
129- const int * inl_l,
130- const std::vector<torch::Tensor>& descriptor,
131- double ** gedm,
132- double & E_delta)
132+ const int lmaxd,
133+ const int nmaxd,
134+ const int inlmax,
135+ const int des_per_atom,
136+ const int * inl_l,
137+ const std::vector<torch::Tensor>& descriptor,
138+ double ** gedm,
139+ double & E_delta)
133140{
134141 ModuleBase::TITLE (" DeePKS_domain" , " cal_edelta_gedm_equiv" );
142+ ModuleBase::timer::tick (" DeePKS_domain" , " cal_edelta_gedm_equiv" );
135143
136144 LCAO_deepks_io::save_npy_d (nat,
137145 des_per_atom,
@@ -157,28 +165,32 @@ void DeePKS_domain::cal_edelta_gedm_equiv(const int nat,
157165
158166 std::string cmd = " rm -f cal_edelta_gedm.py basis.yaml ec.npy gedm.npy" ;
159167 std::system (cmd.c_str ());
168+
169+ ModuleBase::timer::tick (" DeePKS_domain" , " cal_edelta_gedm_equiv" );
170+ return ;
160171}
161172
162173// obtain from the machine learning model dE_delta/dDescriptor
163174// E_delta is also calculated here
164175void DeePKS_domain::cal_edelta_gedm (const int nat,
165- const int lmaxd,
166- const int nmaxd,
167- const int inlmax,
168- const int des_per_atom,
169- const int * inl_l,
170- const std::vector<torch::Tensor>& descriptor,
171- const std::vector<torch::Tensor>& pdm,
172- torch::jit::script::Module& model_deepks,
173- double ** gedm,
174- double & E_delta)
176+ const int lmaxd,
177+ const int nmaxd,
178+ const int inlmax,
179+ const int des_per_atom,
180+ const int * inl_l,
181+ const std::vector<torch::Tensor>& descriptor,
182+ const std::vector<torch::Tensor>& pdm,
183+ torch::jit::script::Module& model_deepks,
184+ double ** gedm,
185+ double & E_delta)
175186{
176187 if (PARAM.inp .deepks_equiv )
177188 {
178189 DeePKS_domain::cal_edelta_gedm_equiv (nat, lmaxd, nmaxd, inlmax, des_per_atom, inl_l, descriptor, gedm, E_delta);
179190 return ;
180191 }
181192 ModuleBase::TITLE (" DeePKS_domain" , " cal_edelta_gedm" );
193+ ModuleBase::timer::tick (" DeePKS_domain" , " cal_edelta_gedm" );
182194
183195 // forward
184196 std::vector<torch::jit::IValue> inputs;
@@ -213,6 +225,7 @@ void DeePKS_domain::cal_edelta_gedm(const int nat,
213225 }
214226 }
215227 }
228+ ModuleBase::timer::tick (" DeePKS_domain" , " cal_edelta_gedm" );
216229 return ;
217230}
218231
0 commit comments