48
48
from collections .abc import Callable , Iterable , Sequence
49
49
from typing import Any , Literal
50
50
51
+ from numpy .typing import NDArray
52
+
51
53
52
54
__author__ = "Shyue Ping Ong, Stephen Dacek, Anubhav Jain, Matthew Horton, Alex Ganose"
53
55
@@ -67,6 +69,9 @@ def __init__(self, charge_balance_sp):
67
69
"""
68
70
self .charge_balance_sp = str (charge_balance_sp )
69
71
72
+ def __repr__ (self ):
73
+ return f"Charge Balance Transformation : Species to remove = { self .charge_balance_sp } "
74
+
70
75
def apply_transformation (self , structure : Structure ):
71
76
"""Apply the transformation.
72
77
@@ -86,9 +91,6 @@ def apply_transformation(self, structure: Structure):
86
91
trans = SubstitutionTransformation ({self .charge_balance_sp : {self .charge_balance_sp : 1 - removal_fraction }})
87
92
return trans .apply_transformation (structure )
88
93
89
- def __repr__ (self ):
90
- return f"Charge Balance Transformation : Species to remove = { self .charge_balance_sp } "
91
-
92
94
93
95
class SuperTransformation (AbstractTransformation ):
94
96
"""This is a transformation that is inherently one-to-many. It is constructed
@@ -110,6 +112,9 @@ def __init__(self, transformations, nstructures_per_trans=1):
110
112
self ._transformations = transformations
111
113
self .nstructures_per_trans = nstructures_per_trans
112
114
115
+ def __repr__ (self ):
116
+ return f"Super Transformation : Transformations = { ' ' .join (map (str , self ._transformations ))} "
117
+
113
118
def apply_transformation (self , structure : Structure , return_ranked_list : bool | int = False ):
114
119
"""Apply the transformation.
115
120
@@ -139,11 +144,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
139
144
)
140
145
return structures
141
146
142
- def __repr__ (self ):
143
- return f"Super Transformation : Transformations = { ' ' .join (map (str , self ._transformations ))} "
144
-
145
147
@property
146
- def is_one_to_many (self ) -> bool :
148
+ def is_one_to_many (self ) -> Literal [ True ] :
147
149
"""Transform one structure to many."""
148
150
return True
149
151
@@ -191,6 +193,9 @@ def __init__(
191
193
self .charge_balance_species = charge_balance_species
192
194
self .order = order
193
195
196
+ def __repr__ (self ):
197
+ return f"Multiple Substitution Transformation : Substitution on { self .sp_to_replace } "
198
+
194
199
def apply_transformation (self , structure : Structure , return_ranked_list : bool | int = False ):
195
200
"""Apply the transformation.
196
201
@@ -233,11 +238,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
233
238
outputs .append ({"structure" : new_structure })
234
239
return outputs
235
240
236
- def __repr__ (self ):
237
- return f"Multiple Substitution Transformation : Substitution on { self .sp_to_replace } "
238
-
239
241
@property
240
- def is_one_to_many (self ) -> bool :
242
+ def is_one_to_many (self ) -> Literal [ True ] :
241
243
"""Transform one structure to many."""
242
244
return True
243
245
@@ -322,6 +324,9 @@ def __init__(
322
324
if max_cell_size and max_disordered_sites :
323
325
raise ValueError ("Cannot set both max_cell_size and max_disordered_sites!" )
324
326
327
+ def __repr__ (self ):
328
+ return "EnumerateStructureTransformation"
329
+
325
330
def apply_transformation (
326
331
self , structure : Structure , return_ranked_list : bool | int = False
327
332
) -> Structure | list [dict ]:
@@ -468,11 +473,8 @@ def sort_func(struct):
468
473
return self ._all_structures [:num_to_return ]
469
474
return self ._all_structures [0 ]["structure" ]
470
475
471
- def __repr__ (self ):
472
- return "EnumerateStructureTransformation"
473
-
474
476
@property
475
- def is_one_to_many (self ) -> bool :
477
+ def is_one_to_many (self ) -> Literal [ True ] :
476
478
"""Transform one structure to many."""
477
479
return True
478
480
@@ -494,6 +496,9 @@ def __init__(self, threshold=1e-2, scale_volumes=True, **kwargs):
494
496
self .scale_volumes = scale_volumes
495
497
self ._substitutor = SubstitutionPredictor (threshold = threshold , ** kwargs )
496
498
499
+ def __repr__ (self ):
500
+ return "SubstitutionPredictorTransformation"
501
+
497
502
def apply_transformation (self , structure : Structure , return_ranked_list : bool | int = False ):
498
503
"""Apply the transformation.
499
504
@@ -528,11 +533,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
528
533
outputs .append (output )
529
534
return outputs
530
535
531
- def __repr__ (self ):
532
- return "SubstitutionPredictorTransformation"
533
-
534
536
@property
535
- def is_one_to_many (self ) -> bool :
537
+ def is_one_to_many (self ) -> Literal [ True ] :
536
538
"""Transform one structure to many."""
537
539
return True
538
540
@@ -895,7 +897,7 @@ def key(struct: Structure) -> int:
895
897
return self ._all_structures [:num_to_return ] # type: ignore[return-value]
896
898
897
899
@property
898
- def is_one_to_many (self ) -> bool :
900
+ def is_one_to_many (self ) -> Literal [ True ] :
899
901
"""Transform one structure to many."""
900
902
return True
901
903
@@ -984,15 +986,19 @@ def __init__(
984
986
self .allowed_doping_species = allowed_doping_species
985
987
self .kwargs = kwargs
986
988
987
- def apply_transformation (self , structure : Structure , return_ranked_list : bool | int = False ):
989
+ def apply_transformation (
990
+ self ,
991
+ structure : Structure ,
992
+ return_ranked_list : bool | int = False ,
993
+ ) -> list [dict [Literal ["structure" , "energy" ], Structure | float ]] | Structure :
988
994
"""
989
995
Args:
990
- structure (Structure): Input structure to dope
991
- return_ranked_list (bool | int, optional): If return_ranked_list is int, that number of structures.
992
- is returned. If False, only the single lowest energy structure is returned. Defaults to False.
996
+ structure (Structure): Input structure to dope.
997
+ return_ranked_list (bool | int, optional): If is int, that number of structures is returned .
998
+ If False, only the single lowest energy structure is returned. Defaults to False.
993
999
994
1000
Returns:
995
- list[dict] | Structure: each dict has shape {"structure": Structure, "energy": float}.
1001
+ list[dict] | Structure: each dict as {"structure": Structure, "energy": float}.
996
1002
"""
997
1003
comp = structure .composition
998
1004
logger .info (f"Composition: { comp } " )
@@ -1125,7 +1131,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
1125
1131
return all_structures [0 ]["structure" ]
1126
1132
1127
1133
@property
1128
- def is_one_to_many (self ) -> bool :
1134
+ def is_one_to_many (self ) -> Literal [ True ] :
1129
1135
"""Transform one structure to many."""
1130
1136
return True
1131
1137
@@ -1253,7 +1259,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
1253
1259
return disordered_structures
1254
1260
1255
1261
@property
1256
- def is_one_to_many (self ) -> bool :
1262
+ def is_one_to_many (self ) -> Literal [ True ] :
1257
1263
"""Transform one structure to many."""
1258
1264
return True
1259
1265
@@ -1714,7 +1720,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
1714
1720
return [{"structure" : structure } for structure in structures [:return_ranked_list ]]
1715
1721
1716
1722
@property
1717
- def is_one_to_many (self ) -> bool :
1723
+ def is_one_to_many (self ) -> Literal [ True ] :
1718
1724
"""Transform one structure to many."""
1719
1725
return True
1720
1726
@@ -1868,16 +1874,25 @@ def apply_transformation(
1868
1874
return [{"structure" : structure } for structure in structures [:return_ranked_list ]]
1869
1875
1870
1876
@property
1871
- def is_one_to_many (self ) -> bool :
1877
+ def is_one_to_many (self ) -> Literal [ True ] :
1872
1878
"""Transform one structure to many."""
1873
1879
return True
1874
1880
1875
1881
1876
- def _proj (b , a ):
1877
- """Get vector projection (np.ndarray) of vector b (np.ndarray)
1878
- onto vector a (np.ndarray).
1882
+ def _proj (b : NDArray , a : NDArray ) -> NDArray :
1883
+ """Get vector projection of vector b onto vector a.
1884
+
1885
+ Args:
1886
+ b (NDArray): Vector to be projected.
1887
+ a (NDArray): Vector onto which `b` is projected.
1888
+
1889
+ Returns:
1890
+ NDArray: Projection of `b` onto `a`.
1879
1891
"""
1880
- return (b .T @ (a / np .linalg .norm (a ))) * (a / np .linalg .norm (a ))
1892
+ a = np .asarray (a )
1893
+ b = np .asarray (b )
1894
+
1895
+ return (np .dot (b , a ) / np .dot (a , a )) * a
1881
1896
1882
1897
1883
1898
class SQSTransformation (AbstractTransformation ):
@@ -2146,7 +2161,7 @@ def _get_unique_best_sqs_structs(sqs, best_only, return_ranked_list, remove_dupl
2146
2161
return to_return
2147
2162
2148
2163
@property
2149
- def is_one_to_many (self ) -> bool :
2164
+ def is_one_to_many (self ) -> Literal [ True ] :
2150
2165
"""Transform one structure to many."""
2151
2166
return True
2152
2167
@@ -2195,6 +2210,9 @@ def __init__(self, rattle_std: float, min_distance: float, seed: int | None = No
2195
2210
self .random_state = np .random .RandomState (seed )
2196
2211
self .kwargs = kwargs
2197
2212
2213
+ def __repr__ (self ):
2214
+ return f"{ __name__ } : rattle_std = { self .rattle_std } "
2215
+
2198
2216
def apply_transformation (self , structure : Structure ) -> Structure :
2199
2217
"""Apply the transformation.
2200
2218
@@ -2216,6 +2234,3 @@ def apply_transformation(self, structure: Structure) -> Structure:
2216
2234
structure .cart_coords + displacements ,
2217
2235
coords_are_cartesian = True ,
2218
2236
)
2219
-
2220
- def __repr__ (self ):
2221
- return f"{ __name__ } : rattle_std = { self .rattle_std } "
0 commit comments