Skip to content

Commit 4848cfd

Browse files
sunliang98Fisherd99
authored andcommitted
Refactor: Remove cal_nelec and cal_nbands from unitcell.cpp. (deepmodeling#5694)
* Refactor: Move if (PARAM.inp.nspin == 4) into cal_ux(). * Refactor: Remove cal_nelec and cal_nbands from ucell.cpp. * Test: Update Unit Tests. * Fix: Fix the build error with PAW.
1 parent 0636fa6 commit 4848cfd

File tree

14 files changed

+222
-197
lines changed

14 files changed

+222
-197
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ OBJS_ELECSTAT=elecstate.o\
229229
H_TDDFT_pw.o\
230230
pot_xc.o\
231231
cal_ux.o\
232+
cal_nelec_nband.o\
232233
read_pseudo.o\
233234

234235
OBJS_ELECSTAT_LCAO=elecstate_lcao.o\

source/module_cell/cal_atoms_info.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef CAL_ATOMS_INFO_H
22
#define CAL_ATOMS_INFO_H
33
#include "module_parameter/parameter.h"
4-
#include "unitcell.h"
4+
#include "module_elecstate/cal_nelec_nband.h"
55
class CalAtomsInfo
66
{
77
public:
@@ -58,7 +58,7 @@ class CalAtomsInfo
5858
}
5959

6060
// calculate the total number of electrons
61-
cal_nelec(atoms, ntype, para.input.nelec);
61+
elecstate::cal_nelec(atoms, ntype, para.input.nelec);
6262

6363
// autoset and check GlobalV::NBANDS
6464
std::vector<double> nelec_spin(2, 0.0);
@@ -67,7 +67,7 @@ class CalAtomsInfo
6767
nelec_spin[0] = (para.inp.nelec + para.inp.nupdown ) / 2.0;
6868
nelec_spin[1] = (para.inp.nelec - para.inp.nupdown ) / 2.0;
6969
}
70-
cal_nbands(para.inp.nelec, para.sys.nlocal, nelec_spin, para.input.nbands);
70+
elecstate::cal_nbands(para.inp.nelec, para.sys.nlocal, nelec_spin, para.input.nbands);
7171
return;
7272
}
7373
};

source/module_cell/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ list(APPEND cell_simple_srcs
2626
../read_pp_blps.cpp
2727
../check_atomic_stru.cpp
2828
../../module_elecstate/read_pseudo.cpp
29+
../../module_elecstate/cal_nelec_nband.cpp
2930
)
3031

3132
add_library(cell_info OBJECT ${cell_simple_srcs})

source/module_cell/test/unitcell_test_readpp.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ Magnetism::~Magnetism() { delete[] this->start_magnetization; }
9090
* possible of an element
9191
* - CalNelec: UnitCell::cal_nelec
9292
* - calculate the total number of valence electrons from psp files
93-
* - CalNbands: elecstate::ElecState::cal_nbands()
93+
* - CalNbands: elecstate::cal_nbands()
9494
* - calculate the number of bands
9595
*/
9696

@@ -406,22 +406,22 @@ TEST_F(UcellTest, CalNelec) {
406406
EXPECT_EQ(1, ucell->atoms[0].na);
407407
EXPECT_EQ(2, ucell->atoms[1].na);
408408
double nelec = 0;
409-
cal_nelec(ucell->atoms, ucell->ntype, nelec);
409+
elecstate::cal_nelec(ucell->atoms, ucell->ntype, nelec);
410410
EXPECT_DOUBLE_EQ(6, nelec);
411411
}
412412

413413
TEST_F(UcellTest, CalNbands)
414414
{
415415
std::vector<double> nelec_spin(2, 5.0);
416-
cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
416+
elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
417417
EXPECT_EQ(PARAM.input.nbands, 6);
418418
}
419419

420420
TEST_F(UcellTest, CalNbandsFractionElec)
421421
{
422422
PARAM.input.nelec = 9.5;
423423
std::vector<double> nelec_spin(2, 5.0);
424-
cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
424+
elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
425425
EXPECT_EQ(PARAM.input.nbands, 6);
426426
}
427427

