Skip to content

Commit 7e7756e

Browse files
authored
Avoid using full equality (==) to compare float, avoid assert_array_equal compare float array (#4159)
* replace some float equality check * explicit encoding * charge is also float * enhance types * access gcd via math namespace as math is already imported * put dunder method to top * fix typo * tweak _proj implementation * support array like * add arg and return type * tweak type * avoid more == for float comparison * replace some == in test, more left to do * replace more in core test * replace more in test * replace even more * replace last batch * clean up assert approx * replace pytest.approx with approx * also fix membership check * replace some equality check of list * replace some sequences * fix test * replace float comparison as dict * fix test * replace more float compare, mostly for VASP * fix test * fix approx in condition block * replace sci notation * suppress buggy ruff sim300 * number_of_permutations to int * revert change for formula_double_format, in favor of another PR * c_indices seems to be int * use sci notation for crazily large int * simplify numpy.testing usage * set tol as pos arg * avoid array equal for list of str * assert_array_equal should not be used on float array * fix module level var name * more assert_array_equal on complex number * simplify approx on dict value * avoid module level var when it's used only 3 times * pytext.approx to approx * fix approx on nested dict * avoid unnecessary convert to np.array * array_equal to all close for float array * assert all close for float array * capital class attrib is treated as constant
1 parent 5f744f2 commit 7e7756e

File tree

90 files changed

+1222
-1139
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+1222
-1139
lines changed

src/pymatgen/alchemy/filters.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import abc
6+
import math
67
from collections import defaultdict
78
from typing import TYPE_CHECKING
89

@@ -285,7 +286,7 @@ def __init__(self):
285286

286287
def test(self, structure: Structure):
287288
"""True if structure is neutral."""
288-
return structure.charge == 0.0
289+
return math.isclose(structure.charge, 0.0)
289290

290291

291292
class SpeciesMaxDistFilter(AbstractStructureFilter):

src/pymatgen/core/bonds.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,10 @@ def obtain_all_bond_lengths(
134134
If None, a ValueError will be thrown.
135135
136136
Returns:
137-
A dict mapping bond order to bond length in angstrom
137+
dict[float, float]: mapping bond order to bond length in Angstrom.
138+
139+
Todo:
140+
it's better to avoid using float as dict keys.
138141
"""
139142
if isinstance(sp1, Element):
140143
sp1 = sp1.symbol

src/pymatgen/electronic_structure/plotter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,7 @@ def get_elt_projected_plots_color(
12461246
proj[b][str(spin)][band_idx][j][str(el)][o]
12471247
for o in proj[b][str(spin)][band_idx][j][str(el)]
12481248
)
1249-
if sum_e == 0.0:
1249+
if math.isclose(sum_e, 0.0):
12501250
color = [0.0] * len(elt_ordered)
12511251
else:
12521252
color = [

src/pymatgen/io/aims/inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def get_content(
566566
magmom = structure.site_properties.get("magmom", spins)
567567
if (
568568
parameters.get("spin", "") == "collinear"
569-
and np.all(magmom == 0.0)
569+
and np.allclose(magmom, 0.0)
570570
and ("default_initial_moment" not in parameters)
571571
):
572572
warn(

src/pymatgen/io/cp2k/outputs.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,10 +1320,16 @@ def parse_bandstructure(self, bandstructure_filename=None) -> None:
13201320
else:
13211321
eigenvals = {Spin.up: bands_data.reshape((nbands, nkpts))}
13221322

1323-
occ = bands_data[:, 1][bands_data[:, -1] != 0.0]
1323+
# Filter out occupied and unoccupied states
1324+
occupied_mask = ~np.isclose(bands_data[:, -1], 0.0)
1325+
unoccupied_mask = np.isclose(bands_data[:, -1], 0.0)
1326+
1327+
occ = bands_data[:, 1][occupied_mask]
13241328
homo = np.max(occ)
1325-
unocc = bands_data[:, 1][bands_data[:, -1] == 0.0]
1329+
1330+
unocc = bands_data[:, 1][unoccupied_mask]
13261331
lumo = np.min(unocc)
1332+
13271333
efermi = (lumo + homo) / 2
13281334
self.efermi = efermi
13291335

src/pymatgen/io/vasp/outputs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -782,13 +782,13 @@ def run_type(self) -> str:
782782
4: "dDsC",
783783
}
784784

785-
if self.parameters.get("AEXX", 1.00) == 1.00:
785+
if math.isclose(self.parameters.get("AEXX", 1.00), 1.00):
786786
run_type = "HF"
787-
elif self.parameters.get("HFSCREEN", 0.30) == 0.30:
787+
elif math.isclose(self.parameters.get("HFSCREEN", 0.30), 0.30):
788788
run_type = "HSE03"
789-
elif self.parameters.get("HFSCREEN", 0.20) == 0.20:
789+
elif math.isclose(self.parameters.get("HFSCREEN", 0.20), 0.20):
790790
run_type = "HSE06"
791-
elif self.parameters.get("AEXX", 0.20) == 0.20:
791+
elif math.isclose(self.parameters.get("AEXX", 0.20), 0.20):
792792
run_type = "B3LYP"
793793
elif self.parameters.get("LHFCALC", True):
794794
run_type = "PBEO or other Hybrid Functional"
@@ -1031,7 +1031,7 @@ def get_band_structure(
10311031
if (hybrid_band or force_hybrid_mode) and not use_kpoints_opt:
10321032
start_bs_index = 0
10331033
for i in range(len(self.actual_kpoints)):
1034-
if self.actual_kpoints_weights[i] == 0.0:
1034+
if math.isclose(self.actual_kpoints_weights[i], 0.0):
10351035
start_bs_index = i
10361036
break
10371037
for i in range(start_bs_index, len(kpoint_file.kpts)):
@@ -5386,7 +5386,7 @@ def get_parchg(
53865386
Returns:
53875387
A Chgcar object.
53885388
"""
5389-
if phase and not np.all(self.kpoints[kpoint] == 0.0):
5389+
if phase and not np.allclose(self.kpoints[kpoint], 0.0):
53905390
warnings.warn(
53915391
"phase is True should only be used for the Gamma kpoint! I hope you know what you're doing!",
53925392
stacklevel=2,

src/pymatgen/transformations/advanced_transformations.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
from collections.abc import Callable, Iterable, Sequence
4949
from typing import Any, Literal
5050

51+
from numpy.typing import NDArray
52+
5153

5254
__author__ = "Shyue Ping Ong, Stephen Dacek, Anubhav Jain, Matthew Horton, Alex Ganose"
5355

@@ -67,6 +69,9 @@ def __init__(self, charge_balance_sp):
6769
"""
6870
self.charge_balance_sp = str(charge_balance_sp)
6971

72+
def __repr__(self):
73+
return f"Charge Balance Transformation : Species to remove = {self.charge_balance_sp}"
74+
7075
def apply_transformation(self, structure: Structure):
7176
"""Apply the transformation.
7277
@@ -86,9 +91,6 @@ def apply_transformation(self, structure: Structure):
8691
trans = SubstitutionTransformation({self.charge_balance_sp: {self.charge_balance_sp: 1 - removal_fraction}})
8792
return trans.apply_transformation(structure)
8893

89-
def __repr__(self):
90-
return f"Charge Balance Transformation : Species to remove = {self.charge_balance_sp}"
91-
9294

9395
class SuperTransformation(AbstractTransformation):
9496
"""This is a transformation that is inherently one-to-many. It is constructed
@@ -110,6 +112,9 @@ def __init__(self, transformations, nstructures_per_trans=1):
110112
self._transformations = transformations
111113
self.nstructures_per_trans = nstructures_per_trans
112114

115+
def __repr__(self):
116+
return f"Super Transformation : Transformations = {' '.join(map(str, self._transformations))}"
117+
113118
def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
114119
"""Apply the transformation.
115120
@@ -139,11 +144,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
139144
)
140145
return structures
141146

142-
def __repr__(self):
143-
return f"Super Transformation : Transformations = {' '.join(map(str, self._transformations))}"
144-
145147
@property
146-
def is_one_to_many(self) -> bool:
148+
def is_one_to_many(self) -> Literal[True]:
147149
"""Transform one structure to many."""
148150
return True
149151

@@ -191,6 +193,9 @@ def __init__(
191193
self.charge_balance_species = charge_balance_species
192194
self.order = order
193195

196+
def __repr__(self):
197+
return f"Multiple Substitution Transformation : Substitution on {self.sp_to_replace}"
198+
194199
def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
195200
"""Apply the transformation.
196201
@@ -233,11 +238,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
233238
outputs.append({"structure": new_structure})
234239
return outputs
235240

236-
def __repr__(self):
237-
return f"Multiple Substitution Transformation : Substitution on {self.sp_to_replace}"
238-
239241
@property
240-
def is_one_to_many(self) -> bool:
242+
def is_one_to_many(self) -> Literal[True]:
241243
"""Transform one structure to many."""
242244
return True
243245

@@ -322,6 +324,9 @@ def __init__(
322324
if max_cell_size and max_disordered_sites:
323325
raise ValueError("Cannot set both max_cell_size and max_disordered_sites!")
324326

327+
def __repr__(self):
328+
return "EnumerateStructureTransformation"
329+
325330
def apply_transformation(
326331
self, structure: Structure, return_ranked_list: bool | int = False
327332
) -> Structure | list[dict]:
@@ -468,11 +473,8 @@ def sort_func(struct):
468473
return self._all_structures[:num_to_return]
469474
return self._all_structures[0]["structure"]
470475

471-
def __repr__(self):
472-
return "EnumerateStructureTransformation"
473-
474476
@property
475-
def is_one_to_many(self) -> bool:
477+
def is_one_to_many(self) -> Literal[True]:
476478
"""Transform one structure to many."""
477479
return True
478480

@@ -494,6 +496,9 @@ def __init__(self, threshold=1e-2, scale_volumes=True, **kwargs):
494496
self.scale_volumes = scale_volumes
495497
self._substitutor = SubstitutionPredictor(threshold=threshold, **kwargs)
496498

499+
def __repr__(self):
500+
return "SubstitutionPredictorTransformation"
501+
497502
def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
498503
"""Apply the transformation.
499504
@@ -528,11 +533,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
528533
outputs.append(output)
529534
return outputs
530535

531-
def __repr__(self):
532-
return "SubstitutionPredictorTransformation"
533-
534536
@property
535-
def is_one_to_many(self) -> bool:
537+
def is_one_to_many(self) -> Literal[True]:
536538
"""Transform one structure to many."""
537539
return True
538540

@@ -895,7 +897,7 @@ def key(struct: Structure) -> int:
895897
return self._all_structures[:num_to_return] # type: ignore[return-value]
896898

897899
@property
898-
def is_one_to_many(self) -> bool:
900+
def is_one_to_many(self) -> Literal[True]:
899901
"""Transform one structure to many."""
900902
return True
901903

@@ -984,15 +986,19 @@ def __init__(
984986
self.allowed_doping_species = allowed_doping_species
985987
self.kwargs = kwargs
986988

987-
def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
989+
def apply_transformation(
990+
self,
991+
structure: Structure,
992+
return_ranked_list: bool | int = False,
993+
) -> list[dict[Literal["structure", "energy"], Structure | float]] | Structure:
988994
"""
989995
Args:
990-
structure (Structure): Input structure to dope
991-
return_ranked_list (bool | int, optional): If return_ranked_list is int, that number of structures.
992-
is returned. If False, only the single lowest energy structure is returned. Defaults to False.
996+
structure (Structure): Input structure to dope.
997+
return_ranked_list (bool | int, optional): If is int, that number of structures is returned.
998+
If False, only the single lowest energy structure is returned. Defaults to False.
993999
9941000
Returns:
995-
list[dict] | Structure: each dict has shape {"structure": Structure, "energy": float}.
1001+
list[dict] | Structure: each dict as {"structure": Structure, "energy": float}.
9961002
"""
9971003
comp = structure.composition
9981004
logger.info(f"Composition: {comp}")
@@ -1125,7 +1131,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
11251131
return all_structures[0]["structure"]
11261132

11271133
@property
1128-
def is_one_to_many(self) -> bool:
1134+
def is_one_to_many(self) -> Literal[True]:
11291135
"""Transform one structure to many."""
11301136
return True
11311137

@@ -1253,7 +1259,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
12531259
return disordered_structures
12541260

12551261
@property
1256-
def is_one_to_many(self) -> bool:
1262+
def is_one_to_many(self) -> Literal[True]:
12571263
"""Transform one structure to many."""
12581264
return True
12591265

@@ -1714,7 +1720,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
17141720
return [{"structure": structure} for structure in structures[:return_ranked_list]]
17151721

17161722
@property
1717-
def is_one_to_many(self) -> bool:
1723+
def is_one_to_many(self) -> Literal[True]:
17181724
"""Transform one structure to many."""
17191725
return True
17201726

@@ -1868,16 +1874,25 @@ def apply_transformation(
18681874
return [{"structure": structure} for structure in structures[:return_ranked_list]]
18691875

18701876
@property
1871-
def is_one_to_many(self) -> bool:
1877+
def is_one_to_many(self) -> Literal[True]:
18721878
"""Transform one structure to many."""
18731879
return True
18741880

18751881

1876-
def _proj(b, a):
1877-
"""Get vector projection (np.ndarray) of vector b (np.ndarray)
1878-
onto vector a (np.ndarray).
1882+
def _proj(b: NDArray, a: NDArray) -> NDArray:
1883+
"""Get vector projection of vector b onto vector a.
1884+
1885+
Args:
1886+
b (NDArray): Vector to be projected.
1887+
a (NDArray): Vector onto which `b` is projected.
1888+
1889+
Returns:
1890+
NDArray: Projection of `b` onto `a`.
18791891
"""
1880-
return (b.T @ (a / np.linalg.norm(a))) * (a / np.linalg.norm(a))
1892+
a = np.asarray(a)
1893+
b = np.asarray(b)
1894+
1895+
return (np.dot(b, a) / np.dot(a, a)) * a
18811896

18821897

18831898
class SQSTransformation(AbstractTransformation):
@@ -2146,7 +2161,7 @@ def _get_unique_best_sqs_structs(sqs, best_only, return_ranked_list, remove_dupl
21462161
return to_return
21472162

21482163
@property
2149-
def is_one_to_many(self) -> bool:
2164+
def is_one_to_many(self) -> Literal[True]:
21502165
"""Transform one structure to many."""
21512166
return True
21522167

@@ -2195,6 +2210,9 @@ def __init__(self, rattle_std: float, min_distance: float, seed: int | None = No
21952210
self.random_state = np.random.RandomState(seed)
21962211
self.kwargs = kwargs
21972212

2213+
def __repr__(self):
2214+
return f"{__name__} : rattle_std = {self.rattle_std}"
2215+
21982216
def apply_transformation(self, structure: Structure) -> Structure:
21992217
"""Apply the transformation.
22002218
@@ -2216,6 +2234,3 @@ def apply_transformation(self, structure: Structure) -> Structure:
22162234
structure.cart_coords + displacements,
22172235
coords_are_cartesian=True,
22182236
)
2219-
2220-
def __repr__(self):
2221-
return f"{__name__} : rattle_std = {self.rattle_std}"

src/pymatgen/transformations/transformation_abc.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from __future__ import annotations
44

55
import abc
6-
from typing import TYPE_CHECKING, Any
6+
from typing import TYPE_CHECKING
77

88
from monty.json import MSONable
99

1010
if TYPE_CHECKING:
11+
from typing import Any, Literal
12+
1113
from pymatgen.core import Structure
1214

1315
__author__ = "Shyue Ping Ong"
@@ -55,7 +57,7 @@ def inverse(self) -> AbstractTransformation | None:
5557
"""
5658

5759
@property
58-
def is_one_to_many(self) -> bool:
60+
def is_one_to_many(self) -> Literal[False]:
5961
"""Determine if a Transformation is a one-to-many transformation. In that case, the
6062
apply_transformation method should have a keyword arg "return_ranked_list" which
6163
allows for the transformed structures to be returned as a ranked list.
@@ -64,7 +66,7 @@ def is_one_to_many(self) -> bool:
6466
return False
6567

6668
@property
67-
def use_multiprocessing(self) -> bool:
69+
def use_multiprocessing(self) -> Literal[False]:
6870
"""Indicates whether the transformation can be applied by a
6971
subprocessing pool. This should be overridden to return True for
7072
transformations that the transmuter can parallelize.

tests/analysis/chemenv/coordination_environments/test_coordination_geometries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_coordination_geometry(self):
8989
assert cg_oct.IUCr_symbol_str == "[6o]"
9090

9191
cg_oct.permutations_safe_override = True
92-
assert cg_oct.number_of_permutations == 720.0
92+
assert cg_oct.number_of_permutations == 720
9393
assert cg_oct.ref_permutation([0, 3, 2, 4, 5, 1]) == (0, 3, 1, 5, 2, 4)
9494

9595
sites = [FakeSite(coords=pp) for pp in cg_oct.points]

0 commit comments

Comments
 (0)