Skip to content

Commit d3df2a9

Browse files
committed
Fix regression.
1 parent 885603c commit d3df2a9

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/pymatgen/io/common.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pymatgen.electronic_structure.core import Spin
2424

2525
if TYPE_CHECKING:
26+
from numpy.typing import NDArray
2627
from typing_extensions import Any, Self
2728

2829
from pymatgen.core.structure import IStructure
@@ -62,8 +63,8 @@ def __init__(
6263
self,
6364
structure: Structure | IStructure,
6465
data: dict[str, np.ndarray],
65-
distance_matrix: np.ndarray | None = None,
66-
data_aug: np.ndarray | None = None,
66+
distance_matrix: dict | None = None,
67+
data_aug: dict[str, NDArray] | None = None,
6768
) -> None:
6869
"""
6970
Typically, this constructor is not used directly and the static
@@ -85,11 +86,11 @@ def __init__(
8586
# convert data to numpy arrays in case they were jsanitized as lists
8687
self.data = {k: np.array(v) for k, v in data.items()}
8788
self.dim = self.data["total"].shape
88-
self.data_aug = data_aug
89+
self.data_aug = data_aug or {}
8990
self.ngridpts = self.dim[0] * self.dim[1] * self.dim[2]
9091
# lazy init the spin data since this is not always needed.
9192
self._spin_data: dict[Spin, float] = {}
92-
self._distance_matrix = distance_matrix
93+
self._distance_matrix = distance_matrix if distance_matrix is not None else {}
9394
self.xpoints = np.linspace(0.0, 1.0, num=self.dim[0])
9495
self.ypoints = np.linspace(0.0, 1.0, num=self.dim[1])
9596
self.zpoints = np.linspace(0.0, 1.0, num=self.dim[2])
@@ -168,7 +169,7 @@ def linear_add(self, other, scale_factor=1.0) -> VolumetricData:
168169

169170
new = deepcopy(self)
170171
new.data = data
171-
new.data_aug = None
172+
new.data_aug = {}
172173
return new
173174

174175
def scale(self, factor):
@@ -247,6 +248,7 @@ def get_integrated_diff(self, ind, radius, nbins=1):
247248

248249
struct = self.structure
249250
a = self.dim
251+
self._distance_matrix = {} if self._distance_matrix is None else self._distance_matrix
250252
if ind not in self._distance_matrix or self._distance_matrix[ind]["max_radius"] < radius:
251253
coords = []
252254
for x, y, z in itertools.product(*(list(range(i)) for i in a)):

0 commit comments

Comments
 (0)