Skip to content

Commit 905db38

Browse files
kitagryhirosassa
andauthored
fix: Disallow Any in generics (#493)
Co-authored-by: hirosassa <hiro.sassa@gmail.com>
1 parent 0cb86e1 commit 905db38

33 files changed

+130
-107
lines changed

gokart/build.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class WorkerProtocol(Protocol):
6262
This protocol is determined by luigi.worker.Worker.
6363
"""
6464

65-
def add(self, task: TaskOnKart) -> bool: ...
65+
def add(self, task: TaskOnKart[Any]) -> bool: ...
6666

6767
def run(self) -> bool: ...
6868

@@ -124,7 +124,7 @@ class TaskDumpConfig:
124124
output_type: TaskDumpOutputType = TaskDumpOutputType.NONE
125125

126126

127-
def process_task_info(task: TaskOnKart, task_dump_config: TaskDumpConfig = TaskDumpConfig()) -> None:
127+
def process_task_info(task: TaskOnKart[Any], task_dump_config: TaskDumpConfig = TaskDumpConfig()) -> None:
128128
match task_dump_config:
129129
case TaskDumpConfig(mode=TaskDumpMode.NONE, output_type=TaskDumpOutputType.NONE):
130130
pass

gokart/config_params.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
import luigi
46

57
import gokart
@@ -18,7 +20,7 @@ def __init__(self, config_class: type[luigi.Config], parameter_alias: dict[str,
1820
self._config_class: type[luigi.Config] = config_class
1921
self._parameter_alias: dict[str, str] = parameter_alias if parameter_alias is not None else {}
2022

21-
def __call__(self, task_class: type[gokart.TaskOnKart]) -> type[gokart.TaskOnKart]:
23+
def __call__(self, task_class: type[gokart.TaskOnKart[Any]]) -> type[gokart.TaskOnKart[Any]]:
2224
# wrap task to prevent task name from being changed
2325
@luigi.task._task_wraps(task_class)
2426
class Wrapped(task_class): # type: ignore

gokart/conflict_prevention_lock/task_lock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class TaskLockException(Exception):
2727

2828

2929
class RedisClient:
30-
_instances: dict = {}
30+
_instances: dict[Any, Any] = {}
3131

3232
def __new__(cls, *args, **kwargs):
3333
key = (args, tuple(sorted(kwargs.items())))

gokart/conflict_prevention_lock/task_lock_wrappers.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33
import functools
44
from collections.abc import Callable
55
from logging import getLogger
6-
from typing import Any
6+
from typing import ParamSpec, TypeVar
77

88
from gokart.conflict_prevention_lock.task_lock import TaskLockParams, set_lock_scheduler, set_task_lock
99

1010
logger = getLogger(__name__)
1111

1212

13-
def wrap_dump_with_lock(func: Callable, task_lock_params: TaskLockParams, exist_check: Callable) -> Callable:
13+
P = ParamSpec('P')
14+
R = TypeVar('R')
15+
16+
17+
def wrap_dump_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams, exist_check: Callable[..., bool]) -> Callable[P, R | None]:
1418
"""Redis lock wrapper function for TargetOnKart.dump().
1519
When TargetOnKart.dump() is called, dump() will be wrapped with redis lock and cache existance check.
1620
https://github.com/m3dev/gokart/issues/265
@@ -19,14 +23,15 @@ def wrap_dump_with_lock(func: Callable, task_lock_params: TaskLockParams, exist_
1923
if not task_lock_params.should_task_lock:
2024
return func
2125

22-
def wrapper(*args, **kwargs):
26+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | None:
2327
task_lock = set_task_lock(task_lock_params=task_lock_params)
2428
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)
2529

2630
try:
2731
logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} locked.')
2832
if not exist_check():
29-
func(*args, **kwargs)
33+
return func(*args, **kwargs)
34+
return None
3035
finally:
3136
logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} released.')
3237
task_lock.release()
@@ -35,7 +40,7 @@ def wrapper(*args, **kwargs):
3540
return wrapper
3641

3742