@@ -430,22 +430,22 @@ TEST_F(UcellTest, CalNbandsSOC)
430430
PARAM.input.lspinorb = true;
431431
PARAM.input.nbands = 0;
432432
std::vector<double> nelec_spin(2, 5.0);
433-
cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
433+
elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
434434
EXPECT_EQ(PARAM.input.nbands, 20);
435435
}
436436

437437
TEST_F(UcellTest, CalNbandsSDFT)
438438
{
439439
PARAM.input.esolver_type = "sdft";
440440
std::vector<double> nelec_spin(2, 5.0);
441-
EXPECT_NO_THROW(cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands));
441+
EXPECT_NO_THROW(elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands));
442442
}
443443

444444
TEST_F(UcellTest, CalNbandsLCAO)
445445
{
446446
PARAM.input.basis_type = "lcao";
447447
std::vector<double> nelec_spin(2, 5.0);
448-
EXPECT_NO_THROW(cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands));
448+
EXPECT_NO_THROW(elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands));
449449
}
450450

451451
TEST_F(UcellTest, CalNbandsLCAOINPW)
@@ -454,7 +454,7 @@ TEST_F(UcellTest, CalNbandsLCAOINPW)
454454
PARAM.sys.nlocal = PARAM.input.nbands - 1;
455455
std::vector<double> nelec_spin(2, 5.0);
456456
testing::internal::CaptureStdout();
457-
EXPECT_EXIT(cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands), ::testing::ExitedWithCode(1), "");
457+
EXPECT_EXIT(elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands), ::testing::ExitedWithCode(1), "");
458458
output = testing::internal::GetCapturedStdout();
459459
EXPECT_THAT(output, testing::HasSubstr("NLOCAL < NBANDS"));
460460
}
@@ -464,7 +464,7 @@ TEST_F(UcellTest, CalNbandsWarning1)
464464
PARAM.input.nbands = PARAM.input.nelec / 2 - 1;
465465
std::vector<double> nelec_spin(2, 5.0);
466466
testing::internal::CaptureStdout();
467-
EXPECT_EXIT(cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands), ::testing::ExitedWithCode(1), "");
467+
EXPECT_EXIT(elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands), ::testing::ExitedWithCode(1), "");
468468
output = testing::internal::GetCapturedStdout();
469469
EXPECT_THAT(output, testing::HasSubstr("Too few bands!"));
470470
}
@@ -477,7 +477,7 @@ TEST_F(UcellTest, CalNbandsWarning2)
477477
nelec_spin[0] = (PARAM.input.nelec + PARAM.input.nupdown ) / 2.0;
478478
nelec_spin[1] = (PARAM.input.nelec - PARAM.input.nupdown ) / 2.0;
479479
testing::internal::CaptureStdout();
480-
EXPECT_EXIT(cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands), ::testing::ExitedWithCode(1), "");
480+
EXPECT_EXIT(elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands), ::testing::ExitedWithCode(1), "");
481481
output = testing::internal::GetCapturedStdout();
482482
EXPECT_THAT(output, testing::HasSubstr("Too few spin up bands!"));
483483
}
@@ -490,7 +490,7 @@ TEST_F(UcellTest, CalNbandsWarning3)
490490
nelec_spin[0] = (PARAM.input.nelec + PARAM.input.nupdown ) / 2.0;
491491
nelec_spin[1] = (PARAM.input.nelec - PARAM.input.nupdown ) / 2.0;
492492
testing::internal::CaptureStdout();
493-
EXPECT_EXIT(cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands), ::testing::ExitedWithCode(1), "");
493+
EXPECT_EXIT(elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands), ::testing::ExitedWithCode(1), "");
494494
output = testing::internal::GetCapturedStdout();
495495
EXPECT_THAT(output, testing::HasSubstr("Too few spin down bands!"));
496496
}
@@ -500,7 +500,7 @@ TEST_F(UcellTest, CalNbandsSpin1)
500500
PARAM.input.nspin = 1;
501501
PARAM.input.nbands = 0;
502502
std::vector<double> nelec_spin(2, 5.0);
503-
cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
503+
elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
504504
EXPECT_EQ(PARAM.input.nbands, 15);
505505
}
506506

