Skip to content

Commit b2b0303

Browse files
committed
add cell for pyabacus
1 parent 5329628 commit b2b0303

File tree

9 files changed

+368
-3
lines changed

9 files changed

+368
-3
lines changed

python/pyabacus/pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ wheel.expand-macos-universal-tags = true
3333
cmake.verbose = true
3434
logging.level = "INFO"
3535

36+
[tool.scikit-build.cmake.define]
37+
CMAKE_INSTALL_RPATH = "$ORIGIN"
38+
39+
[tool.setuptools]
40+
package-dir = {"pyabacus" = "src/pyabacus"}
41+
42+
[tool.setuptools.packages.find]
43+
where = ["src"]
44+
include = ["pyabacus*"]
3645

3746
[tool.pytest.ini_options]
3847
minversion = "6.0"

python/pyabacus/src/pyabacus/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

3-
__submodules__ = ["ModuleBase", "ModuleNAO", "hsolver"]
4-
3+
__submodules__ = ["ModuleBase", "ModuleNAO", "hsolver", "Cell", "IntegralCalculator"]
54
__all__ = list(__submodules__)
65

76
def __getattr__(attr):
@@ -13,4 +12,9 @@ def __getattr__(attr):
1312
return ModuleNAO
1413
elif attr == "hsolver":
1514
import pyabacus.hsolver as hsolver
16-
return hsolver
15+
return hsolver
16+
elif attr == "Cell":
17+
from .cell import Cell
18+
return Cell
19+
else:
20+
raise AttributeError(f"module {__name__} has no attribute {attr}")
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import numpy as np
2+
import os
3+
4+
class Cell:
5+
def __init__(self):
6+
self.atom = None
7+
self.a = None # Lattice vectors
8+
self.unit = 'Angstrom' # Default unit
9+
self.spin = 0 # Default spin
10+
self.charge = 0 # Default charge
11+
self.lattice_constant = 6.1416 #
12+
self.basis = None
13+
self.pseudo = None
14+
self.orbitals = []
15+
self.pseudo_potentials = {}
16+
self.pseudo_dir = ''
17+
self.orbital_dir = ''
18+
self.basis_type = ''
19+
self.built = False
20+
self._kspace = None
21+
self.precision = 1e-8 # Default precision
22+
self._mesh = None
23+
self.ke_cutoff = None
24+
self.rcut = None
25+
26+
@classmethod
27+
def from_file(cls, stru_file):
28+
cell = cls()
29+
cell._parse_stru(stru_file)
30+
cell._built = True
31+
return cell
32+
33+
def build(self):
34+
if self.atom is None:
35+
raise ValueError("Atom information must be set before building.")
36+
37+
if isinstance(self.atom, str):
38+
if self.atom.endswith('.xyz'):
39+
self._parse_xyz(self.atom)
40+
else:
41+
raise ValueError("Unsupported file format. Use .xyz files or provide atom list directly.")
42+
elif isinstance(self.atom, list):
43+
self.atoms = [[atom[0], np.array(atom[1])] for atom in self.atom]
44+
else:
45+
raise ValueError("Unsupported atom format.")
46+
47+
# Automatically set parameters based on precision
48+
self._set_auto_parameters()
49+
50+
self._built = True
51+
52+
def _set_auto_parameters(self):
53+
if self.a is not None:
54+
self.mesh = [int(np.ceil(np.linalg.norm(v) / self.precision)) for v in self.a] # TODO: Check the formula!
55+
else:
56+
self.mesh = [10, 10, 10] # Default mesh if lattice vectors are not set
57+
58+
self.ke_cutoff = -np.log(self.precision) * 10 # TODO: Check the formula!
59+
self.rcut = -np.log(self.precision) * 2 # TODO: Check the formula!
60+
61+
def _parse_stru(self, stru_file):
62+
self.atoms = []
63+
with open(stru_file, 'r') as f:
64+
lines = f.readlines()
65+
i = 0
66+
while i < len(lines):
67+
line = lines[i].strip()
68+
if 'ATOMIC_SPECIES' in line:
69+
i += 1
70+
while i < len(lines) and lines[i].strip():
71+
parts = lines[i].split()
72+
if len(parts) == 3:
73+
species, mass, pp_file = parts
74+
pp_file = pp_file.lstrip('./')
75+
self.pseudo_potentials[species] = {
76+
'mass': float(mass),
77+
'pseudo_file': pp_file
78+
}
79+
i += 1
80+
elif 'NUMERICAL_ORBITAL' in line:
81+
i += 1
82+
while i < len(lines) and lines[i].strip():
83+
orbital = lines[i].split()
84+
self.orbitals.append(orbital)
85+
i += 1
86+
elif 'LATTICE_CONSTANT' in line:
87+
i += 1
88+
self.lattice_constant = float(lines[i].strip())
89+
i += 1
90+
elif 'LATTICE_VECTORS' in line:
91+
self.a = np.array([
92+
list(map(float, lines[i+1].split())),
93+
list(map(float, lines[i+2].split())),
94+
list(map(float, lines[i+3].split()))
95+
])
96+
i += 4
97+
elif 'ATOMIC_POSITIONS' in line:
98+
i += 3
99+
while i < len(lines) and lines[i].strip():
100+
species = lines[i].strip()
101+
i += 2
102+
num_atoms = int(lines[i].strip())
103+
i += 1
104+
for _ in range(num_atoms):
105+
pos = np.array(list(map(float, lines[i].split()[:3])))
106+
self.atoms.append([species, pos])
107+
i += 2
108+
else:
109+
i += 1
110+
111+
def _parse_xyz(self, xyz_file):
112+
self.atoms = []
113+
with open(xyz_file, 'r') as f:
114+
lines = f.readlines()
115+
num_atoms = int(lines[0])
116+
# Skip the comment line
117+
for line in lines[2:2+num_atoms]:
118+
parts = line.split()
119+
species = parts[0]
120+
coords = np.array(list(map(float, parts[1:4])))
121+
self.atoms.append([species, coords])
122+
123+
def get_atom_positions(self):
124+
if not self._built:
125+
raise RuntimeError("Cell has not been built. Call build() first.")
126+
return np.array([atom[1] for atom in self.atoms])
127+
128+
def get_atom_species(self):
129+
if not self._built:
130+
raise RuntimeError("Cell has not been built. Call build() first.")
131+
return [atom[0] for atom in self.atoms]
132+
133+
@property
134+
def unit(self):
135+
return self._unit
136+
137+
@unit.setter
138+
def unit(self, value):
139+
if value.lower() in ['angstrom', 'a']:
140+
self._unit = 'Angstrom'
141+
elif value.lower() in ['bohr', 'b', 'au']:
142+
self._unit = 'Bohr'
143+
else:
144+
raise ValueError("Unit must be 'Angstrom' or 'Bohr'")
145+
146+
@property
147+
def lattice_constant(self):
148+
return self._lattice_constant
149+
150+
@lattice_constant.setter
151+
def lattice_constant(self, value):
152+
self._lattice_constant = value
153+
154+
@property
155+
def precision(self):
156+
return self._precision
157+
158+
@precision.setter
159+
def precision(self, value):
160+
if value <= 0:
161+
raise ValueError("Precision must be a positive number")
162+
self._precision = value
163+
164+
@property
165+
def kspace(self):
166+
return self._kspace
167+
168+
@kspace.setter
169+
def kspace(self, value):
170+
if value <= 0:
171+
raise ValueError("k-space must be a positive number")
172+
self._kspace = value
173+
174+
def make_kpts(self, mesh, with_gamma_point=True):
175+
if self.a is None:
176+
raise ValueError("Lattice vectors (self.a) must be set before generating k-points.")
177+
178+
kpts = []
179+
for i in range(mesh[0]):
180+
for j in range(mesh[1]):
181+
for k in range(mesh[2]):
182+
if with_gamma_point:
183+
kpt = np.array([i/mesh[0], j/mesh[1], k/mesh[2]])
184+
else:
185+
kpt = np.array([(i+0.5)/mesh[0], (j+0.5)/mesh[1], (k+0.5)/mesh[2]])
186+
kpts.append(kpt)
187+
188+
# Convert to cartesian coordinates
189+
recip_lattice = 2 * np.pi * np.linalg.inv(self.a.T)
190+
kpts = np.dot(kpts, recip_lattice)
191+
192+
return np.array(kpts)

