Skip to content

Commit 799b019

Browse files
Fix GPU, but so ugly...
2 parents 1f9c6eb + bf4ce74 commit 799b019

File tree

24 files changed

+216
-166
lines changed

24 files changed

+216
-166
lines changed

.github/workflows/test.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,17 @@ jobs:
1717
volumes:
1818
- /tmp/ccache:/github/home/.ccache
1919
steps:
20-
- name: Checkout
20+
- name: Checkout repository
2121
uses: actions/checkout@v5
2222
with:
23-
submodules: recursive
2423
fetch-depth: 0
24+
# We will handle submodules manually after fixing ownership
25+
submodules: 'false'
26+
27+
- name: Take ownership of the workspace and update submodules
28+
run: |
29+
sudo chown -R $(whoami) .
30+
git submodule update --init --recursive
2531
2632
- name: Install CI tools
2733
run: |

docs/advanced/input_files/input-main.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,15 +1487,20 @@ These variables are used to control the geometry relaxation.
14871487

14881488
### relax_method
14891489

1490-
- **Type**: String
1490+
- **Type**: Vector of string
14911491
- **Description**: The methods to do geometry optimization.
1492+
the first element:
14921493
- cg: using the conjugate gradient (CG) algorithm. Note that there are two implementations of the conjugate gradient (CG) method, see [relax_new](#relax_new).
1493-
- bfgs: using the Broyden–Fletcher–Goldfarb–Shanno (BFGS) algorithm.
1494-
- bfgs_trad: using the traditional Broyden–Fletcher–Goldfarb–Shanno (BFGS) algorithm.
1494+
- bfgs : using the Broyden–Fletcher–Goldfarb–Shanno (BFGS) algorithm.
1495+
- lbfgs: using the Limited-memory Broyden–Fletcher–Goldfarb–Shanno (LBFGS) algorithm.
14951496
- cg_bfgs: using the CG method for the initial steps, and switching to BFGS method when the force convergence is smaller than [relax_cg_thr](#relax_cg_thr).
14961497
- sd: using the steepest descent (SD) algorithm.
14971498
- fire: the Fast Inertial Relaxation Engine method (FIRE), a kind of molecular-dynamics-based relaxation algorithm, is implemented in the molecular dynamics (MD) module. The algorithm can be used by setting [calculation](#calculation) to `md` and [md_type](#md_type) to `fire`. Also ionic velocities should be set in this case. See [fire](../md.md#fire) for more details.
1498-
- **Default**: cg
1499+
1500+
the second element:
1501+
when the first element is bfgs, if the second parameter is 1, it indicates the use of the new BFGS algorithm; if the second parameter is not 1, it indicates the use of the old BFGS algorithm.
1502+
- **Default**: cg 1
1503+
- **Note**:In the 3.10-LTS version, the type of this parameter is std::string. It can be set to "cg","bfgs","cg_bfgs","bfgs_trad","lbfgs","sd","fire".
14991504

15001505
### relax_new
15011506

source/source_io/module_parameter/input_parameter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <string>
77
#include <vector>
88

9+
910
// It stores all input parameters both defined in INPUT file and not defined in
1011
// INPUT file
1112
struct Input_para
@@ -150,7 +151,7 @@ struct Input_para
150151
// int bessel_nao_lmax; ///< lmax used in descriptor
151152

152153
// ============== #Parameters (4.Relaxation) ===========================
153-
std::string relax_method = "cg"; ///< methods to move_ion: sd, bfgs, cg...
154+
std::vector<std::string> relax_method = {"cg","1"}; ///< methods to move_ion: sd, bfgs, cg...
154155
bool relax_new = true;
155156
bool relax = false; ///< allow relaxation along the specific direction
156157
double relax_scale_force = 0.5;

source/source_io/read_input_item_relax.cpp

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,53 @@
55

66
namespace ModuleIO
77
{
8+
9+
810
void ReadInput::item_relax()
911
{
1012
{
1113
Input_Item item("relax_method");
1214
item.annotation = "cg; bfgs; sd; cg; cg_bfgs;";
13-
read_sync_string(input.relax_method);
14-
item.check_value = [](const Input_Item& item, const Parameter& para) {
15-
const std::vector<std::string> relax_methods = {"cg", "bfgs", "sd", "cg_bfgs","bfgs_trad","lbfgs"};
16-
if (std::find(relax_methods.begin(),relax_methods.end(), para.input.relax_method)==relax_methods.end())
17-
{
18-
const std::string warningstr = nofound_str(relax_methods, "relax_method");
19-
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
20-
}
15+
item.read_value = [](const Input_Item& item, Parameter& para) {
16+
if(item.get_size()==1)
17+
{
18+
para.input.relax_method[0] = item.str_values[0];
19+
para.input.relax_method[1] = "1";
20+
}
21+
else if(item.get_size()>=2)
22+
{
23+
para.input.relax_method[0] = item.str_values[0];
24+
para.input.relax_method[1] = item.str_values[1];
25+
}
26+
};
27+
item.check_value = [](const Input_Item& item, const Parameter& para) {
28+
const std::vector<std::string> relax_methods = {"cg", "sd", "cg_bfgs","lbfgs","bfgs"};
29+
if (std::find(relax_methods.begin(), relax_methods.end(), para.input.relax_method[0]) == relax_methods.end()) {
30+
const std::string warningstr = nofound_str(relax_methods, "relax_method");
31+
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
32+
}
2133
};
2234
this->add_item(item);
35+
36+
// Input_Item item("relax_method");
37+
// item.annotation = "cg; bfgs; sd; cg; cg_bfgs;";
38+
// read_sync_string(input.relax_method);
39+
// item.check_value = [](const Input_Item& item, const Parameter& para) {
40+
// const std::vector<std::string> relax_methods = {"cg", "bfgs_old", "sd", "cg_bfgs","bfgs","lbfgs"};
41+
// if (std::find(relax_methods.begin(),relax_methods.end(), para.input.relax_method)==relax_methods.end())
42+
// {
43+
// const std::string warningstr = nofound_str(relax_methods, "relax_method");
44+
// ModuleBase::WARNING_QUIT("ReadInput", warningstr);
45+
// }
46+
// };
47+
// this->add_item(item);
2348
}
2449
{
2550
Input_Item item("relax_new");
2651
item.annotation = "whether to use the new relaxation method";
2752
read_sync_bool(input.relax_new);
2853
item.reset_value = [](const Input_Item& item, Parameter& para) {
29-
if (para.input.relax_new && para.input.relax_method != "cg")
54+
if (para.input.relax_new && para.input.relax_method[0] != "cg")
3055
{
3156
para.input.relax_new = false;
3257
}

source/source_io/test/read_input_ptest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ TEST_F(InputParaTest, ParaRead)
114114
EXPECT_EQ(param.inp.fixed_axes, "None");
115115
EXPECT_FALSE(param.inp.fixed_ibrav);
116116
EXPECT_FALSE(param.inp.fixed_atoms);
117-
EXPECT_EQ(param.inp.relax_method, "cg");
117+
EXPECT_EQ(param.inp.relax_method[0], "cg");
118118
EXPECT_DOUBLE_EQ(param.inp.relax_cg_thr, 0.5);
119119
EXPECT_EQ(param.inp.out_level, "ie");
120120
EXPECT_TRUE(param.globalv.out_md_control);

source/source_io/test_serial/read_input_item_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ TEST_F(InputTest, Item_test)
766766
}
767767
{ // relax_method
768768
auto it = find_label("relax_method", readinput.input_lists);
769-
param.input.relax_method = "none";
769+
param.input.relax_method[0] = "none";
770770
testing::internal::CaptureStdout();
771771
EXPECT_EXIT(it->second.check_value(it->second, param), ::testing::ExitedWithCode(1), "");
772772
output = testing::internal::GetCapturedStdout();
@@ -775,12 +775,12 @@ TEST_F(InputTest, Item_test)
775775
{ //relax_new
776776
auto it = find_label("relax_new", readinput.input_lists);
777777
param.input.relax_new = true;
778-
param.input.relax_method = "cg";
778+
param.input.relax_method[0] = "cg";
779779
it->second.reset_value(it->second, param);
780780
EXPECT_EQ(param.input.relax_new, true);
781781

782782
param.input.relax_new = true;
783-
param.input.relax_method = "none";
783+
param.input.relax_method[0] = "none";
784784
it->second.reset_value(it->second, param);
785785
EXPECT_EQ(param.input.relax_new, false);
786786
}

source/source_io/write_mlkedf_descriptors.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,53 @@ void Write_MLKEDF_Descriptors::generateTrainData_KS(
6767
delete ptempRho;
6868
}
6969

70+
void Write_MLKEDF_Descriptors::generateTrainData_KS(
71+
const std::string& out_dir,
72+
psi::Psi<std::complex<float>> *psi,
73+
elecstate::ElecState *pelec,
74+
ModulePW::PW_Basis_K *pw_psi,
75+
ModulePW::PW_Basis *pw_rho,
76+
UnitCell& ucell,
77+
const double* veff
78+
)
79+
{
80+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> psi_double(*psi);
81+
82+
this->generateTrainData_KS(out_dir, &psi_double, pelec, pw_psi, pw_rho, ucell, veff);
83+
}
84+
85+
#if ((defined __CUDA) || (defined __ROCM))
86+
void Write_MLKEDF_Descriptors::generateTrainData_KS(
87+
const std::string& out_dir,
88+
psi::Psi<std::complex<double>, base_device::DEVICE_GPU>* psi,
89+
elecstate::ElecState *pelec,
90+
ModulePW::PW_Basis_K *pw_psi,
91+
ModulePW::PW_Basis *pw_rho,
92+
UnitCell& ucell,
93+
const double* veff
94+
)
95+
{
96+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> psi_cpu(*psi);
97+
98+
this->generateTrainData_KS(out_dir, &psi_cpu, pelec, pw_psi, pw_rho, ucell, veff);
99+
}
100+
101+
void Write_MLKEDF_Descriptors::generateTrainData_KS(
102+
const std::string& dir,
103+
psi::Psi<std::complex<float>, base_device::DEVICE_GPU>* psi,
104+
elecstate::ElecState *pelec,
105+
ModulePW::PW_Basis_K *pw_psi,
106+
ModulePW::PW_Basis *pw_rho,
107+
UnitCell& ucell,
108+
const double *veff
109+
)
110+
{
111+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> psi_cpu_double(*psi);
112+
113+
this->generateTrainData_KS(dir, &psi_cpu_double, pelec, pw_psi, pw_rho, ucell, veff);
114+
}
115+
#endif
116+
70117
void Write_MLKEDF_Descriptors::generate_descriptor(
71118
const std::string& out_dir,
72119
const double * const *prho,

source/source_io/write_mlkedf_descriptors.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,29 @@ class Write_MLKEDF_Descriptors
4040
ModulePW::PW_Basis *pw_rho,
4141
UnitCell& ucell,
4242
const double *veff
43-
){} // a mock function
43+
);
44+
45+
#if ((defined __CUDA) || (defined __ROCM))
46+
void generateTrainData_KS(
47+
const std::string& dir,
48+
psi::Psi<std::complex<double>, base_device::DEVICE_GPU>* psi,
49+
elecstate::ElecState *pelec,
50+
ModulePW::PW_Basis_K *pw_psi,
51+
ModulePW::PW_Basis *pw_rho,
52+
UnitCell& ucell,
53+
const double *veff
54+
);
55+
void generateTrainData_KS(
56+
const std::string& dir,
57+
psi::Psi<std::complex<float>, base_device::DEVICE_GPU>* psi,
58+
elecstate::ElecState *pelec,
59+
ModulePW::PW_Basis_K *pw_psi,
60+
ModulePW::PW_Basis *pw_rho,
61+
UnitCell& ucell,
62+
const double *veff
63+
);
64+
#endif
65+
4466
void generate_descriptor(
4567
const std::string& out_dir,
4668
const double * const *prho,

source/source_psi/psi.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ Psi<T, Device>::Psi(const Psi& psi_in)
171171
this->psi_current = this->psi + psi_in.get_psi_bias();
172172
}
173173

174-
175174
// Constructor 2-2:
176175
template <typename T, typename Device>
177176
template <typename T_in, typename Device_in>
@@ -545,6 +544,8 @@ template Psi<double, base_device::DEVICE_CPU>::Psi(const Psi<double, base_device
545544
template Psi<double, base_device::DEVICE_GPU>::Psi(const Psi<double, base_device::DEVICE_CPU>&);
546545
template Psi<std::complex<double>, base_device::DEVICE_CPU>::Psi(
547546
const Psi<std::complex<double>, base_device::DEVICE_GPU>&);
547+
template Psi<std::complex<double>, base_device::DEVICE_CPU>::Psi(
548+
const Psi<std::complex<float>, base_device::DEVICE_GPU>&);
548549
template Psi<std::complex<double>, base_device::DEVICE_GPU>::Psi(
549550
const Psi<std::complex<double>, base_device::DEVICE_CPU>&);
550551
template Psi<std::complex<float>, base_device::DEVICE_GPU>::Psi(

source/source_relax/ions_move_basic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ double Ions_Move_Basic::relax_bfgs_init = -1.0; // default is 0.5
2323
double Ions_Move_Basic::best_xxx = 1.0;
2424

2525
int Ions_Move_Basic::out_stru = 0;
26-
std::string Ions_Move_Basic::relax_method = "bfgs";
26+
std::vector<std::string> Ions_Move_Basic::relax_method = {"bfgs","2"};
2727

2828
void Ions_Move_Basic::setup_gradient(const UnitCell &ucell, const ModuleBase::matrix &force, double *pos, double *grad)
2929
{

0 commit comments

Comments
 (0)