Skip to content

Commit b6e34e1

Browse files
authored
Merge pull request #793 from deepmodeling/deepks
Deepks
2 parents 363c2fd + 3b13ffd commit b6e34e1

File tree

151 files changed

+135917
-552
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

151 files changed

+135917
-552
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
env:
2222
GTEST_COLOR: 'yes'
2323
run: |
24-
cmake -B build -DBUILD_TESTING=ON
24+
cmake -B build -DBUILD_TESTING=ON -DENABLE_DEEPKS=ON
2525
cmake --build build -j8
2626
cmake --install build
2727
cmake --build build --target test ARGS="-V"

CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,14 @@ else()
167167
message(WARNING "Cannot find the correct library for Fortran.")
168168
endif()
169169
endif()
170-
target_link_libraries(${ABACUS_BIN_NAME} ${math_libs})
171170

172171
if(ENABLE_DEEPKS)
173172
set(CMAKE_CXX_STANDARD 14)
174173
find_package(Torch REQUIRED)
175174
include_directories(${TORCH_INCLUDE_DIRS})
176-
target_link_libraries(${ABACUS_BIN_NAME} ${TORCH_LIBRARIES})
175+
target_link_libraries(${ABACUS_BIN_NAME} deepks)
176+
list(APPEND math_libs ${TORCH_LIBRARIES})
177+
177178
add_compile_options(${TORCH_CXX_FLAGS})
178179