38-
def wrap_load_with_lock(func: Callable, task_lock_params: TaskLockParams) -> Callable:
43+
def wrap_load_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams) -> Callable[P, R]:
3944
"""Redis lock wrapper function for TargetOnKart.load().
4045
When TargetOnKart.load() is called, redis lock will be locked and released before load().
4146
https://github.com/m3dev/gokart/issues/265
@@ -44,7 +49,7 @@ def wrap_load_with_lock(func: Callable, task_lock_params: TaskLockParams) -> Cal
4449
if not task_lock_params.should_task_lock:
4550
return func
4651

47-
def wrapper(*args, **kwargs):
52+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
4853
task_lock = set_task_lock(task_lock_params=task_lock_params)
4954
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)
5055

@@ -58,15 +63,15 @@ def wrapper(*args, **kwargs):
5863
return wrapper
5964

6065

61-
def wrap_remove_with_lock(func: Callable, task_lock_params: TaskLockParams) -> Callable:
66+
def wrap_remove_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams) -> Callable[P, R]:
6267
"""Redis lock wrapper function for TargetOnKart.remove().
6368
When TargetOnKart.remove() is called, remove() will be simply wrapped with redis lock.
6469
https://github.com/m3dev/gokart/issues/265
6570
"""
6671
if not task_lock_params.should_task_lock:
6772
return func
6873

69-
def wrapper(*args, **kwargs):
74+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
7075
task_lock = set_task_lock(task_lock_params=task_lock_params)
7176
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)
7277

@@ -86,7 +91,7 @@ def wrapper(*args, **kwargs):
8691
return wrapper
8792

8893

89-
def wrap_run_with_lock(run_func: Callable[[], Any], task_lock_params: TaskLockParams) -> Callable[[], Any]:
94+
def wrap_run_with_lock(run_func: Callable[[], R], task_lock_params: TaskLockParams) -> Callable[[], R]:
9095
@functools.wraps(run_func)
9196
def wrapped():
9297
task_lock = set_task_lock(task_lock_params=task_lock_params)

gokart/gcs_obj_metadata_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _get_patched_obj_metadata(
9595
task_params: dict[str, str] | None = None,
9696
custom_labels: dict[str, str] | None = None,
9797
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
98-
) -> dict | Any:
98+
) -> dict[str, Any] | Any:
9999
# If metadata from response when getting bucket and object information is not dictionary,
100100
# something wrong might be happened, so return original metadata, no patched.
101101
if not isinstance(metadata, dict):

gokart/info.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from logging import getLogger
4+
from typing import Any
45

56
import luigi
67

@@ -11,7 +12,7 @@
1112

1213

1314
def make_tree_info(
14-
task: TaskOnKart,
15+
task: TaskOnKart[Any],
1516
indent: str = '',
1617
last: bool = True,
1718
details: bool = False,
@@ -43,7 +44,7 @@ def make_tree_info(
4344
return make_task_info_as_tree_str(task=task, details=details, abbr=abbr, ignore_task_names=ignore_task_names)
4445

4546

46-
class tree_info(TaskOnKart):
47+
class tree_info(TaskOnKart[Any]):
4748
mode: str = luigi.Parameter(default='', description='This must be in ["simple", "all"].')
4849
output_path: str = luigi.Parameter(default='tree.txt', description='Output file path.')
4950

gokart/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import sys
66
from logging import getLogger
7+
from typing import Any
78

89
import luigi
910
import luigi.cmdline
@@ -49,7 +50,7 @@ def _try_tree_info(cmdline_args):
4950

5051
def _try_to_delete_unnecessary_output_file(cmdline_args: list[str]) -> None:
5152
with CmdlineParser.global_instance(cmdline_args) as cp:
52-
task = cp.get_task_obj() # type: gokart.TaskOnKart
53+
task: gokart.TaskOnKart[Any] = cp.get_task_obj()
5354
if task.delete_unnecessary_output_files:
5455
if ObjectStorage.if_object_storage_path(task.workspace_directory):
5556
logger.info('delete-unnecessary-output-files is not support s3/gcs.')

gokart/task.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,11 @@ def input(self) -> FlattenableItems[TargetOnKart]:
143143
def output(self) -> FlattenableItems[TargetOnKart]:
144144
return self.make_target()
145145

146-
def requires(self) -> FlattenableItems[TaskOnKart]:
146+
def requires(self) -> FlattenableItems[TaskOnKart[Any]]:
147147
tasks = self.make_task_instance_dictionary()
148148
return tasks or [] # when tasks is empty dict, then this returns empty list.
149149

150-
def make_task_instance_dictionary(self) -> dict[str, TaskOnKart]:
150+
def make_task_instance_dictionary(self) -> dict[str, TaskOnKart[Any]]:
151151
return {key: var for key, var in vars(self).items() if self.is_task_on_kart(var)}
152152

153153
@staticmethod
@@ -395,7 +395,7 @@ def _to_str_params(task):
395395
dependencies.append(self.get_own_code())
396396
return hashlib.md5(str(dependencies).encode()).hexdigest()
397397

398-
def _get_input_targets(self, target: None | str | TargetOnKart | TaskOnKart | list[TaskOnKart]) -> FlattenableItems[TargetOnKart]:
398+
def _get_input_targets(self, target: None | str | TargetOnKart | TaskOnKart[Any] | list[TaskOnKart[Any]]) -> FlattenableItems[TargetOnKart]:
399399
if target is None:
400400
return self.input()
401401
if isinstance(target, str):
@@ -438,7 +438,7 @@ def get_info(self, only_significant=False):
438438
def _get_task_log_target(self):
439439
return self.make_target(f'log/task_log/{type(self).__name__}.pkl')
440440

441-
def get_task_log(self) -> dict:
441+
def get_task_log(self) -> dict[str, Any]:
442442
target = self._get_task_log_target()
443443
if self.task_log:
444444
return self.task_log
@@ -455,7 +455,7 @@ def _dump_task_log(self):
455455
def _get_task_params_target(self):
456456
return self.make_target(f'log/task_params/{type(self).__name__}.pkl')
457457

458-
def get_task_params(self) -> dict:
458+
def get_task_params(self) -> dict[str, Any]:
459459
target = self._get_task_log_target()
460460
if target.exists():
461461
return cast(dict[Any, Any], self.load(target))

gokart/task_complete_check.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import functools
44
from collections.abc import Callable
55
from logging import getLogger
6+
from typing import Any
67

78
logger = getLogger(__name__)
89

910

10-
def task_complete_check_wrapper(run_func: Callable, complete_check_func: Callable) -> Callable:
11+
def task_complete_check_wrapper(run_func: Callable[..., Any], complete_check_func: Callable[..., Any]) -> Callable[..., Any]:
1112
@functools.wraps(run_func)
1213
def wrapper(*args, **kwargs):
1314
if complete_check_func():

gokart/testing/check_if_run_with_empty_data_frame.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
import sys
5+
from typing import Any
56

67
import luigi
78
from luigi.cmdline_parser import CmdlineParser
@@ -14,15 +15,15 @@
1415
test_logger.setLevel(logging.INFO)
1516

1617

17-
class test_run(gokart.TaskOnKart):
18+
class test_run(gokart.TaskOnKart[Any]):
1819
pandas: bool = luigi.BoolParameter()
1920
namespace: str | None = luigi.OptionalParameter(
2021
default=None, description='When task namespace is not defined explicitly, please use "__not_user_specified".'
2122
)
2223

2324

2425
class _TestStatus:
25-
def __init__(self, task: gokart.TaskOnKart) -> None:
26+
def __init__(self, task: gokart.TaskOnKart[Any]) -> None:
2627
self.namespace = task.task_namespace
2728
self.name = type(task).__name__
2829
self.task_id = task.make_unique_id()
@@ -39,14 +40,14 @@ def fail(self) -> bool:
3940
return self.status != 'OK'
4041

4142

42-
def _get_all_tasks(task: gokart.TaskOnKart) -> list[gokart.TaskOnKart]:
43+
def _get_all_tasks(task: gokart.TaskOnKart[Any]) -> list[gokart.TaskOnKart[Any]]:
4344
result = [task]
4445
for o in flatten(task.requires()):
4546
result.extend(_get_all_tasks(o))
4647
return result
4748

4849

49-
def _run_with_test_status(task: gokart.TaskOnKart) -> _TestStatus:
50+
def _run_with_test_status(task: gokart.TaskOnKart[Any]) -> _TestStatus:
5051
test_message = _TestStatus(task)
5152
try:
5253
task.run()

0 commit comments

Comments
 (0)