|
11 | 11 | import pickle |
12 | 12 | import sys |
13 | 13 | from typing import Literal |
| 14 | +from typing import Optional |
| 15 | +from typing import Union |
14 | 16 |
|
15 | 17 | import numpy as np |
16 | 18 | from numpy.random import RandomState |
@@ -335,8 +337,8 @@ def train_test_validate( |
335 | 337 | 'mixed-set', 'drug-blind', 'cancer-blind' |
336 | 338 | ]='mixed-set', |
337 | 339 | ratio: tuple[int, int, int]=(8,1,1), |
338 | | - stratify_by: (str | None)=None, |
339 | | - random_state: (int | RandomState | None)=None, |
| 340 | + stratify_by: Optional[str]=None, |
| 341 | + random_state: Optional[Union[int,RandomState]]=None, |
340 | 342 | **kwargs: dict, |
341 | 343 | ) -> Split: |
342 | 344 |
|
@@ -386,7 +388,7 @@ def save(self, path: Path) -> None: |
386 | 388 |
|
387 | 389 | def load( |
388 | 390 | name: str, |
389 | | - local_path: str|Path=Path.cwd(), |
| 391 | + local_path: Union[str,Path]=Path.cwd(), |
390 | 392 | from_pickle:bool=False |
391 | 393 | ) -> Dataset: |
392 | 394 | """ |
@@ -669,8 +671,8 @@ def train_test_validate( |
669 | 671 | 'mixed-set', 'drug-blind', 'cancer-blind' |
670 | 672 | ]='mixed-set', |
671 | 673 | ratio: tuple[int, int, int]=(8,1,1), |
672 | | - stratify_by: (str | None)=None, |
673 | | - random_state: (int | RandomState | None)=None, |
| 674 | + stratify_by: Optional[str]=None, |
| 675 | + random_state: Optional[Union[int,RandomState]]=None, |
674 | 676 | **kwargs: dict, |
675 | 677 | ) -> Split: |
676 | 678 | """ |
@@ -1015,7 +1017,7 @@ def _load_file(file_path: Path) -> pd.DataFrame: |
1015 | 1017 | ) |
1016 | 1018 |
|
1017 | 1019 |
|
1018 | | -def _determine_delimiter(file_path): |
| 1020 | +def _determine_delimiter(file_path: Path) -> str: |
1019 | 1021 | if '.tsv' in file_path.suffixes: |
1020 | 1022 | return '\t' |
1021 | 1023 | else: |
|
0 commit comments