@@ -510,7 +510,7 @@ TEST_F(UcellTest, CalNbandsSpin1LCAO)
510510
PARAM.input.nbands = 0;
511511
PARAM.input.basis_type = "lcao";
512512
std::vector<double> nelec_spin(2, 5.0);
513-
cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
513+
elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
514514
EXPECT_EQ(PARAM.input.nbands, 6);
515515
}
516516

@@ -519,7 +519,7 @@ TEST_F(UcellTest, CalNbandsSpin4)
519519
PARAM.input.nspin = 4;
520520
PARAM.input.nbands = 0;
521521
std::vector<double> nelec_spin(2, 5.0);
522-
cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
522+
elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
523523
EXPECT_EQ(PARAM.input.nbands, 30);
524524
}
525525

@@ -529,7 +529,7 @@ TEST_F(UcellTest, CalNbandsSpin4LCAO)
529529
PARAM.input.nbands = 0;
530530
PARAM.input.basis_type = "lcao";
531531
std::vector<double> nelec_spin(2, 5.0);
532-
cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
532+
elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
533533
EXPECT_EQ(PARAM.input.nbands, 6);
534534
}
535535

@@ -538,7 +538,7 @@ TEST_F(UcellTest, CalNbandsSpin2)
538538
PARAM.input.nspin = 2;
539539
PARAM.input.nbands = 0;
540540
std::vector<double> nelec_spin(2, 5.0);
541-
cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
541+
elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
542542
EXPECT_EQ(PARAM.input.nbands, 16);
543543
}
544544

@@ -548,7 +548,7 @@ TEST_F(UcellTest, CalNbandsSpin2LCAO)
548548
PARAM.input.nbands = 0;
549549
PARAM.input.basis_type = "lcao";
550550
std::vector<double> nelec_spin(2, 5.0);
551-
cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
551+
elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands);
552552
EXPECT_EQ(PARAM.input.nbands, 6);
553553
}
554554

@@ -558,7 +558,7 @@ TEST_F(UcellTest, CalNbandsGaussWarning)
558558
std::vector<double> nelec_spin(2, 5.0);
559559
PARAM.input.smearing_method = "gaussian";
560560
testing::internal::CaptureStdout();
561-
EXPECT_EXIT(cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands), ::testing::ExitedWithCode(1), "");
561+
EXPECT_EXIT(elecstate::cal_nbands(PARAM.input.nelec, PARAM.sys.nlocal, nelec_spin, PARAM.input.nbands), ::testing::ExitedWithCode(1), "");
562562
output = testing::internal::GetCapturedStdout();
563563
EXPECT_THAT(output, testing::HasSubstr("for smearing, num. of bands > num. of occupied bands"));
564564
}

source/module_cell/test_pw/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ AddTest(
1313
LIBS parameter ${math_libs} base device
1414
SOURCES unitcell_test_pw.cpp ../unitcell.cpp ../read_atoms.cpp ../atom_spec.cpp
1515
../atom_pseudo.cpp ../pseudo.cpp ../read_pp.cpp ../read_pp_complete.cpp ../read_pp_upf201.cpp ../read_pp_upf100.cpp
16-
../read_pp_vwr.cpp ../read_pp_blps.cpp ../../module_io/output.cpp ../../module_elecstate/read_pseudo.cpp
16+
../read_pp_vwr.cpp ../read_pp_blps.cpp ../../module_io/output.cpp ../../module_elecstate/read_pseudo.cpp ../../module_elecstate/cal_nelec_nband.cpp
1717
)
1818

1919
find_program(BASH bash)

source/module_cell/unitcell.cpp

Lines changed: 0 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,160 +1328,6 @@ void UnitCell::remake_cell() {
13281328
}
13291329
}
13301330

