Skip to content

Commit dcf6a1f

Browse files
committed
being a bit more concrete with literals
1 parent 38692d2 commit dcf6a1f

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

pyerrors/correlators.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from .fits import least_squares, Fit_result
1111
from .roots import find_root
1212
from . import linalg
13+
from .input.json import dump_to_json
1314
from numpy import ndarray, ufunc
14-
from typing import Any, Callable, Optional, Union
15+
from typing import Any, Callable, Optional, Union, Literal
1516

1617

1718
class Corr:
@@ -45,7 +46,7 @@ class Corr:
4546

4647
__slots__ = ["content", "N", "T", "tag", "prange"]
4748

48-
def __init__(self, data_input: list[Obs, CObs], padding: list[int]=[0, 0], prange: Optional[list[int]]=None):
49+
def __init__(self, data_input: Union[list[Obs, CObs], list[ndarray[ndarray[Obs, CObs]]], ndarray[ndarray[Corr]]], padding: list[int]=[0, 0], prange: Optional[list[int]]=None):
4950
""" Initialize a Corr object.
5051
5152
Parameters
@@ -303,7 +304,7 @@ def matrix_symmetric(self) -> "Corr":
303304
transposed = [None if _check_for_none(self, G) else G.T for G in self.content]
304305
return 0.5 * (Corr(transposed) + self)
305306

306-
def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[str]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[list[list[Optional[ndarray]]], ndarray, list[Optional[ndarray]]]:
307+
def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[Literal["Eigenvalue", "Eigenvector"]]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[list[list[Optional[ndarray]]], ndarray, list[Optional[ndarray]]]:
307308
r'''Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors.
308309
309310
The eigenvectors are sorted according to the descending eigenvalues, the zeroth eigenvector(s) correspond to the
@@ -409,7 +410,7 @@ def _get_mat_at_t(t, vector_obs=vector_obs):
409410
else:
410411
return reordered_vecs
411412

412-
def Eigenvalue(self, t0: int, ts: None=None, state: int=0, sort: str="Eigenvalue", **kwargs) -> "Corr":
413+
def Eigenvalue(self, t0: int, ts: Optional[int]=None, state: int=0, sort: Optional[Literal["Eigenvalue", "Eigenvector"]]="Eigenvalue", **kwargs) -> "Corr":
413414
"""Determines the eigenvalue of the GEVP by solving and projecting the correlator
414415
415416
Parameters
@@ -495,7 +496,7 @@ def thin(self, spacing: int=2, offset: int=0) -> "Corr":
495496
new_content.append(self.content[t])
496497
return Corr(new_content)
497498

498-
def correlate(self, partner: Union[Corr, float, Obs]) -> "Corr":
499+
def correlate(self, partner: Union[Corr, Obs]) -> "Corr":
499500
"""Correlate the correlator with another correlator or Obs
500501
501502
Parameters
@@ -577,14 +578,14 @@ def T_symmetry(self, partner: "Corr", parity: int=+1) -> "Corr":
577578

578579
return (self + T_partner) / 2
579580

580-
def deriv(self, variant: Optional[str]="symmetric") -> "Corr":
581+
def deriv(self, variant: Literal["symmetric", "forward", "backward", "improved", "log"]="symmetric") -> "Corr":
581582
"""Return the first derivative of the correlator with respect to x0.
582583
583584
Parameters
584585
----------
585586
variant : str
586587
decides which definition of the finite differences derivative is used.
587-
Available choice: symmetric, forward, backward, improved, log, default: symmetric
588+
Available choices: symmetric, forward, backward, improved, log, default: symmetric
588589
"""
589590
if self.N != 1:
590591
raise ValueError("deriv only implemented for one-dimensional correlators.")
@@ -638,7 +639,7 @@ def deriv(self, variant: Optional[str]="symmetric") -> "Corr":
638639
else:
639640
raise ValueError("Unknown variant.")
640641

641-
def second_deriv(self, variant: Optional[str]="symmetric") -> "Corr":
642+
def second_deriv(self, variant: Literal["symmetric", "big_symmetric", "improved", "log"]="symmetric") -> "Corr":
642643
r"""Return the second derivative of the correlator with respect to x0.
643644
644645
Parameters
@@ -698,7 +699,7 @@ def second_deriv(self, variant: Optional[str]="symmetric") -> "Corr":
698699
else:
699700
raise ValueError("Unknown variant.")
700701

701-
def m_eff(self, variant: str='log', guess: float=1.0) -> "Corr":
702+
def m_eff(self, variant: Literal["log", "cosh", "periodic", "sinh", "arccosh", "logsym"]='log', guess: float=1.0) -> "Corr":
702703
"""Returns the effective mass of the correlator as correlator object
703704
704705
Parameters
@@ -813,7 +814,7 @@ def fit(self, function: Callable, fitrange: Optional[list[int]]=None, silent: bo
813814
result = least_squares(xs, ys, function, silent=silent, **kwargs)
814815
return result
815816

816-
def plateau(self, plateau_range: Optional[list[int]]=None, method: str="fit", auto_gamma: bool=False) -> Obs:
817+
def plateau(self, plateau_range: Optional[list[int]]=None, method: Literal['fit', 'avg']="fit", auto_gamma: bool=False) -> Obs:
817818
""" Extract a plateau value from a Corr object
818819
819820
Parameters
@@ -862,7 +863,7 @@ def set_prange(self, prange: list[int]):
862863
self.prange = prange
863864
return
864865

865-
def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Optional[Obs, float, int]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Optional[int, float]=None, references: Optional[list[float]]=None, title: Optional[str]=None):
866+
def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Union[Obs, float, int, None]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Union[int, float, None]=None, references: Optional[list[float]]=None, title: Optional[str]=None):
866867
"""Plots the correlator using the tag of the correlator as label if available.
867868
868869
Parameters
@@ -1029,11 +1030,8 @@ def dump(self, filename: str, datatype: str="json.gz", **kwargs):
10291030
specifies a custom path for the file (default '.')
10301031
"""
10311032
if datatype == "json.gz":
1032-
from .input.json import dump_to_json
1033-
if 'path' in kwargs:
1034-
file_name = kwargs.get('path') + '/' + filename
1035-
else:
1036-
file_name = filename
1033+
path = kwargs.get("path", ".")
1034+
file_name = path + '/' + filename
10371035
dump_to_json(self, file_name)
10381036
elif datatype == "pickle":
10391037
dump_object(self, filename, **kwargs)
@@ -1078,7 +1076,7 @@ def __str__(self) -> str:
10781076

10791077
__array_priority__ = 10000
10801078

1081-
def __eq__(self, y: Any) -> ndarray:
1079+
def __eq__(self, y: Any) -> ndarray[bool, None]:
10821080
if isinstance(y, Corr):
10831081
comp = np.asarray(y.content, dtype=object)
10841082
else:

0 commit comments

Comments
 (0)