23
23
from pymatgen .electronic_structure .core import Spin
24
24
25
25
if TYPE_CHECKING :
26
+ from numpy .typing import NDArray
26
27
from typing_extensions import Any , Self
27
28
28
29
from pymatgen .core .structure import IStructure
@@ -62,8 +63,8 @@ def __init__(
62
63
self ,
63
64
structure : Structure | IStructure ,
64
65
data : dict [str , np .ndarray ],
65
- distance_matrix : np . ndarray | None = None ,
66
- data_aug : np . ndarray | None = None ,
66
+ distance_matrix : dict | None = None ,
67
+ data_aug : dict [ str , NDArray ] | None = None ,
67
68
) -> None :
68
69
"""
69
70
Typically, this constructor is not used directly and the static
@@ -85,11 +86,11 @@ def __init__(
85
86
# convert data to numpy arrays in case they were jsanitized as lists
86
87
self .data = {k : np .array (v ) for k , v in data .items ()}
87
88
self .dim = self .data ["total" ].shape
88
- self .data_aug = data_aug
89
+ self .data_aug = data_aug or {}
89
90
self .ngridpts = self .dim [0 ] * self .dim [1 ] * self .dim [2 ]
90
91
# lazy init the spin data since this is not always needed.
91
92
self ._spin_data : dict [Spin , float ] = {}
92
- self ._distance_matrix = distance_matrix
93
+ self ._distance_matrix = distance_matrix if distance_matrix is not None else {}
93
94
self .xpoints = np .linspace (0.0 , 1.0 , num = self .dim [0 ])
94
95
self .ypoints = np .linspace (0.0 , 1.0 , num = self .dim [1 ])
95
96
self .zpoints = np .linspace (0.0 , 1.0 , num = self .dim [2 ])
@@ -168,7 +169,7 @@ def linear_add(self, other, scale_factor=1.0) -> VolumetricData:
168
169
169
170
new = deepcopy (self )
170
171
new .data = data
171
- new .data_aug = None
172
+ new .data_aug = {}
172
173
return new
173
174
174
175
def scale (self , factor ):
@@ -247,6 +248,7 @@ def get_integrated_diff(self, ind, radius, nbins=1):
247
248
248
249
struct = self .structure
249
250
a = self .dim
251
+ self ._distance_matrix = {} if self ._distance_matrix is None else self ._distance_matrix
250
252
if ind not in self ._distance_matrix or self ._distance_matrix [ind ]["max_radius" ] < radius :
251
253
coords = []
252
254
for x , y , z in itertools .product (* (list (range (i )) for i in a )):
0 commit comments