46
46
47
47
from __future__ import annotations
48
48
49
+ from typing import TYPE_CHECKING
50
+
49
51
import numpy as np
50
52
from scipy .interpolate import UnivariateSpline
51
53
52
54
from pymatgen .core .lattice import Lattice
53
55
from pymatgen .core .structure import Structure
54
56
57
+ if TYPE_CHECKING :
58
+ from collections .abc import Sequence
59
+
60
+ from pymatgen .core .sites import PeriodicSite
61
+
62
+
55
63
__author__ = "Tess Smidt"
56
64
__copyright__ = "Copyright 2017, The Materials Project"
57
65
__version__ = "1.0"
@@ -73,7 +81,7 @@ def zval_dict_from_potcar(potcar):
73
81
return zval_dict
74
82
75
83
76
- def calc_ionic (site , structure : Structure , zval ) :
84
+ def calc_ionic (site : PeriodicSite , structure : Structure , zval : float ) -> np . ndarray :
77
85
"""
78
86
Calculate the ionic dipole moment using ZVAL from pseudopotential.
79
87
@@ -103,31 +111,27 @@ def get_total_ionic_dipole(structure, zval_dict):
103
111
return np .sum (tot_ionic , axis = 0 )
104
112
105
113
106
- class PolarizationLattice (Structure ):
107
- """TODO Why is a Lattice inheriting a structure? This is ridiculous."""
108
-
109
- def get_nearest_site (self , coords , site , r = None ):
110
- """
111
- Given coords and a site, find closet site to coords.
114
+ def get_nearest_site (struct : Structure , coords : Sequence [float ], site : PeriodicSite , r : float | None = None ):
115
+ """
116
+ Given coords and a site, find closet site to coords.
112
117
113
- Args:
114
- coords (3x1 array): Cartesian coords of center of sphere
115
- site: site to find closest to coords
116
- r : radius of sphere. Defaults to diagonal of unit cell
118
+ Args:
119
+ coords (3x1 array): Cartesian coords of center of sphere
120
+ site: site to find closest to coords
121
+ r (float) : radius of sphere. Defaults to diagonal of unit cell
117
122
118
- Returns:
119
- Closest site and distance.
120
- """
121
- index = self .index (site )
122
- if r is None :
123
- r = np .linalg .norm (np .sum (self .lattice .matrix , axis = 0 ))
124
- ns = self .get_sites_in_sphere (coords , r , include_index = True )
125
- # Get sites with identical index to site
126
- ns = [n for n in ns if n [2 ] == index ]
127
- # Sort by distance to coords
128
- ns .sort (key = lambda x : x [1 ])
129
- # Return PeriodicSite and distance of closest image
130
- return ns [0 ][0 :2 ]
123
+ Returns:
124
+ Closest site and distance.
125
+ """
126
+ index = struct .index (site )
127
+ r = r or np .linalg .norm (np .sum (struct .lattice .matrix , axis = 0 ))
128
+ ns = struct .get_sites_in_sphere (coords , r , include_index = True )
129
+ # Get sites with identical index to site
130
+ ns = [n for n in ns if n [2 ] == index ]
131
+ # Sort by distance to coords
132
+ ns .sort (key = lambda x : x [1 ])
133
+ # Return PeriodicSite and distance of closest image
134
+ return ns [0 ][0 :2 ]
131
135
132
136
133
137
class Polarization :
@@ -298,18 +302,18 @@ def get_same_branch_polarization_data(self, convert_to_muC_per_cm2=True, all_in_
298
302
for idx in range (n_elecs ):
299
303
lattice = lattices [idx ]
300
304
frac_coord = np .divide (np .array ([p_tot [idx ]]), np .array (lattice .lengths ))
301
- d = PolarizationLattice (lattice , ["C" ], [np .array (frac_coord ).ravel ()])
302
- d_structs .append (d )
303
- site = d [0 ]
305
+ struct = Structure (lattice , ["C" ], [np .array (frac_coord ).ravel ()])
306
+ d_structs .append (struct )
307
+ site = struct [0 ]
304
308
# Adjust nonpolar polarization to be closest to zero.
305
309
# This is compatible with both a polarization of zero or a half quantum.
306
310
prev_site = [0 , 0 , 0 ] if idx == 0 else sites [- 1 ].coords
307
- new_site = d . get_nearest_site (prev_site , site )
311
+ new_site = get_nearest_site (struct , prev_site , site )
308
312
sites .append (new_site [0 ])
309
313
310
314
adjust_pol = []
311
- for site , d in zip (sites , d_structs ):
312
- adjust_pol .append (np .multiply (site .frac_coords , np .array (d .lattice .lengths )).ravel ())
315
+ for site , struct in zip (sites , d_structs ):
316
+ adjust_pol .append (np .multiply (site .frac_coords , np .array (struct .lattice .lengths )).ravel ())
313
317
return np .array (adjust_pol )
314
318
315
319
def get_lattice_quanta (self , convert_to_muC_per_cm2 = True , all_in_polar = True ):
0 commit comments