179180
find_path(libnpy_SOURCE_DIR
@@ -192,6 +193,8 @@ if(ENABLE_DEEPKS)
192193
add_compile_definitions(__DEEPKS)
193194
endif()
194195

196+
target_link_libraries(${ABACUS_BIN_NAME} ${math_libs})
197+
195198
if(DEFINED Libxc_DIR)
196199
set(ENABLE_LIBXC ON)
197200
endif()

doc/input-main.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
[nurse](#nurse) | [t_in_h](#t-in-h) | [vl_in_h](#vl-in-h) | [vnl_in_h](#vnl-in-h) | [test_force](#test-force) | [test_stress](#test-stress) | [colour](#colour) | [new_dm](#new-dm) | [test_just_neighbor](#test-just-neighbor)
6363
- [DeePKS](#deepks)
6464

65-
[out_descriptor](#out-descriptor) | [lmax_descriptor](#lmax-descriptor) | [deepks_scf](#deepks-scf) | [model_file](#model-file)
65+
[deepks_out_labels](#out-descriptor) | [deepks_descriptor_lmax](#lmax-descriptor) | [deepks_scf](#deepks-scf) | [deepks_model](#model-file)
6666

6767
[back to main page](../README.md)
6868

@@ -913,7 +913,7 @@ This part of variables are used to control the calculation of DOS.
913913
This part of variables are used to control the usage of DeePKS method (a comprehensive data-driven approach to improve accuracy of DFT).
914914
Warning: this function is not robust enough for version 2.2.0. Please try these variables in https://github.com/deepmodeling/abacus-develop/tree/deepks .
915915
916-
- out_descriptor<a id="out-descriptor"></a>
916+
- deepks_out_labels<a id="out-descriptor"></a>
917917
- *Type*: Boolean
918918
- *Description*: when set to 1, ABACUS will calculate and output descriptor for DeePKS training. In `LCAO` calculation, a path of *.orb file is needed to be specified under `NUMERICAL_DESCRIPTOR`in `STRU`file. For example:
919919
```
@@ -927,7 +927,7 @@ Warning: this function is not robust enough for version 2.2.0. Please try these
927927
- *Default*: 0
928928
929929
[back to top](#input-file)
930-
- lmax_descriptor<a id="lmax-descriptor"></a>
930+
- deepks_descriptor_lmax<a id="lmax-descriptor"></a>
931931
- *Type*: Integer
932932
- *Description*: control the max angular momentum of descriptor basis.
933933
- *Default*: 0
@@ -939,7 +939,7 @@ Warning: this function is not robust enough for version 2.2.0. Please try these
939939
- *Default*: 0
940940
941941
[back to top](#input-file)
942-
- model_file<a id="model-file"></a>
942+
- deepks_model<a id="model-file"></a>
943943
- *Type*: String
944944
- *Description*: the path of the trained, traced NN model file (generated by deepks-kit). used when deepks_scf is set to 1.
945945
- *Default*: None

examples/H2O-deepks-lcao/INPUT

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ mixing_beta 0.4
2626
#Parameters (6.Deepks)
2727
force 1
2828
test_force 1
29-
out_descriptor 1
30-
lmax_descriptor 2
29+
deepks_out_labels 1
30+
deepks_descriptor_lmax 2
3131
newdm 1
3232
deepks_scf 1
33-
model_file model.ptg
33+
deepks_model model.ptg

examples/H2O-deepks-pw/INPUT

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ mixing_beta 0.4
2525

2626
#Parameters (6.File)
2727
out_band 0
28-
out_descriptor 1
29-
lmax_descriptor 2
28+
deepks_out_labels 1
29+
deepks_descriptor_lmax 2

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ LCAO_diago.o\
159159
LCAO_evolve.o\
160160
LCAO_deepks.o\
161161
LCAO_deepks_fdelta.o\
162+
LCAO_deepks_odelta.o\
162163
LCAO_deepks_io.o\
163164
LCAO_deepks_mpi.o\
164165
LCAO_deepks_pdm.o\

source/input.cpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,11 @@ void Input::Default(void)
240240
out_charge = 0;
241241
out_dm = 0;
242242

243-
out_descriptor = 0; // caoyu added 2020-11-24, mohan added 2021-01-03
244-
lmax_descriptor = 2; // mohan added 2021-01-03
243+
deepks_out_labels = 0; // caoyu added 2020-11-24, mohan added 2021-01-03
244+
deepks_scf = 0;
245+
deepks_bandgap = 0;
246+
deepks_out_unittest = 0;
247+
deepks_descriptor_lmax = 2; // mohan added 2021-01-03
245248

246249
out_potential = 0;
247250
out_wf = 0;
@@ -913,22 +916,30 @@ bool Input::Read(const std::string &fn)
913916
{
914917
read_value(ifs, out_dm);
915918
}
916-
else if (strcmp("out_descriptor", word) == 0) // caoyu added 2020-11-24, mohan modified 2021-01-03
919+
else if (strcmp("deepks_out_labels", word) == 0) // caoyu added 2020-11-24, mohan modified 2021-01-03
917920
{
918-
read_value(ifs, out_descriptor);
921+
read_value(ifs, deepks_out_labels);
919922
}
920-
else if (strcmp("lmax_descriptor", word) == 0)// mohan added 2021-01-03
921-
{
922-
read_value(ifs, lmax_descriptor);
923-
}
924-
else if (strcmp("deepks_scf", word) == 0) // caoyu added 2021-06-02
923+
else if (strcmp("deepks_scf", word) == 0) // caoyu added 2020-11-24, mohan modified 2021-01-03
925924
{
926925
read_value(ifs, deepks_scf);
926+
}
927+
else if (strcmp("deepks_bandgap", word) == 0) // caoyu added 2020-11-24, mohan modified 2021-01-03
928+
{
929+
read_value(ifs, deepks_bandgap);
930+
}
931+
else if (strcmp("deepks_out_unittest", word) == 0)// mohan added 2021-01-03
932+
{
933+
read_value(ifs, deepks_out_unittest);
927934
}
928-
else if (strcmp("model_file", word) == 0) // caoyu added 2021-06-03
935+
else if (strcmp("deepks_model", word) == 0) // caoyu added 2021-06-03
929936
{
930-
read_value(ifs, model_file);
937+
read_value(ifs, deepks_model);
931938
}
939+
else if (strcmp("deepks_descriptor_lmax", word) == 0) // QO added 2021-12-15
940+
{
941+
read_value(ifs, deepks_descriptor_lmax);
942+
}
932943
else if (strcmp("out_potential", word) == 0)
933944
{
934945
read_value(ifs, out_potential);
@@ -1845,16 +1856,19 @@ void Input::Bcast()
18451856
Parallel_Common::bcast_string( charge_extrap );//xiaohui modify 2015-02-01
18461857
Parallel_Common::bcast_int( out_charge );
18471858
Parallel_Common::bcast_int( out_dm );
1848-
Parallel_Common::bcast_int( out_descriptor ); // caoyu added 2020-11-24, mohan modified 2021-01-03
1849-
Parallel_Common::bcast_int( lmax_descriptor ); // mohan modified 2021-01-03
1850-
Parallel_Common::bcast_int( deepks_scf ); // caoyu add 2021-06-02
1851-
Parallel_Common::bcast_string( model_file ); // caoyu add 2021-06-03
1859+
1860+
Parallel_Common::bcast_bool( deepks_out_labels ); // caoyu added 2020-11-24, mohan modified 2021-01-03
1861+
Parallel_Common::bcast_bool( deepks_scf );
1862+
Parallel_Common::bcast_bool( deepks_bandgap );
1863+
Parallel_Common::bcast_bool( deepks_out_unittest );
1864+
Parallel_Common::bcast_string( deepks_model );
1865+
Parallel_Common::bcast_int( deepks_descriptor_lmax );
18521866

18531867
Parallel_Common::bcast_int(out_potential);
18541868
Parallel_Common::bcast_int( out_wf );
18551869
Parallel_Common::bcast_int( out_wf_r );
18561870
Parallel_Common::bcast_int( out_dos );
1857-
Parallel_Common::bcast_int( out_band );
1871+
Parallel_Common::bcast_int( out_band );
18581872
Parallel_Common::bcast_int( out_hs );
18591873
Parallel_Common::bcast_int( out_hs2 ); //LiuXh add 2019-07-15
18601874
Parallel_Common::bcast_int( out_r_matrix ); // jingan add 2019-8-14

source/input.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,10 +383,18 @@ class Input
383383
//==========================================================
384384
// DeepKS -- added by caoyu and mohan
385385
//==========================================================
386-
int out_descriptor; // (need libnpy) output descritpor for deepks. caoyu added 2020-11-24, mohan modified 2021-01-03
387-
int lmax_descriptor; //lmax used in descriptor, mohan added 2021-01-03
388-
int deepks_scf; //(need libnpy and libtorch) if set 1, a trained model would be needed to cal V_delta and F_delta
389-
string model_file; //needed when deepks_scf=1
386+
bool deepks_out_labels; // (need libnpy) prints energy and force labels and descriptors for training, wenfei 2022-1-12
387+
bool deepks_scf; //(need libnpy and libtorch) if set 1, a trained model would be needed to cal V_delta and F_delta
388+
bool deepks_bandgap; //for bandgap label. QO added 2021-12-15
389+
390+
bool deepks_out_unittest; //if set 1, prints intermediate quantities that shall be used for making unit test
391+
392+
string deepks_model; //needed when deepks_scf=1
393+
394+
//the following 3 are used when generating jle.orb
395+
int deepks_descriptor_lmax; //lmax used in descriptor, mohan added 2021-01-03
396+
double deepks_descriptor_rcut;
397+
double deepks_descriptor_ecut;
390398

391399
//==========================================================
392400
// variables for test only

source/input_conv.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,14 +459,25 @@ void Input_Conv::Convert(void)
459459
// mohan add 2021-02-16
460460
berryphase::berry_phase_flag = INPUT.berry_phase;
461461

462-
ModuleBase::timer::tick("Input_Conv", "Convert");
462+
463463
//-----------------------------------------------
464464
// caoyu add for DeePKS
465465
//-----------------------------------------------
466466
#ifdef __DEEPKS
467-
GlobalV::out_descriptor = INPUT.out_descriptor;
468467
GlobalV::deepks_scf = INPUT.deepks_scf;
468+
GlobalV::deepks_bandgap = INPUT.deepks_bandgap; //QO added for bandgap label 2021-12-15
469+
GlobalV::deepks_out_unittest = INPUT.deepks_out_unittest;
470+
GlobalV::deepks_out_labels = INPUT.deepks_out_labels;
471+
if(GlobalV::deepks_out_unittest)
472+
{
473+
GlobalV::deepks_out_labels = 1;
474+
GlobalV::deepks_scf = 1;
475+
if (GlobalV::NPROC>1) ModuleBase::WARNING_QUIT("Input_conv","generate deepks unittest with only 1 processor");
476+
if (GlobalV::FORCE!=1) ModuleBase::WARNING_QUIT("Input_conv","force is required in generating deepks unittest");
477+
if (GlobalV::STRESS!=1) ModuleBase::WARNING_QUIT("Input_conv","stress is required in generating deepks unittest");
478+
}
479+
if(GlobalV::deepks_scf || GlobalV::deepks_out_labels) GlobalV::deepks_setorb = 1;
469480
#endif
470-
471-
return;
481+
ModuleBase::timer::tick("Input_Conv","Convert");
482+
return;
472483
}

source/module_base/global_variable.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,12 @@ double soc_lambda = 1.0;
179179

180180
bool FINAL_SCF = false; //LiuXh add 20180619
181181

182-
bool out_descriptor = false; //caoyu add 2021-10-16 for DeePKS
183-
bool deepks_scf = false; //caoyu add 2021-10-16 for DeePKS
182+
bool deepks_out_labels = false; //caoyu add 2021-10-16 for DeePKS, wenfei 2022-1-16
183+
bool deepks_scf = false; //caoyu add 2021-10-16 for DeePKS, wenfei 2022-1-16
184+
bool deepks_bandgap = false; //for bandgap label. QO added 2021-12-15
185+
bool deepks_out_unittest = false;
186+
187+
bool deepks_setorb = false;
184188

185189
int vnl_method = 1; //set defauld vnl method as old, added by zhengdy 2021-10-11
186190

0 commit comments

Comments
 (0)