|
1 | | -# -*- encoding: utf-8 -*- |
2 | | -from typing import List, Optional, Tuple, Union |
| 1 | +from typing import overload |
3 | 2 |
|
4 | 3 | import logging |
5 | 4 |
|
|
16 | 15 |
|
17 | 16 | def convert_if_sparse( |
18 | 17 | y: SUPPORTED_TARGET_TYPES, |
19 | | -) -> Union[np.ndarray, List, pd.DataFrame, pd.Series]: |
| 18 | +) -> np.ndarray | list | pd.DataFrame | pd.Series: |
20 | 19 | """If the labels `y` are sparse, it will convert it to its dense representation |
21 | 20 |
|
22 | 21 | Parameters |
@@ -77,9 +76,9 @@ class InputValidator(BaseEstimator): |
77 | 76 |
|
78 | 77 | def __init__( |
79 | 78 | self, |
80 | | - feat_type: Optional[List[str]] = None, |
| 79 | + feat_type: list[str] | None = None, |
81 | 80 | is_classification: bool = False, |
82 | | - logger_port: Optional[int] = None, |
| 81 | + logger_port: int | None = None, |
83 | 82 | allow_string_features: bool = True, |
84 | 83 | ) -> None: |
85 | 84 | self.feat_type = feat_type |
@@ -108,8 +107,8 @@ def fit( |
108 | 107 | self, |
109 | 108 | X_train: SUPPORTED_FEAT_TYPES, |
110 | 109 | y_train: SUPPORTED_TARGET_TYPES, |
111 | | - X_test: Optional[SUPPORTED_FEAT_TYPES] = None, |
112 | | - y_test: Optional[SUPPORTED_TARGET_TYPES] = None, |
| 110 | + X_test: SUPPORTED_FEAT_TYPES | None = None, |
| 111 | + y_test: SUPPORTED_TARGET_TYPES | None = None, |
113 | 112 | ) -> BaseEstimator: |
114 | 113 | """ |
115 | 114 | Validates and fit a categorical encoder (if needed) to the features, and |
@@ -172,11 +171,27 @@ def fit( |
172 | 171 |
|
173 | 172 | return self |
174 | 173 |
|
| 174 | + @overload |
175 | 175 | def transform( |
176 | 176 | self, |
177 | 177 | X: SUPPORTED_FEAT_TYPES, |
178 | | - y: Optional[Union[List, pd.Series, pd.DataFrame, np.ndarray]] = None, |
179 | | - ) -> Tuple[Union[np.ndarray, pd.DataFrame, spmatrix], Optional[np.ndarray]]: |
| 178 | + y: None = None, |
| 179 | + ) -> tuple[np.ndarray | pd.DataFrame | spmatrix, None]: |
| 180 | + ... |
| 181 | + |
| 182 | + @overload |
| 183 | + def transform( |
| 184 | + self, |
| 185 | + X: SUPPORTED_FEAT_TYPES, |
| 186 | + y: list | pd.Series | pd.DataFrame | np.ndarray, |
| 187 | + ) -> tuple[np.ndarray | pd.DataFrame | spmatrix, np.ndarray]: |
| 188 | + ... |
| 189 | + |
| 190 | + def transform( |
| 191 | + self, |
| 192 | + X: SUPPORTED_FEAT_TYPES, |
| 193 | + y: list | pd.Series | pd.DataFrame | np.ndarray | None = None, |
| 194 | + ) -> tuple[np.ndarray | pd.DataFrame | spmatrix, np.ndarray | None]: |
180 | 195 | """ |
181 | 196 | Transform the given target or features to a numpy array |
182 | 197 |
|
|
0 commit comments