Skip to content

Commit e2bf3f7

Browse files
committed
add unittest for convert cell
1 parent 53db9e2 commit e2bf3f7

File tree

2 files changed

+99
-9
lines changed

2 files changed

+99
-9
lines changed

dpdata/cp2k/output.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,59 @@
1717

1818
avail_patterns.append(re.compile(r'^ INITIAL POTENTIAL ENERGY'))
1919
avail_patterns.append(re.compile(r'^ ENSEMBLE TYPE'))
20+
21+
def cell_to_low_triangle(A,B,C,alpha,beta,gamma):
22+
"""
23+
Convert cell to low triangle matrix.
24+
25+
Parameters
26+
----------
27+
A : float
28+
cell length A
29+
B : float
30+
cell length B
31+
C : float
32+
cell length C
33+
alpha : float
34+
radian. The angle between vector B and vector C.
35+
beta : float
36+
radian. The angle between vector A and vector C.
37+
gamma : float
38+
radian. The angle between vector B and vector C.
39+
40+
Returns
41+
-------
42+
cell : list
43+
The cell matrix used by dpdata in low triangle form.
44+
"""
45+
if not np.pi*5/180<alpha< np.pi*175/180:
46+
raise RuntimeError("alpha=={}: must be a radian, and \
47+
must be in np.pi*5/180 < alpha < np.pi*175/180".format(alpha))
48+
if not np.pi*5/180<beta< np.pi*175/180:
49+
raise RuntimeError("beta=={}: must be a radian, and \
50+
must be in np.pi*5/180 < beta < np.pi*175/180".format(beta))
51+
if not np.pi*5/180<gamma< np.pi*175/180:
52+
raise RuntimeError("gamma=={}: must be a radian, and \
53+
must be in np.pi*5/180 < gamma < np.pi*175/180".format(gamma))
54+
if not A > 0.2:
55+
raise RuntimeError("A=={}, must be greater than 0.2".format(A))
56+
if not B > 0.2:
57+
raise RuntimeError("B=={}, must be greater than 0.2".format(B))
58+
if not C > 0.2:
59+
raise RuntimeError("C=={}, must be greater than 0.2".format(C))
60+
61+
lx = A
62+
xy = B * np.cos(gamma)
63+
xz = C * np.cos(beta)
64+
ly = B* np.sin(gamma)
65+
if not ly > 0.1:
66+
raise RuntimeError("ly:=B* np.sin(gamma)=={}, must be greater than 0.1",format(ly))
67+
yz = (B*C*np.cos(alpha)-xy*xz)/ly
68+
lz = np.sqrt(C**2-xz**2-yz**2)
69+
cell = np.asarray([[lx, 0 , 0],
70+
[xy, ly, 0 ],
71+
[xz, yz, lz]]).astype('float32')
72+
return cell
2073
class Cp2kSystems(object):
2174
"""
2275
deal with cp2k outputfile
@@ -123,15 +176,17 @@ def handle_single_log_frame(self, lines):
123176
cell_gamma = np.deg2rad(float(cell_angle_pattern.match(line).groupdict()['gamma']))
124177
cell_flag+=1
125178
if cell_flag == 2:
126-
lx = cell_A
127-
xy = cell_B * np.cos(cell_gamma)
128-
xz = cell_C * np.cos(cell_beta)
129-
ly = cell_B* np.sin(cell_gamma)
130-
yz = (cell_B*cell_C*np.cos(cell_alpha)-xy*xz)/ly
131-
lz = np.sqrt(cell_C**2-xz**2-yz**2)
132-
self.cell = [[lx, 0 , 0],
133-
[xy, ly, 0 ],
134-
[xz, yz, lz]]
179+
self.cell = cell_to_low_triangle(cell_A,cell_B,cell_C,
180+
cell_alpha,cell_beta,cell_gamma)
181+
# lx = cell_A
182+
# xy = cell_B * np.cos(cell_gamma)
183+
# xz = cell_C * np.cos(cell_beta)
184+
# ly = cell_B* np.sin(cell_gamma)
185+
# yz = (cell_B*cell_C*np.cos(cell_alpha)-xy*xz)/ly
186+
# lz = np.sqrt(cell_C**2-xz**2-yz**2)
187+
# self.cell = [[lx, 0 , 0],
188+
# [xy, ly, 0 ],
189+
# [xz, yz, lz]]
135190

136191
element_index = -1
137192
element_dict = OrderedDict()

tests/test_cell_to_low_triangle.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
import numpy as np
3+
import unittest
4+
from dpdata.cp2k.output import cell_to_low_triangle
5+
6+
class CellToLowTriangle(unittest.TestCase):
7+
def test_func1(self):
8+
cell_1 = cell_to_low_triangle(6,6,6,np.pi*1/2, np.pi*1/2, np.pi*1/2)
9+
cell_2 = np.asarray([[6,0,0],[0,6,0],[0,0,6]])
10+
for ii in range(3):
11+
for jj in range(3):
12+
self.assertAlmostEqual(cell_1[ii,jj], cell_2[ii,jj], places=6)
13+
14+
def test_func2(self):
15+
cell_1 = cell_to_low_triangle(6,6,6,np.pi*1/3, np.pi*1/3, np.pi*1/3)
16+
cell_2 = np.asarray([
17+
[6,0,0],
18+
[3,3*np.sqrt(3),0],
19+
[3,np.sqrt(3),2*np.sqrt(6)]])
20+
for ii in range(3):
21+
for jj in range(3):
22+
self.assertAlmostEqual(cell_1[ii,jj], cell_2[ii,jj], places=6)
23+
24+
def test_func3(self):
25+
with self.assertRaises(Exception) as c:
26+
cell_to_low_triangle(0.1,6,6,np.pi*1/2,np.pi*1/2,np.pi*1/2)
27+
self.assertTrue("A==0.1" in str(c.exception))
28+
29+
def test_func4(self):
30+
with self.assertRaises(Exception) as c:
31+
cell_to_low_triangle(6,6,6,np.pi*3/180,np.pi*1/2,np.pi*1/2)
32+
self.assertTrue("alpha" in str(c.exception))
33+
34+
if __name__ == '__main__':
35+
unittest.main()

0 commit comments

Comments
 (0)