1331-
void cal_nelec(const Atom* atoms, const int& ntype, double& nelec)
1332-
{
1333-
ModuleBase::TITLE("UnitCell", "cal_nelec");
1334-
GlobalV::ofs_running << "\n SETUP THE ELECTRONS NUMBER" << std::endl;
1335-
1336-
if (nelec == 0)
1337-
{
1338-
if (PARAM.inp.use_paw)
1339-
{
1340-
#ifdef USE_PAW
1341-
for (int it = 0; it < ntype; it++)
1342-
{
1343-
std::stringstream ss1, ss2;
1344-
ss1 << " electron number of element " << GlobalC::paw_cell.get_zat(it) << std::endl;
1345-
const int nelec_it = GlobalC::paw_cell.get_val(it) * atoms[it].na;
1346-
nelec += nelec_it;
1347-
ss2 << "total electron number of element " << GlobalC::paw_cell.get_zat(it);
1348-
1349-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, ss1.str(), GlobalC::paw_cell.get_val(it));
1350-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, ss2.str(), nelec_it);
1351-
}
1352-
#endif
1353-
}
1354-
else
1355-
{
1356-
for (int it = 0; it < ntype; it++)
1357-
{
1358-
std::stringstream ss1, ss2;
1359-
ss1 << "electron number of element " << atoms[it].label;
1360-
const double nelec_it = atoms[it].ncpp.zv * atoms[it].na;
1361-
nelec += nelec_it;
1362-
ss2 << "total electron number of element " << atoms[it].label;
1363-
1364-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, ss1.str(), atoms[it].ncpp.zv);
1365-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, ss2.str(), nelec_it);
1366-
}
1367-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "AUTOSET number of electrons: ", nelec);
1368-
}
1369-
}
1370-
if (PARAM.inp.nelec_delta != 0)
1371-
{
1372-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,
1373-
"nelec_delta is NOT zero, please make sure you know what you are "
1374-
"doing! nelec_delta: ",
1375-
PARAM.inp.nelec_delta);
1376-
nelec += PARAM.inp.nelec_delta;
1377-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "nelec now: ", nelec);
1378-
}
1379-
return;
1380-
}
1381-
1382-
void cal_nbands(const int& nelec, const int& nlocal, const std::vector<double>& nelec_spin, int& nbands)
1383-
{
1384-
if (PARAM.inp.esolver_type == "sdft") // qianrui 2021-2-20
1385-
{
1386-
return;
1387-
}
1388-
//=======================================
1389-
// calculate number of bands (setup.f90)
1390-
//=======================================
1391-
double occupied_bands = static_cast<double>(nelec / ModuleBase::DEGSPIN);
1392-
if (PARAM.inp.lspinorb == 1) {
1393-
occupied_bands = static_cast<double>(nelec);
1394-
}
1395-
1396-
if ((occupied_bands - std::floor(occupied_bands)) > 0.0)
1397-
{
1398-
occupied_bands = std::floor(occupied_bands) + 1.0; // mohan fix 2012-04-16
1399-
}
1400-
1401-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "occupied bands", occupied_bands);
1402-
1403-
if (nbands == 0)
1404-
{
1405-
if (PARAM.inp.nspin == 1)
1406-
{
1407-
const int nbands1 = static_cast<int>(occupied_bands) + 10;
1408-
const int nbands2 = static_cast<int>(1.2 * occupied_bands) + 1;
1409-
nbands = std::max(nbands1, nbands2);
1410-
if (PARAM.inp.basis_type != "pw") {
1411-
nbands = std::min(nbands, nlocal);
1412-
}
1413-
}
1414-
else if (PARAM.inp.nspin == 4)
1415-
{
1416-
const int nbands3 = nelec + 20;
1417-
const int nbands4 = static_cast<int>(1.2 * nelec) + 1;
1418-
nbands = std::max(nbands3, nbands4);
1419-
if (PARAM.inp.basis_type != "pw") {
1420-
nbands = std::min(nbands, nlocal);
1421-
}
1422-
}
1423-
else if (PARAM.inp.nspin == 2)
1424-
{
1425-
const double max_occ = std::max(nelec_spin[0], nelec_spin[1]);
1426-
const int nbands3 = static_cast<int>(max_occ) + 11;
1427-
const int nbands4 = static_cast<int>(1.2 * max_occ) + 1;
1428-
nbands = std::max(nbands3, nbands4);
1429-
if (PARAM.inp.basis_type != "pw") {
1430-
nbands = std::min(nbands, nlocal);
1431-
}
1432-
}
1433-
ModuleBase::GlobalFunc::AUTO_SET("NBANDS", nbands);
1434-
}
1435-
// else if ( PARAM.inp.calculation=="scf" || PARAM.inp.calculation=="md" || PARAM.inp.calculation=="relax") //pengfei
1436-
// 2014-10-13
1437-
else
1438-
{
1439-
if (nbands < occupied_bands) {
1440-
ModuleBase::WARNING_QUIT("unitcell", "Too few bands!");
1441-
}
1442-
if (PARAM.inp.nspin == 2)
1443-
{
1444-
if (nbands < nelec_spin[0])
1445-
{
1446-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "nelec_up", nelec_spin[0]);
1447-
ModuleBase::WARNING_QUIT("ElecState::cal_nbands", "Too few spin up bands!");
1448-
}
1449-
if (nbands < nelec_spin[1])
1450-
{
1451-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "nelec_down", nelec_spin[1]);
1452-
ModuleBase::WARNING_QUIT("ElecState::cal_nbands", "Too few spin down bands!");
1453-
}
1454-
}
1455-
}
1456-
1457-
// mohan add 2010-09-04
1458-
// std::cout << "nbands(this-> = " <<nbands <<std::endl;
1459-
if (nbands == occupied_bands)
1460-
{
1461-
if (PARAM.inp.smearing_method != "fixed")
1462-
{
1463-
ModuleBase::WARNING_QUIT("ElecState::cal_nbands", "for smearing, num. of bands > num. of occupied bands");
1464-
}
1465-
}
1466-
1467-
// mohan update 2021-02-19
1468-
// mohan add 2011-01-5
1469-
if (PARAM.inp.basis_type == "lcao" || PARAM.inp.basis_type == "lcao_in_pw")
1470-
{
1471-
if (nbands > nlocal)
1472-
{
1473-
ModuleBase::WARNING_QUIT("ElecState::cal_nbandsc", "NLOCAL < NBANDS");
1474-
}
1475-
else
1476-
{
1477-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "NLOCAL", nlocal);
1478-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "NBANDS", nbands);
1479-
}
1480-
}
1481-
1482-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "NBANDS", nbands);
1483-
}
1484-
14851331
void UnitCell::compare_atom_labels(std::string label1, std::string label2) {
14861332
if (label1
14871333
!= label2) //'!( "Ag" == "Ag" || "47" == "47" || "Silver" == Silver" )'

source/module_cell/unitcell.h

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -329,23 +329,4 @@ class UnitCell {
329329
std::vector<ModuleBase::Vector3<int>> get_constrain() const;
330330
};
331331

332-
/**
333-
* @brief calculate the total number of electrons in system
334-
*
335-
* @param atoms [in] atom pointer
336-
* @param ntype [in] number of atom types
337-
* @param nelec [out] total number of electrons
338-
*/
339-
void cal_nelec(const Atom* atoms, const int& ntype, double& nelec);
340-
341-
/**
342-
* @brief Calculate the number of bands.
343-
*
344-
* @param nelec [in] total number of electrons
345-
* @param nlocal [in] total number of local basis
346-
* @param nelec_spin [in] number of electrons for each spin
347-
* @param nbands [out] number of bands
348-
*/
349-
void cal_nbands(const int& nelec, const int& nlocal, const std::vector<double>& nelec_spin, int& nbands);
350-
351332
#endif // unitcell class

0 commit comments

Comments
 (0)