|
10 | 10 | from .fits import least_squares, Fit_result |
11 | 11 | from .roots import find_root |
12 | 12 | from . import linalg |
| 13 | +from .input.json import dump_to_json |
13 | 14 | from numpy import ndarray, ufunc |
14 | | -from typing import Any, Callable, Optional, Union |
| 15 | +from typing import Any, Callable, Optional, Union, Literal |
15 | 16 |
|
16 | 17 |
|
17 | 18 | class Corr: |
@@ -45,7 +46,7 @@ class Corr: |
45 | 46 |
|
46 | 47 | __slots__ = ["content", "N", "T", "tag", "prange"] |
47 | 48 |
|
48 | | - def __init__(self, data_input: list[Obs, CObs], padding: list[int]=[0, 0], prange: Optional[list[int]]=None): |
| 49 | + def __init__(self, data_input: Union[list[Obs, CObs], list[ndarray[ndarray[Obs, CObs]]], ndarray[ndarray[Corr]]], padding: list[int]=[0, 0], prange: Optional[list[int]]=None): |
49 | 50 | """ Initialize a Corr object. |
50 | 51 |
|
51 | 52 | Parameters |
@@ -303,7 +304,7 @@ def matrix_symmetric(self) -> "Corr": |
303 | 304 | transposed = [None if _check_for_none(self, G) else G.T for G in self.content] |
304 | 305 | return 0.5 * (Corr(transposed) + self) |
305 | 306 |
|
306 | | - def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[str]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[list[list[Optional[ndarray]]], ndarray, list[Optional[ndarray]]]: |
| 307 | + def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[Literal["Eigenvalue", "Eigenvector"]]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[list[list[Optional[ndarray]]], ndarray, list[Optional[ndarray]]]: |
307 | 308 | r'''Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors. |
308 | 309 |
|
309 | 310 | The eigenvectors are sorted according to the descending eigenvalues, the zeroth eigenvector(s) correspond to the |
@@ -409,7 +410,7 @@ def _get_mat_at_t(t, vector_obs=vector_obs): |
409 | 410 | else: |
410 | 411 | return reordered_vecs |
411 | 412 |
|
412 | | - def Eigenvalue(self, t0: int, ts: None=None, state: int=0, sort: str="Eigenvalue", **kwargs) -> "Corr": |
| 413 | + def Eigenvalue(self, t0: int, ts: Optional[int]=None, state: int=0, sort: Optional[Literal["Eigenvalue", "Eigenvector"]]="Eigenvalue", **kwargs) -> "Corr": |
413 | 414 | """Determines the eigenvalue of the GEVP by solving and projecting the correlator |
414 | 415 |
|
415 | 416 | Parameters |
@@ -495,7 +496,7 @@ def thin(self, spacing: int=2, offset: int=0) -> "Corr": |
495 | 496 | new_content.append(self.content[t]) |
496 | 497 | return Corr(new_content) |
497 | 498 |
|
498 | | - def correlate(self, partner: Union[Corr, float, Obs]) -> "Corr": |
| 499 | + def correlate(self, partner: Union[Corr, Obs]) -> "Corr": |
499 | 500 | """Correlate the correlator with another correlator or Obs |
500 | 501 |
|
501 | 502 | Parameters |
@@ -577,14 +578,14 @@ def T_symmetry(self, partner: "Corr", parity: int=+1) -> "Corr": |
577 | 578 |
|
578 | 579 | return (self + T_partner) / 2 |
579 | 580 |
|
580 | | - def deriv(self, variant: Optional[str]="symmetric") -> "Corr": |
| 581 | + def deriv(self, variant: Literal["symmetric", "forward", "backward", "improved", "log"]="symmetric") -> "Corr": |
581 | 582 | """Return the first derivative of the correlator with respect to x0. |
582 | 583 |
|
583 | 584 | Parameters |
584 | 585 | ---------- |
585 | 586 | variant : str |
586 | 587 | decides which definition of the finite differences derivative is used. |
587 | | - Available choice: symmetric, forward, backward, improved, log, default: symmetric |
| 588 | + Available choices: symmetric, forward, backward, improved, log, default: symmetric |
588 | 589 | """ |
589 | 590 | if self.N != 1: |
590 | 591 | raise ValueError("deriv only implemented for one-dimensional correlators.") |
@@ -638,7 +639,7 @@ def deriv(self, variant: Optional[str]="symmetric") -> "Corr": |
638 | 639 | else: |
639 | 640 | raise ValueError("Unknown variant.") |
640 | 641 |
|
641 | | - def second_deriv(self, variant: Optional[str]="symmetric") -> "Corr": |
| 642 | + def second_deriv(self, variant: Literal["symmetric", "big_symmetric", "improved", "log"]="symmetric") -> "Corr": |
642 | 643 | r"""Return the second derivative of the correlator with respect to x0. |
643 | 644 |
|
644 | 645 | Parameters |
@@ -698,7 +699,7 @@ def second_deriv(self, variant: Optional[str]="symmetric") -> "Corr": |
698 | 699 | else: |
699 | 700 | raise ValueError("Unknown variant.") |
700 | 701 |
|
701 | | - def m_eff(self, variant: str='log', guess: float=1.0) -> "Corr": |
| 702 | + def m_eff(self, variant: Literal["log", "cosh", "periodic", "sinh", "arccosh", "logsym"]='log', guess: float=1.0) -> "Corr": |
702 | 703 | """Returns the effective mass of the correlator as correlator object |
703 | 704 |
|
704 | 705 | Parameters |
@@ -813,7 +814,7 @@ def fit(self, function: Callable, fitrange: Optional[list[int]]=None, silent: bo |
813 | 814 | result = least_squares(xs, ys, function, silent=silent, **kwargs) |
814 | 815 | return result |
815 | 816 |
|
816 | | - def plateau(self, plateau_range: Optional[list[int]]=None, method: str="fit", auto_gamma: bool=False) -> Obs: |
| 817 | + def plateau(self, plateau_range: Optional[list[int]]=None, method: Literal['fit', 'avg']="fit", auto_gamma: bool=False) -> Obs: |
817 | 818 | """ Extract a plateau value from a Corr object |
818 | 819 |
|
819 | 820 | Parameters |
@@ -862,7 +863,7 @@ def set_prange(self, prange: list[int]): |
862 | 863 | self.prange = prange |
863 | 864 | return |
864 | 865 |
|
865 | | - def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Optional[Obs, float, int]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Optional[int, float]=None, references: Optional[list[float]]=None, title: Optional[str]=None): |
| 866 | + def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Union[Obs, float, int, None]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Union[int, float, None]=None, references: Optional[list[float]]=None, title: Optional[str]=None): |
866 | 867 | """Plots the correlator using the tag of the correlator as label if available. |
867 | 868 |
|
868 | 869 | Parameters |
@@ -1029,11 +1030,8 @@ def dump(self, filename: str, datatype: str="json.gz", **kwargs): |
1029 | 1030 | specifies a custom path for the file (default '.') |
1030 | 1031 | """ |
1031 | 1032 | if datatype == "json.gz": |
1032 | | - from .input.json import dump_to_json |
1033 | | - if 'path' in kwargs: |
1034 | | - file_name = kwargs.get('path') + '/' + filename |
1035 | | - else: |
1036 | | - file_name = filename |
| 1033 | + path = kwargs.get("path", ".") |
| 1034 | + file_name = path + '/' + filename |
1037 | 1035 | dump_to_json(self, file_name) |
1038 | 1036 | elif datatype == "pickle": |
1039 | 1037 | dump_object(self, filename, **kwargs) |
@@ -1078,7 +1076,7 @@ def __str__(self) -> str: |
1078 | 1076 |
|
1079 | 1077 | __array_priority__ = 10000 |
1080 | 1078 |
|
1081 | | - def __eq__(self, y: Any) -> ndarray: |
| 1079 | + def __eq__(self, y: Any) -> ndarray[bool, None]: |
1082 | 1080 | if isinstance(y, Corr): |
1083 | 1081 | comp = np.asarray(y.content, dtype=object) |
1084 | 1082 | else: |
|
0 commit comments