Skip to content

Commit 95c4d9a

Browse files
authored
Fix: Name conflict in pyabacus. (#6466)
* Fix pyabacus conflict caused by using same name for member variable and member function. * Add support for reading NUMERICAL_DESCRIPTOR from STRU in pyabacus. * Add value check for TwoCenterIntegrator. * Update pyabacus test.
1 parent 9ab7b09 commit 95c4d9a

File tree

5 files changed

+30
-17
lines changed

5 files changed

+30
-17
lines changed

python/pyabacus/src/ModuleNAO/py_m_nao.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ void bind_m_nao(py::module& m)
8282
.def("symbol", &RadialCollection::symbol, "itype"_a)
8383
.def_property_readonly("ntype", &RadialCollection::ntype)
8484
.def("lmax", overload_cast_<const int>()(&RadialCollection::lmax, py::const_), "itype"_a)
85-
.def_property_readonly("lmax", overload_cast_<>()(&RadialCollection::lmax, py::const_))
85+
.def("lmax", overload_cast_<>()(&RadialCollection::lmax, py::const_))
8686
.def("rcut_max", overload_cast_<const int>()(&RadialCollection::rcut_max, py::const_), "itype"_a)
87-
.def_property_readonly("rcut_max", overload_cast_<>()(&RadialCollection::rcut_max, py::const_))
87+
.def("rcut_max", overload_cast_<>()(&RadialCollection::rcut_max, py::const_))
8888
.def("nzeta", &RadialCollection::nzeta, "itype"_a, "l"_a)
8989
.def("nzeta_max", overload_cast_<const int>()(&RadialCollection::nzeta_max, py::const_), "itype"_a)
9090
.def("nzeta_max", overload_cast_<>()(&RadialCollection::nzeta_max, py::const_))

python/pyabacus/src/pyabacus/ModuleNAO/_module_nao.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,21 @@ def symbol(self, itype: int) -> str:
8383
def ntype(self) -> int:
8484
return super().ntype
8585

86-
def lmax(self, itype: int) -> int:
87-
return super().lmax(itype)
88-
89-
@property
90-
def lmax(self) -> int:
91-
return super().lmax
92-
93-
def rcut_max(self, itype: int) -> float:
94-
return super().rcut_max(itype)
86+
@overload
87+
def lmax(self) -> int: ...
88+
@overload
89+
def lmax(self, itype: int) -> int: ...
9590

96-
@property
97-
def rcut_max(self) -> float:
98-
return super().rcut_max
91+
def lmax(self, *args, **kwargs):
92+
return super().lmax(*args, **kwargs)
93+
94+
@overload
95+
def rcut_max(self) -> float: ...
96+
@overload
97+
def rcut_max(self, itype: int) -> float: ...
98+
99+
def rcut_max(self, *args, **kwargs):
100+
return super().rcut_max(*args, **kwargs)
99101

100102
def nzeta(self, itype: int, l: int) -> int:
101103
return super().nzeta(itype, l)

python/pyabacus/src/pyabacus/io/stru.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def read_stru(fpath: str) -> Dict[str, Any]:
107107
"""Read an ABACUS STRU file and return its content as a dictionary."""
108108
block_title = ['ATOMIC_SPECIES',
109109
'NUMERICAL_ORBITAL',
110+
'NUMERICAL_DESCRIPTOR',
110111
'LATTICE_CONSTANT',
111112
'LATTICE_PARAMETER',
112113
'LATTICE_VECTORS',
@@ -144,6 +145,10 @@ def _trim(line):
144145
for i, s in enumerate(stru['species']):
145146
s['orb_file'] = blocks['NUMERICAL_ORBITAL'][i].strip()
146147

148+
#============ NUMERICAL_DESCRIPTOR ============
149+
if 'NUMERICAL_DESCRIPTOR' in blocks:
150+
stru['desc'] = blocks['NUMERICAL_DESCRIPTOR'][0].strip()
151+
147152
#============ ATOMIC_POSITIONS ============
148153
stru['coord_type'] = blocks['ATOMIC_POSITIONS'][0]
149154
index = {s['symbol']: i for i, s in enumerate(stru['species'])}

python/pyabacus/tests/test_m_nao.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def test_rc():
5050
assert orb.symbol(3) == 'Fe'
5151

5252
assert orb.ntype == 4
53-
assert orb.lmax == 3
54-
assert orb.rcut_max == 10.0
53+
assert orb.lmax() == 3
54+
assert orb.rcut_max() == 10.0
5555

5656
assert orb.nzeta(0,0) == 2
5757
assert orb.nzeta(1,0) == 2
@@ -85,7 +85,7 @@ def test_twocenterintegrator():
8585
alpha.build(1, [file_list[0]])
8686

8787
dr = 0.01 # R spacing
88-
rmax = max(orb.rcut_max, alpha.rcut_max)
88+
rmax = max(orb.rcut_max(), alpha.rcut_max())
8989
cutoff = 2.0 * rmax
9090
nr = int(rmax / dr) + 1
9191

source/source_basis/module_nao/two_center_integrator.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ void TwoCenterIntegrator::calculate(const int itype1,
4848
return;
4949
}
5050

51+
if (m1 > l1 || m1 < -l1 || m2 > l2 || m2 < -l2)
52+
{
53+
ModuleBase::WARNING("TwoCenterIntegrator", "m should be in range [-l, l].");
54+
return;
55+
}
56+
5157
// unit vector along R
5258
ModuleBase::Vector3<double> uR = (R == 0.0 ? ModuleBase::Vector3<double>(0., 0., 1.) : vR / R);
5359

0 commit comments

Comments
 (0)