python/pyabacus/tests/test_cell.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import unittest
2+
import numpy as np
3+
import tempfile
4+
import os
5+
from pyabacus import Cell
6+
7+
class TestCell(unittest.TestCase):
8+
def setUp(self):
9+
# Use the existing STRU file
10+
self.test_dir = os.path.dirname(os.path.abspath(__file__))
11+
self.stru_file = os.path.join(self.test_dir, 'test_cell', 'lcao_ZnO', 'STRU')
12+
13+
# Path for the XYZ file (assuming it exists in the test_cell folder)
14+
self.xyz_file = os.path.join(self.test_dir, 'test_cell', 'h2o.xyz')
15+
16+
def test_from_file(self):
17+
cell = Cell.from_file(self.stru_file)
18+
self.assertEqual(len(cell.atoms), 2)
19+
self.assertEqual(cell.get_atom_species(), ['Zn', 'O'])
20+
expected_lattice = np.array([
21+
[1.00, 0.00, 0.00],
22+
[-0.5, 0.866, 0.00],
23+
[0.00, 0.00, 1.6]
24+
])
25+
np.testing.assert_array_almost_equal(cell.a, expected_lattice)
26+
27+
def test_from_xyz_file(self):
28+
cell = Cell()
29+
cell.atom = self.xyz_file
30+
cell.build()
31+
self.assertEqual(len(cell.atoms), 3)
32+
self.assertEqual(cell.get_atom_species(), ['O', 'H', 'H'])
33+
34+
def test_pseudo_potentials(self):
35+
cell = Cell.from_file(self.stru_file)
36+
self.assertIn('Zn', cell.pseudo_potentials)
37+
self.assertIn('O', cell.pseudo_potentials)
38+
self.assertEqual(cell.pseudo_potentials['Zn']['pseudo_file'], 'Zn.LDA.UPF')
39+
self.assertEqual(cell.pseudo_potentials['O']['pseudo_file'], 'O.LDA.100.UPF')
40+
41+
42+
def test_atomic_positions(self):
43+
cell = Cell.from_file(self.stru_file)
44+
expected_positions = np.array([
45+
[0.00, 0.00, 0.00],
46+
[0.33333, 0.66667, 0.50]
47+
])
48+
np.testing.assert_array_almost_equal(cell.get_atom_positions(), expected_positions)
49+
50+
def test_build(self):
51+
cell = Cell()
52+
cell.atom = [['H', [0, 0, 0]], ['O', [0, 0, 1]], ['H', [0, 1, 0]]]
53+
cell.a = np.eye(3) * 3.0
54+
cell.build()
55+
self.assertTrue(cell._built)
56+
self.assertIsNotNone(cell.mesh)
57+
self.assertIsNotNone(cell.ke_cutoff)
58+
self.assertIsNotNone(cell.rcut)
59+
60+
def test_make_kpts(self):
61+
cell = Cell()
62+
cell.atom = [['H', [0, 0, 0]], ['O', [0, 0, 1]], ['H', [0, 1, 0]]]
63+
cell.a = np.eye(3) * 3.0
64+
cell.build()
65+
kpts = cell.make_kpts([2, 2, 2])
66+
self.assertEqual(kpts.shape, (8, 3))
67+
68+
def test_precision(self):
69+
cell = Cell()
70+
cell.precision = 1e-10
71+
self.assertEqual(cell.precision, 1e-10)
72+
with self.assertRaises(ValueError):
73+
cell.precision = -1
74+
75+
def test_unit(self):
76+
cell = Cell()
77+
cell.unit = 'Angstrom'
78+
self.assertEqual(cell.unit, 'Angstrom')
79+
cell.unit = 'Bohr'
80+
self.assertEqual(cell.unit, 'Bohr')
81+
with self.assertRaises(ValueError):
82+
cell.unit = 'invalid_unit'
83+
84+
85+
if __name__ == '__main__':
86+
unittest.main()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
3
2+
H2O molecule
3+
O 0.000000 0.000000 0.000000
4+
H 0.758602 0.000000 0.504284
5+
H 0.758602 0.000000 -0.504284
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
INPUT_PARAMETERS
2+
pseudo_dir ../../../tests/PP_ORB
3+
orbital_dir ../../../tests/PP_ORB
4+
nbands 24
5+
6+
calculation scf
7+
ecutwfc 120
8+
scf_thr 1.0e-8
9+
scf_nmax 100
10+
11+
smearing_method gaussian
12+
smearing_sigma 0.02
13+
14+
mixing_type broyden
15+
mixing_beta 0.4
16+
17+
basis_type lcao
18+
gamma_only 0
19+
symmetry 1
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
K_POINTS
2+
0
3+
Gamma
4+
4 4 4 0 0 0
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
ATOMIC_SPECIES
2+
Zn 1.000 ./Zn.LDA.UPF
3+
O 1.000 ./O.LDA.100.UPF
4+
5+
NUMERICAL_ORBITAL
6+
Zn_lda_8.0au_120Ry_2s2p2d
7+
O_lda_7.0au_50Ry_2s2p1d
8+
9+
LATTICE_CONSTANT
10+
6.1416
11+
12+
LATTICE_VECTORS
13+
1.00 0.00 0.00
14+
-0.5 0.866 0.00
15+
0.00 0.00 1.6
16+
17+
ATOMIC_POSITIONS
18+
Direct
19+
20+
Zn
21+
0.0
22+
1
23+
0.00 0.00 0.00 1 1 1
24+
25+
O
26+
0.0
27+
1
28+
0.33333 0.66667 0.50 1 1 1

0 commit comments

Comments
 (0)