diff --git a/docs/conf.py b/docs/conf.py index eda84116..d0ae7060 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - # https://github.com/sphinx-doc/sphinx/issues/6211 import luigi diff --git a/gokart/build.py b/gokart/build.py index 001f2954..ff43f4e3 100644 --- a/gokart/build.py +++ b/gokart/build.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import enum import logging import sys from dataclasses import dataclass from functools import partial from logging import getLogger -from typing import Literal, Optional, Protocol, TypeVar, cast, overload +from typing import Literal, Protocol, TypeVar, cast, overload import backoff import luigi @@ -62,7 +64,7 @@ def add(self, task: TaskOnKart) -> bool: ... def run(self) -> bool: ... - def __enter__(self) -> 'WorkerProtocol': ... + def __enter__(self) -> WorkerProtocol: ... def __exit__(self, type, value, traceback) -> Literal[False]: ... @@ -162,7 +164,7 @@ def build( task_lock_exception_max_wait_seconds: int = 600, task_dump_config: TaskDumpConfig = TaskDumpConfig(), **env_params, -) -> Optional[T]: +) -> T | None: """ Run gokart task for local interpreter. Sharing the most of its parameters with luigi.build (see https://luigi.readthedocs.io/en/stable/api/luigi.html?highlight=build#luigi.build) diff --git a/gokart/build_process_task_info.py b/gokart/build_process_task_info.py index 62cd318d..7b4cd004 100644 --- a/gokart/build_process_task_info.py +++ b/gokart/build_process_task_info.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import io import gokart diff --git a/gokart/config_params.py b/gokart/config_params.py index e5e1bf95..74deb9d2 100644 --- a/gokart/config_params.py +++ b/gokart/config_params.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Type +from __future__ import annotations import luigi @@ -6,7 +6,7 @@ class inherits_config_params: - def __init__(self, config_class: Type[luigi.Config], parameter_alias: Optional[Dict[str, str]] = None): + def __init__(self, config_class: type[luigi.Config], parameter_alias: dict[str, str] | None = None): """ Decorates task to inherit parameter value of `config_class`. @@ -15,10 +15,10 @@ def __init__(self, config_class: Type[luigi.Config], parameter_alias: Optional[D key: config_class's parameter name. value: decorated task's parameter name. """ - self._config_class: Type[luigi.Config] = config_class - self._parameter_alias: Dict[str, str] = parameter_alias if parameter_alias is not None else {} + self._config_class: type[luigi.Config] = config_class + self._parameter_alias: dict[str, str] = parameter_alias if parameter_alias is not None else {} - def __call__(self, task_class: Type[gokart.TaskOnKart]): + def __call__(self, task_class: type[gokart.TaskOnKart]): # wrap task to prevent task name from being changed @luigi.task._task_wraps(task_class) class Wrapped(task_class): # type: ignore @@ -29,6 +29,6 @@ def get_param_values(cls, params, args, kwargs): if hasattr(cls, task_param_key) and task_param_key not in kwargs: kwargs[task_param_key] = param_value - return super(Wrapped, cls).get_param_values(params, args, kwargs) + return super().get_param_values(params, args, kwargs) return Wrapped diff --git a/gokart/conflict_prevention_lock/task_lock.py b/gokart/conflict_prevention_lock/task_lock.py index e67bf535..038f4139 100644 --- a/gokart/conflict_prevention_lock/task_lock.py +++ b/gokart/conflict_prevention_lock/task_lock.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import functools import os from logging import getLogger -from typing import NamedTuple, Optional +from typing import NamedTuple import redis from apscheduler.schedulers.background import BackgroundScheduler @@ -10,9 +12,9 @@ class TaskLockParams(NamedTuple): - redis_host: Optional[str] - redis_port: Optional[int] - redis_timeout: Optional[int] + redis_host: str | None + redis_port: int | None + redis_timeout: int | None redis_key: str should_task_lock: bool raise_task_lock_exception_on_collision: bool @@ -31,10 +33,10 @@ def __new__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = {} if key not in cls._instances[cls]: - cls._instances[cls][key] = super(RedisClient, cls).__new__(cls) + cls._instances[cls][key] = super().__new__(cls) return cls._instances[cls][key] - def __init__(self, host: Optional[str], port: Optional[int]) -> None: + def __init__(self, host: str | None, port: int | None) -> None: if not hasattr(self, '_redis_client'): host = host or 'localhost' port = port or 6379 @@ -72,17 +74,17 @@ def set_lock_scheduler(task_lock: redis.lock.Lock, task_lock_params: TaskLockPar return scheduler -def make_task_lock_key(file_path: str, unique_id: Optional[str]): +def make_task_lock_key(file_path: str, unique_id: str | None): basename_without_ext = os.path.splitext(os.path.basename(file_path))[0] return f'{basename_without_ext}_{unique_id}' def make_task_lock_params( file_path: str, - unique_id: Optional[str], - redis_host: Optional[str] = None, - redis_port: Optional[int] = None, - redis_timeout: Optional[int] = None, + unique_id: str | None, + redis_host: str | None = None, + redis_port: int | None = None, + redis_timeout: int | None = None, raise_task_lock_exception_on_collision: bool = False, lock_extend_seconds: int = 10, ) -> TaskLockParams: diff --git a/gokart/conflict_prevention_lock/task_lock_wrappers.py b/gokart/conflict_prevention_lock/task_lock_wrappers.py index cb7c5d1e..b7afa00e 100644 --- a/gokart/conflict_prevention_lock/task_lock_wrappers.py +++ b/gokart/conflict_prevention_lock/task_lock_wrappers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools from logging import getLogger from typing import Any, Callable diff --git a/gokart/file_processor.py b/gokart/file_processor.py index b7b7d2d7..ba468000 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import os import xml.etree.ElementTree as ET from abc import abstractmethod from io import BytesIO from logging import getLogger -from typing import Optional import dill import luigi @@ -20,7 +21,7 @@ logger = getLogger(__name__) -class FileProcessor(object): +class FileProcessor: @abstractmethod def format(self): pass @@ -56,7 +57,7 @@ def dump(self, obj, file): file.write(obj) -class _ChunkedLargeFileReader(object): +class _ChunkedLargeFileReader: def __init__(self, file) -> None: self._file = file @@ -125,7 +126,7 @@ class CsvFileProcessor(FileProcessor): def __init__(self, sep=',', encoding: str = 'utf-8'): self._sep = sep self._encoding = encoding - super(CsvFileProcessor, self).__init__() + super().__init__() def format(self): return TextFormat(encoding=self._encoding) @@ -157,7 +158,7 @@ def dump(self, obj, file): class JsonFileProcessor(FileProcessor): - def __init__(self, orient: Optional[str] = None): + def __init__(self, orient: str | None = None): self._orient = orient def format(self): @@ -209,7 +210,7 @@ class ParquetFileProcessor(FileProcessor): def __init__(self, engine='pyarrow', compression=None): self._engine = engine self._compression = compression - super(ParquetFileProcessor, self).__init__() + super().__init__() def format(self): return luigi.format.Nop @@ -232,7 +233,7 @@ def dump(self, obj, file): class FeatherFileProcessor(FileProcessor): def __init__(self, store_index_in_feather: bool): - super(FeatherFileProcessor, self).__init__() + super().__init__() self._store_index_in_feather = store_index_in_feather self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__' diff --git a/gokart/gcs_config.py b/gokart/gcs_config.py index 4dbb887a..224ca68d 100644 --- a/gokart/gcs_config.py +++ b/gokart/gcs_config.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import json import os -from typing import Optional import luigi import luigi.contrib.gcs @@ -19,7 +20,7 @@ def get_gcs_client(self) -> luigi.contrib.gcs.GCSClient: def _get_gcs_client(self) -> luigi.contrib.gcs.GCSClient: return luigi.contrib.gcs.GCSClient(oauth_credentials=self._load_oauth_credentials()) - def _load_oauth_credentials(self) -> Optional[Credentials]: + def _load_oauth_credentials(self) -> Credentials | None: json_str = os.environ.get(self.gcs_credential_name) if not json_str: return None diff --git a/gokart/gcs_obj_metadata_client.py b/gokart/gcs_obj_metadata_client.py index 9c789cac..5b488693 100644 --- a/gokart/gcs_obj_metadata_client.py +++ b/gokart/gcs_obj_metadata_client.py @@ -3,7 +3,7 @@ import copy import re from logging import getLogger -from typing import Any, Union +from typing import Any from urllib.parse import urlsplit from googleapiclient.model import makepatch @@ -84,7 +84,7 @@ def _get_patched_obj_metadata( metadata: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None, - ) -> Union[dict, Any]: + ) -> dict | Any: # If metadata from response when getting bucket and object information is not dictionary, # something wrong might be happened, so return original metadata, no patched. if not isinstance(metadata, dict): diff --git a/gokart/gcs_zip_client.py b/gokart/gcs_zip_client.py index c777c828..167aec2f 100644 --- a/gokart/gcs_zip_client.py +++ b/gokart/gcs_zip_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil diff --git a/gokart/in_memory/data.py b/gokart/in_memory/data.py index 4430af44c..6deb2b42 100644 --- a/gokart/in_memory/data.py +++ b/gokart/in_memory/data.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from datetime import datetime from typing import Any @@ -9,5 +11,5 @@ class InMemoryData: last_modification_time: datetime @classmethod - def create_data(self, value: Any) -> 'InMemoryData': + def create_data(self, value: Any) -> InMemoryData: return InMemoryData(value=value, last_modification_time=datetime.now()) diff --git a/gokart/in_memory/repository.py b/gokart/in_memory/repository.py index a90f0178..6b66cd8b 100644 --- a/gokart/in_memory/repository.py +++ b/gokart/in_memory/repository.py @@ -1,4 +1,7 @@ -from typing import Any, Iterator +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any from .data import InMemoryData diff --git a/gokart/info.py b/gokart/info.py index 1ed093b0..1f5bd19c 100644 --- a/gokart/info.py +++ b/gokart/info.py @@ -1,5 +1,6 @@ +from __future__ import annotations + from logging import getLogger -from typing import List, Optional, Set import luigi @@ -15,8 +16,8 @@ def make_tree_info( last: bool = True, details: bool = False, abbr: bool = True, - visited_tasks: Optional[Set[str]] = None, - ignore_task_names: Optional[List[str]] = None, + visited_tasks: set[str] | None = None, + ignore_task_names: list[str] | None = None, ) -> str: """ Return a string representation of the tasks, their statuses/parameters in a dependency tree format @@ -32,7 +33,7 @@ def make_tree_info( Whether or not to output details. - abbr: bool Whether or not to simplify tasks information that has already appeared. - - ignore_task_names: Optional[List[str]] + - ignore_task_names: list[str] | None List of task names to ignore. Returns ------- diff --git a/gokart/mypy.py b/gokart/mypy.py index 83a53b55..abe108a6 100644 --- a/gokart/mypy.py +++ b/gokart/mypy.py @@ -7,7 +7,8 @@ from __future__ import annotations import re -from typing import Callable, Final, Iterator, Literal, Optional +from collections.abc import Iterator +from typing import Callable, Final, Literal import luigi from mypy.expandtype import expand_type @@ -233,7 +234,7 @@ def _get_assignment_statements_from_block(self, block: Block) -> Iterator[Assign elif isinstance(stmt, IfStmt): yield from self._get_assignment_statements_from_if_statement(stmt) - def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]: + def collect_attributes(self) -> list[TaskOnKartAttribute] | None: """Collect all attributes declared in the task and its parents. All assignments of the form @@ -360,7 +361,7 @@ def _collect_parameter_args(self, expr: Expression) -> tuple[bool, dict[str, Exp return True, args return False, {} - def _infer_type_from_parameters(self, parameter: Expression) -> Optional[Type]: + def _infer_type_from_parameters(self, parameter: Expression) -> Type | None: """ Generate default type from Parameter. For example, when parameter is `luigi.parameter.Parameter`, this method should return `str` type. @@ -369,7 +370,7 @@ def _infer_type_from_parameters(self, parameter: Expression) -> Optional[Type]: if parameter_name is None: return None - underlying_type: Optional[Type] = None + underlying_type: Type | None = None if parameter_name in ['luigi.parameter.Parameter', 'luigi.parameter.OptionalParameter']: underlying_type = self._api.named_type('builtins.str', []) elif parameter_name in ['luigi.parameter.IntParameter', 'luigi.parameter.OptionalIntParameter']: @@ -422,7 +423,7 @@ def _infer_type_from_parameters(self, parameter: Expression) -> Optional[Type]: return underlying_type - def _get_type_from_args(self, parameter: Expression, arg_key: str) -> Optional[Type]: + def _get_type_from_args(self, parameter: Expression, arg_key: str) -> Type | None: """ get type from parameter arguments. @@ -452,7 +453,7 @@ def is_parameter_call(expr: Expression) -> bool: return PARAMETER_FULLNAME_MATCHER.match(parameter_name) is not None -def _extract_parameter_name(expr: Expression) -> Optional[str]: +def _extract_parameter_name(expr: Expression) -> str | None: """Extract name if the expression is a call to luigi.Parameter()""" if not isinstance(expr, CallExpr): return None diff --git a/gokart/object_storage.py b/gokart/object_storage.py index 8655708c..0fded339 100644 --- a/gokart/object_storage.py +++ b/gokart/object_storage.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime import luigi @@ -14,7 +16,7 @@ object_storage_path_prefix = ['s3://', 'gs://'] -class ObjectStorage(object): +class ObjectStorage: @staticmethod def if_object_storage_path(path: str) -> bool: for prefix in object_storage_path_prefix: diff --git a/gokart/pandas_type_config.py b/gokart/pandas_type_config.py index 6b5338e4..f760b1cc 100644 --- a/gokart/pandas_type_config.py +++ b/gokart/pandas_type_config.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from abc import abstractmethod from logging import getLogger -from typing import Any, Dict +from typing import Any import luigi import numpy as np @@ -17,7 +19,7 @@ class PandasTypeError(Exception): class PandasTypeConfig(luigi.Config): @classmethod @abstractmethod - def type_dict(cls) -> Dict[str, Any]: + def type_dict(cls) -> dict[str, Any]: pass @classmethod @@ -39,7 +41,7 @@ class PandasTypeConfigMap(luigi.Config): """To initialize this class only once, this inherits luigi.Config.""" def __init__(self, *args, **kwargs) -> None: - super(PandasTypeConfigMap, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) task_names = Register.task_names() task_classes = [Register.get_task_cls(task_name) for task_name in task_names] self._map = { diff --git a/gokart/parameter.py b/gokart/parameter.py index 1a945ca1..47252b9d 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import bz2 import datetime import json diff --git a/gokart/run.py b/gokart/run.py index 4093a093..f19deb9a 100644 --- a/gokart/run.py +++ b/gokart/run.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import os import sys from logging import getLogger -from typing import List, Optional import luigi import luigi.cmdline @@ -40,7 +41,7 @@ def _try_tree_info(cmdline_args): sys.exit() -def _try_to_delete_unnecessary_output_file(cmdline_args: List[str]): +def _try_to_delete_unnecessary_output_file(cmdline_args: list[str]): with CmdlineParser.global_instance(cmdline_args) as cp: task = cp.get_task_obj() # type: gokart.TaskOnKart if task.delete_unnecessary_output_files: @@ -51,7 +52,7 @@ def _try_to_delete_unnecessary_output_file(cmdline_args: List[str]): sys.exit() -def _try_get_slack_api(cmdline_args: List[str]) -> Optional[gokart.slack.SlackAPI]: +def _try_get_slack_api(cmdline_args: list[str]) -> gokart.slack.SlackAPI | None: with CmdlineParser.global_instance(cmdline_args): config = gokart.slack.SlackConfig() token = os.getenv(config.token_name, '') @@ -64,7 +65,7 @@ def _try_get_slack_api(cmdline_args: List[str]) -> Optional[gokart.slack.SlackAP return None -def _try_to_send_event_summary_to_slack(slack_api: Optional[gokart.slack.SlackAPI], event_aggregator: gokart.slack.EventAggregator, cmdline_args: List[str]): +def _try_to_send_event_summary_to_slack(slack_api: gokart.slack.SlackAPI | None, event_aggregator: gokart.slack.EventAggregator, cmdline_args: list[str]): if slack_api is None: # do nothing return diff --git a/gokart/s3_config.py b/gokart/s3_config.py index 766fae8f..fde845e2 100644 --- a/gokart/s3_config.py +++ b/gokart/s3_config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import luigi diff --git a/gokart/s3_zip_client.py b/gokart/s3_zip_client.py index 7e7c4fc0..015e4ce4 100644 --- a/gokart/s3_zip_client.py +++ b/gokart/s3_zip_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil diff --git a/gokart/slack/event_aggregator.py b/gokart/slack/event_aggregator.py index 84c13851..884570ac 100644 --- a/gokart/slack/event_aggregator.py +++ b/gokart/slack/event_aggregator.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import os from logging import getLogger -from typing import List, TypedDict +from typing import TypedDict import luigi @@ -12,10 +14,10 @@ class FailureEvent(TypedDict): exception: str -class EventAggregator(object): +class EventAggregator: def __init__(self) -> None: - self._success_events: List[str] = [] - self._failure_events: List[FailureEvent] = [] + self._success_events: list[str] = [] + self._failure_events: list[FailureEvent] = [] def set_handlers(self): handlers = [(luigi.Event.SUCCESS, self._success), (luigi.Event.FAILURE, self._failure)] diff --git a/gokart/slack/slack_api.py b/gokart/slack/slack_api.py index 41fe7a8e..2f5ab01c 100644 --- a/gokart/slack/slack_api.py +++ b/gokart/slack/slack_api.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from logging import getLogger import slack_sdk @@ -17,7 +19,7 @@ class FileNotUploadedError(RuntimeError): pass -class SlackAPI(object): +class SlackAPI: def __init__(self, token, channel: str, to_user: str) -> None: self._client = slack_sdk.WebClient(token=token) self._channel_id = self._get_channel_id(channel) diff --git a/gokart/slack/slack_config.py b/gokart/slack/slack_config.py index cbc4d251..603327c7 100644 --- a/gokart/slack/slack_config.py +++ b/gokart/slack/slack_config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import luigi diff --git a/gokart/target.py b/gokart/target.py index 33e14dbb..b4b5d3a0 100644 --- a/gokart/target.py +++ b/gokart/target.py @@ -7,7 +7,7 @@ from datetime import datetime from glob import glob from logging import getLogger -from typing import Any, Optional +from typing import Any import luigi import numpy as np @@ -171,7 +171,7 @@ def _make_temporary_directory(self): os.makedirs(self._temporary_directory, exist_ok=True) -class LargeDataFrameProcessor(object): +class LargeDataFrameProcessor: def __init__(self, max_byte: int): self.max_byte = int(max_byte) @@ -195,14 +195,14 @@ def load(file_path: str) -> pd.DataFrame: return pd.concat([pd.read_pickle(file_path) for file_path in glob(os.path.join(dir_path, 'data_*.pkl'))]) -def _make_file_system_target(file_path: str, processor: Optional[FileProcessor] = None, store_index_in_feather: bool = True) -> luigi.target.FileSystemTarget: +def _make_file_system_target(file_path: str, processor: FileProcessor | None = None, store_index_in_feather: bool = True) -> luigi.target.FileSystemTarget: processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather) if ObjectStorage.if_object_storage_path(file_path): return ObjectStorage.get_object_storage_target(file_path, processor.format()) return luigi.LocalTarget(file_path, format=processor.format()) -def _make_file_path(original_path: str, unique_id: Optional[str] = None) -> str: +def _make_file_path(original_path: str, unique_id: str | None = None) -> str: if unique_id is not None: [base, extension] = os.path.splitext(original_path) return base + '_' + unique_id + extension @@ -219,9 +219,9 @@ def _get_last_modification_time(path: str) -> datetime: def make_target( file_path: str, - unique_id: Optional[str] = None, - processor: Optional[FileProcessor] = None, - task_lock_params: Optional[TaskLockParams] = None, + unique_id: str | None = None, + processor: FileProcessor | None = None, + task_lock_params: TaskLockParams | None = None, store_index_in_feather: bool = True, ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) @@ -236,8 +236,8 @@ def make_model_target( temporary_directory: str, save_function, load_function, - unique_id: Optional[str] = None, - task_lock_params: Optional[TaskLockParams] = None, + unique_id: str | None = None, + task_lock_params: TaskLockParams | None = None, ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) diff --git a/gokart/task.py b/gokart/task.py index 2271f583..5a671c2d 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -6,9 +6,10 @@ import os import random import types +from collections.abc import Generator, Iterable from importlib import import_module from logging import getLogger -from typing import Any, Callable, Dict, Generator, Generic, Iterable, List, Optional, Set, TypeVar, Union, overload +from typing import Any, Callable, Generic, TypeVar, overload import luigi import pandas as pd @@ -82,8 +83,8 @@ class TaskOnKart(luigi.Task, Generic[T]): default=FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER, description='Fix random seed method value.', significant=False ) # FIXME: should fix with OptionalIntParameter after newer luigi (https://github.com/spotify/luigi/pull/3079) will be released - redis_host: Optional[str] = luigi.OptionalParameter(default=None, description='Task lock check is deactivated, when None.', significant=False) - redis_port: Optional[int] = luigi.OptionalIntParameter( + redis_host: str | None = luigi.OptionalParameter(default=None, description='Task lock check is deactivated, when None.', significant=False) + redis_port: int | None = luigi.OptionalIntParameter( default=None, description='Task lock check is deactivated, when None.', significant=False, @@ -116,7 +117,7 @@ def __init__(self, *args, **kwargs): # 'This parameter is dumped into "workspace_directory/log/task_log/" when this task finishes with success.' self.task_log = dict() self.task_unique_id = None - super(TaskOnKart, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self._rerun_state = self.rerun self._lock_at_dump = True @@ -141,11 +142,11 @@ def input(self) -> FlattenableItems[TargetOnKart]: def output(self) -> FlattenableItems[TargetOnKart]: return self.make_target() - def requires(self) -> FlattenableItems['TaskOnKart']: + def requires(self) -> FlattenableItems[TaskOnKart]: tasks = self.make_task_instance_dictionary() return tasks or [] # when tasks is empty dict, then this returns empty list. - def make_task_instance_dictionary(self) -> Dict[str, 'TaskOnKart']: + def make_task_instance_dictionary(self) -> dict[str, TaskOnKart]: return {key: var for key, var in vars(self).items() if self.is_task_on_kart(var)} @staticmethod @@ -210,7 +211,7 @@ def clone(self, cls=None, **kwargs): return cls(**new_k) - def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None) -> TargetOnKart: + def make_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, processor: FileProcessor | None = None) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl') ) @@ -230,7 +231,7 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather ) - def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: + def make_large_data_frame_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip') ) @@ -287,15 +288,15 @@ def make_model_target( ) @overload - def load(self, target: Union[None, str, TargetOnKart] = None) -> Any: ... + def load(self, target: None | str | TargetOnKart = None) -> Any: ... @overload - def load(self, target: 'TaskOnKart[K]') -> K: ... + def load(self, target: TaskOnKart[K]) -> K: ... @overload - def load(self, target: 'List[TaskOnKart[K]]') -> List[K]: ... + def load(self, target: list[TaskOnKart[K]]) -> list[K]: ... - def load(self, target: Union[None, str, TargetOnKart, 'TaskOnKart[K]', 'List[TaskOnKart[K]]'] = None) -> Any: + def load(self, target: None | str | TargetOnKart | TaskOnKart[K] | list[TaskOnKart[K]] = None) -> Any: def _load(targets): if isinstance(targets, list) or isinstance(targets, tuple): return [_load(t) for t in targets] @@ -306,12 +307,12 @@ def _load(targets): return _load(self._get_input_targets(target)) @overload - def load_generator(self, target: Union[None, str, TargetOnKart] = None) -> Generator[Any, None, None]: ... + def load_generator(self, target: None | str | TargetOnKart = None) -> Generator[Any, None, None]: ... @overload - def load_generator(self, target: 'List[TaskOnKart[K]]') -> Generator[K, None, None]: ... + def load_generator(self, target: list[TaskOnKart[K]]) -> Generator[K, None, None]: ... - def load_generator(self, target: Union[None, str, TargetOnKart, 'List[TaskOnKart[K]]'] = None) -> Generator[Any, None, None]: + def load_generator(self, target: None | str | TargetOnKart | list[TaskOnKart[K]] = None) -> Generator[Any, None, None]: def _load(targets): if isinstance(targets, list) or isinstance(targets, tuple): for t in targets: @@ -328,9 +329,9 @@ def _load(targets): def dump(self, obj: T, target: None = None, custom_labels: dict[Any, Any] | None = None) -> None: ... @overload - def dump(self, obj: Any, target: Union[str, TargetOnKart], custom_labels: dict[Any, Any] | None = None) -> None: ... + def dump(self, obj: Any, target: str | TargetOnKart, custom_labels: dict[Any, Any] | None = None) -> None: ... - def dump(self, obj: Any, target: Union[None, str, TargetOnKart] = None, custom_labels: dict[str, Any] | None = None) -> None: + def dump(self, obj: Any, target: None | str | TargetOnKart = None, custom_labels: dict[str, Any] | None = None) -> None: PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace) if self.fail_on_empty_dump: if isinstance(obj, pd.DataFrame) and obj.empty: @@ -344,7 +345,7 @@ def dump(self, obj: Any, target: Union[None, str, TargetOnKart] = None, custom_l ) @staticmethod - def get_code(target_class) -> Set[str]: + def get_code(target_class) -> set[str]: def has_sourcecode(obj): return inspect.ismethod(obj) or inspect.isfunction(obj) or inspect.isframe(obj) or inspect.iscode(obj) @@ -379,7 +380,7 @@ def _to_str_params(task): dependencies.append(self.get_own_code()) return hashlib.md5(str(dependencies).encode()).hexdigest() - def _get_input_targets(self, target: Union[None, str, TargetOnKart, 'TaskOnKart', 'List[TaskOnKart]']) -> FlattenableItems[TargetOnKart]: + def _get_input_targets(self, target: None | str | TargetOnKart | TaskOnKart | list[TaskOnKart]) -> FlattenableItems[TargetOnKart]: if target is None: return self.input() if isinstance(target, str): @@ -395,7 +396,7 @@ def _get_input_targets(self, target: Union[None, str, TargetOnKart, 'TaskOnKart' return target.output() return target - def _get_output_target(self, target: Union[None, str, TargetOnKart]) -> TargetOnKart: + def _get_output_target(self, target: None | str | TargetOnKart) -> TargetOnKart: if target is None: output = self.output() assert isinstance(output, TargetOnKart), f'output must be TargetOnKart, but {type(output)} is passed.' @@ -422,7 +423,7 @@ def get_info(self, only_significant=False): def _get_task_log_target(self): return self.make_target(f'log/task_log/{type(self).__name__}.pkl') - def get_task_log(self) -> Dict: + def get_task_log(self) -> dict: target = self._get_task_log_target() if self.task_log: return self.task_log @@ -439,7 +440,7 @@ def _dump_task_log(self): def _get_task_params_target(self): return self.make_target(f'log/task_params/{type(self).__name__}.pkl') - def get_task_params(self) -> Dict: + def get_task_params(self) -> dict: target = self._get_task_log_target() if target.exists(): return self.load(target) @@ -456,8 +457,8 @@ def _get_random_seeds_target(self): return self.make_target(f'log/random_seed/{type(self).__name__}.pkl') @staticmethod - def try_set_seed(methods: List[str], random_seed: int) -> List[str]: - success_methods: List[str] = [] + def try_set_seed(methods: list[str], random_seed: int) -> list[str]: + success_methods: list[str] = [] for method_name in methods: try: for i, x in enumerate(method_name.split('.')): diff --git a/gokart/task_complete_check.py b/gokart/task_complete_check.py index 53c9f92d..8ed3c2c7 100644 --- a/gokart/task_complete_check.py +++ b/gokart/task_complete_check.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools from logging import getLogger from typing import Callable diff --git a/gokart/testing/check_if_run_with_empty_data_frame.py b/gokart/testing/check_if_run_with_empty_data_frame.py index fac31344..e17b83f5 100644 --- a/gokart/testing/check_if_run_with_empty_data_frame.py +++ b/gokart/testing/check_if_run_with_empty_data_frame.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import logging import sys -from typing import List, Optional import luigi from luigi.cmdline_parser import CmdlineParser @@ -15,7 +16,7 @@ class test_run(gokart.TaskOnKart): pandas: bool = luigi.BoolParameter() - namespace: Optional[str] = luigi.OptionalParameter( + namespace: str | None = luigi.OptionalParameter( default=None, description='When task namespace is not defined explicitly, please use "__not_user_specified".' ) @@ -26,7 +27,7 @@ def __init__(self, task: gokart.TaskOnKart) -> None: self.name = type(task).__name__ self.task_id = task.make_unique_id() self.status = 'OK' - self.message: Optional[Exception] = None + self.message: Exception | None = None def format(self) -> str: s = f'status={self.status}; namespace={self.namespace}; name={self.name}; id={self.task_id};' @@ -38,7 +39,7 @@ def fail(self) -> bool: return self.status != 'OK' -def _get_all_tasks(task: gokart.TaskOnKart) -> List[gokart.TaskOnKart]: +def _get_all_tasks(task: gokart.TaskOnKart) -> list[gokart.TaskOnKart]: result = [task] for o in flatten(task.requires()): result.extend(_get_all_tasks(o)) @@ -55,7 +56,7 @@ def _run_with_test_status(task: gokart.TaskOnKart): return test_message -def _test_run_with_empty_data_frame(cmdline_args: List[str], test_run_params: test_run): +def _test_run_with_empty_data_frame(cmdline_args: list[str], test_run_params: test_run): from unittest.mock import patch try: @@ -77,7 +78,7 @@ def _test_run_with_empty_data_frame(cmdline_args: List[str], test_run_params: te sys.exit(1) -def try_to_run_test_for_empty_data_frame(cmdline_args: List[str]): +def try_to_run_test_for_empty_data_frame(cmdline_args: list[str]): with CmdlineParser.global_instance(cmdline_args): test_run_params = test_run() diff --git a/gokart/testing/pandas_assert.py b/gokart/testing/pandas_assert.py index 68a21bc6..b240d4c1 100644 --- a/gokart/testing/pandas_assert.py +++ b/gokart/testing/pandas_assert.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pandas as pd diff --git a/gokart/tree/task_info.py b/gokart/tree/task_info.py index 97cfde57..e8b46526 100644 --- a/gokart/tree/task_info.py +++ b/gokart/tree/task_info.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import os -from typing import List, Optional import pandas as pd @@ -8,7 +9,7 @@ from gokart.tree.task_info_formatter import make_task_info_tree, make_tree_info, make_tree_info_table_list -def make_task_info_as_tree_str(task: TaskOnKart, details: bool = False, abbr: bool = True, ignore_task_names: Optional[List[str]] = None): +def make_task_info_as_tree_str(task: TaskOnKart, details: bool = False, abbr: bool = True, ignore_task_names: list[str] | None = None): """ Return a string representation of the tasks, their statuses/parameters in a dependency tree format @@ -20,7 +21,7 @@ def make_task_info_as_tree_str(task: TaskOnKart, details: bool = False, abbr: bo Whether or not to output details. - abbr: bool Whether or not to simplify tasks information that has already appeared. - - ignore_task_names: Optional[List[str]] + - ignore_task_names: list[str] | None List of task names to ignore. Returns ------- @@ -32,14 +33,14 @@ def make_task_info_as_tree_str(task: TaskOnKart, details: bool = False, abbr: bo return result -def make_task_info_as_table(task: TaskOnKart, ignore_task_names: Optional[List[str]] = None): +def make_task_info_as_table(task: TaskOnKart, ignore_task_names: list[str] | None = None): """Return a table containing information about dependent tasks. Parameters ---------- - task: TaskOnKart Root task. - - ignore_task_names: Optional[List[str]] + - ignore_task_names: list[str] | None List of task names to ignore. Returns ------- @@ -53,7 +54,7 @@ def make_task_info_as_table(task: TaskOnKart, ignore_task_names: Optional[List[s return task_info_table -def dump_task_info_table(task: TaskOnKart, task_info_dump_path: str, ignore_task_names: Optional[List[str]] = None): +def dump_task_info_table(task: TaskOnKart, task_info_dump_path: str, ignore_task_names: list[str] | None = None): """Dump a table containing information about dependent tasks. Parameters @@ -64,7 +65,7 @@ def dump_task_info_table(task: TaskOnKart, task_info_dump_path: str, ignore_task Output target file path. Path destination can be `local`, `S3`, or `GCS`. File extension can be any type that gokart file processor accepts, including `csv`, `pickle`, or `txt`. See `TaskOnKart.make_target module ` for details. - - ignore_task_names: Optional[List[str]] + - ignore_task_names: list[str] | None List of task names to ignore. Returns ------- @@ -78,7 +79,7 @@ def dump_task_info_table(task: TaskOnKart, task_info_dump_path: str, ignore_task task_info_target.dump(obj=task_info_table, lock_at_dump=False) -def dump_task_info_tree(task: TaskOnKart, task_info_dump_path: str, ignore_task_names: Optional[List[str]] = None, use_unique_id: bool = True): +def dump_task_info_tree(task: TaskOnKart, task_info_dump_path: str, ignore_task_names: list[str] | None = None, use_unique_id: bool = True): """Dump the task info tree object (TaskInfo) to a pickle file. Parameters @@ -88,7 +89,7 @@ def dump_task_info_tree(task: TaskOnKart, task_info_dump_path: str, ignore_task_ - task_info_dump_path: str Output target file path. Path destination can be `local`, `S3`, or `GCS`. File extension must be '.pkl'. - - ignore_task_names: Optional[List[str]] + - ignore_task_names: list[str] | None List of task names to ignore. - use_unique_id: bool = True Whether to use unique id to dump target file. Default is True. diff --git a/gokart/tree/task_info_formatter.py b/gokart/tree/task_info_formatter.py index a484d742..9ebb596e 100644 --- a/gokart/tree/task_info_formatter.py +++ b/gokart/tree/task_info_formatter.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import typing import warnings from dataclasses import dataclass -from typing import Dict, List, NamedTuple, Optional, Set +from typing import NamedTuple from gokart.task import TaskOnKart from gokart.utils import FlattenableItems, flatten @@ -11,13 +13,13 @@ class TaskInfo: name: str unique_id: str - output_paths: List[str] + output_paths: list[str] params: dict processing_time: str is_complete: str task_log: dict - requires: FlattenableItems['RequiredTask'] - children_task_infos: List['TaskInfo'] + requires: FlattenableItems[RequiredTask] + children_task_infos: list[TaskInfo] def get_task_id(self): return f'{self.name}_{self.unique_id}' @@ -57,14 +59,14 @@ def _make_requires_info(requires): raise TypeError(f'`requires` has unexpected type {type(requires)}. Must be `TaskOnKart`, `Iterarble[TaskOnKart]`, or `Dict[str, TaskOnKart]`') -def make_task_info_tree(task: TaskOnKart, ignore_task_names: Optional[List[str]] = None, cache: Optional[Dict[str, TaskInfo]] = None) -> TaskInfo: +def make_task_info_tree(task: TaskOnKart, ignore_task_names: list[str] | None = None, cache: dict[str, TaskInfo] | None = None) -> TaskInfo: with warnings.catch_warnings(): warnings.filterwarnings(action='ignore', message='Task .* without outputs has no custom complete() method') is_task_complete = task.complete() name = task.__class__.__name__ unique_id = task.make_unique_id() - output_paths: List[str] = [t.path() for t in flatten(task.output())] + output_paths: list[str] = [t.path() for t in flatten(task.output())] cache = {} if cache is None else cache cache_id = f'{name}_{unique_id}_{is_task_complete}' @@ -80,7 +82,7 @@ def make_task_info_tree(task: TaskOnKart, ignore_task_names: Optional[List[str]] requires = _make_requires_info(task.requires()) children = flatten(task.requires()) - children_task_infos: List[TaskInfo] = [] + children_task_infos: list[TaskInfo] = [] for child in children: if ignore_task_names is None or child.__class__.__name__ not in ignore_task_names: children_task_infos.append(make_task_info_tree(child, ignore_task_names=ignore_task_names, cache=cache)) @@ -99,7 +101,7 @@ def make_task_info_tree(task: TaskOnKart, ignore_task_names: Optional[List[str]] return task_info -def make_tree_info(task_info: TaskInfo, indent: str, last: bool, details: bool, abbr: bool, visited_tasks: Set[str]): +def make_tree_info(task_info: TaskInfo, indent: str, last: bool, details: bool, abbr: bool, visited_tasks: set[str]): result = '\n' + indent if last: result += '└─-' @@ -126,7 +128,7 @@ def make_tree_info(task_info: TaskInfo, indent: str, last: bool, details: bool, return result -def make_tree_info_table_list(task_info: TaskInfo, visited_tasks: Set[str]): +def make_tree_info_table_list(task_info: TaskInfo, visited_tasks: set[str]): task_id = task_info.get_task_id() if task_id in visited_tasks: return [] diff --git a/gokart/utils.py b/gokart/utils.py index 9e1dfba8..df8f53fa 100644 --- a/gokart/utils.py +++ b/gokart/utils.py @@ -2,8 +2,9 @@ import os import sys +from collections.abc import Iterable from io import BytesIO -from typing import Any, Iterable, Protocol, TypeVar, Union +from typing import Any, Protocol, TypeVar, Union import dill import luigi @@ -71,7 +72,7 @@ def flatten(targets: FlattenableItems[T]) -> list[T]: return flat -def load_dill_with_pandas_backward_compatibility(file: Union[FileLike, BytesIO]) -> Any: +def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> Any: """Load binary dumped by dill with pandas backward compatibility. pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle. It is unclear whether all objects dumped by dill can be loaded by pd.read_pickle, we use dill.load as a fallback. diff --git a/gokart/worker.py b/gokart/worker.py index 6133b087..eb2d1c15 100644 --- a/gokart/worker.py +++ b/gokart/worker.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # @@ -28,6 +27,8 @@ :py:class:`worker` config class. """ +from __future__ import annotations + import collections import collections.abc import contextlib @@ -48,7 +49,8 @@ import threading import time import traceback -from typing import Any, Dict, Generator, List, Literal, Optional, Set, Tuple +from collections.abc import Generator +from typing import Any, Literal import luigi import luigi.scheduler @@ -87,7 +89,7 @@ def _is_external(task: Task) -> bool: return task.run is None or task.run == NotImplemented -def _get_retry_policy_dict(task: Task) -> Dict[str, Any]: +def _get_retry_policy_dict(task: Task) -> dict[str, Any]: return RetryPolicy(task.retry_count, task.disable_hard_timeout, task.disable_window)._asdict() # type: ignore @@ -129,10 +131,10 @@ def __init__( worker_timeout: int = 0, check_unfulfilled_deps: bool = True, check_complete_on_run: bool = False, - task_completion_cache: Optional[Dict[str, Any]] = None, + task_completion_cache: dict[str, Any] | None = None, task_completion_check_at_run: bool = True, ) -> None: - super(TaskProcess, self).__init__() + super().__init__() self.task = task self.worker_id = worker_id self.result_queue = result_queue @@ -148,13 +150,13 @@ def __init__( # completeness check using the cache self.check_complete = functools.partial(luigi.worker.check_complete_cached, completion_cache=task_completion_cache) - def _run_task(self) -> Optional[collections.abc.Generator]: + def _run_task(self) -> collections.abc.Generator | None: if self.task_completion_check_at_run and self.check_complete(self.task): logger.warning(f'{self.task} is skipped because the task is already completed.') return None return self.task.run() - def _run_get_new_deps(self) -> Optional[List[Tuple[str, str, Dict[str, str]]]]: + def _run_get_new_deps(self) -> list[tuple[str, str, dict[str, str]]] | None: task_gen = self._run_task() if not isinstance(task_gen, collections.abc.Generator): @@ -191,10 +193,10 @@ def run(self) -> None: currentTime = time.time() random.seed(processID * currentTime) - status: Optional[str] = FAILED + status: str | None = FAILED expl = '' - missing: List[str] = [] - new_deps: Optional[List[Tuple[str, str, Dict[str, str]]]] = [] + missing: list[str] = [] + new_deps: list[tuple[str, str, dict[str, str]]] | None = [] try: # Verify that all the tasks are fulfilled! For external tasks we # don't care about unfulfilled dependencies, because we are just @@ -211,7 +213,7 @@ def run(self) -> None: missing.append(dep.task_id) if missing: deps = 'dependency' if len(missing) == 1 else 'dependencies' - raise RuntimeError('Unfulfilled %s at run time: %s' % (deps, ', '.join(missing))) + raise RuntimeError('Unfulfilled {} at run time: {}'.format(deps, ', '.join(missing))) self.task.trigger_event(Event.START, self.task) t0 = time.time() status = None @@ -269,7 +271,7 @@ def _recursive_terminate(self) -> None: children = parent.children(recursive=True) # terminate parent. Give it a chance to clean up - super(TaskProcess, self).terminate() + super().terminate() parent.wait() # terminate children @@ -287,7 +289,7 @@ def terminate(self) -> None: try: return self._recursive_terminate() except ImportError: - return super(TaskProcess, self).terminate() + return super().terminate() @contextlib.contextmanager def _forward_attributes(self): @@ -306,7 +308,7 @@ def _forward_attributes(self): # Discussion on generalizing it into a plugin system: https://github.com/spotify/luigi/issues/1897 class ContextManagedTaskProcess(TaskProcess): def __init__(self, context, *args, **kwargs) -> None: - super(ContextManagedTaskProcess, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.context = context def run(self) -> None: @@ -317,9 +319,9 @@ def run(self) -> None: cls = getattr(module, class_name) with cls(self): - super(ContextManagedTaskProcess, self).run() + super().run() else: - super(ContextManagedTaskProcess, self).run() + super().run() class gokart_worker(luigi.Config): @@ -396,11 +398,11 @@ class Worker: def __init__( self, - scheduler: Optional[Scheduler] = None, - worker_id: Optional[str] = None, + scheduler: Scheduler | None = None, + worker_id: str | None = None, worker_processes: int = 1, assistant: bool = False, - config: Optional[gokart_worker] = None, + config: gokart_worker | None = None, ) -> None: if scheduler is None: scheduler = Scheduler() @@ -423,17 +425,17 @@ def __init__( self._stop_requesting_work = False self.host = socket.gethostname() - self._scheduled_tasks: Dict[str, Task] = {} - self._suspended_tasks: Dict[str, Task] = {} - self._batch_running_tasks: Dict[str, Any] = {} - self._batch_families_sent: Set[str] = set() + self._scheduled_tasks: dict[str, Task] = {} + self._suspended_tasks: dict[str, Task] = {} + self._batch_running_tasks: dict[str, Any] = {} + self._batch_families_sent: set[str] = set() self._first_task = None self.add_succeeded = True self.run_succeeded = True - self.unfulfilled_counts: Dict[str, int] = collections.defaultdict(int) + self.unfulfilled_counts: dict[str, int] = collections.defaultdict(int) # note that ``signal.signal(signal.SIGUSR1, fn)`` only works inside the main execution thread, which is why we # provide the ability to conditionally install the hook. @@ -446,8 +448,8 @@ def __init__( # Keep info about what tasks are running (could be in other processes) self._task_result_queue: multiprocessing.Queue = multiprocessing.Queue() - self._running_tasks: Dict[str, TaskProcess] = {} - self._idle_since: Optional[datetime.datetime] = None + self._running_tasks: dict[str, TaskProcess] = {} + self._idle_since: datetime.datetime | None = None # mp-safe dictionary for caching completation checks across task processes self._task_completion_cache = None @@ -455,8 +457,8 @@ def __init__( self._task_completion_cache = multiprocessing.Manager().dict() # Stuff for execution_summary - self._add_task_history: List[Any] = [] - self._get_work_response_history: List[Any] = [] + self._add_task_history: list[Any] = [] + self._get_work_response_history: list[Any] = [] def _add_task(self, *args, **kwargs): """ @@ -482,7 +484,7 @@ def _add_task(self, *args, **kwargs): logger.info('Informed scheduler that task %s has status %s', task_id, status) - def __enter__(self) -> 'Worker': + def __enter__(self) -> Worker: """ Start the KeepAliveThread. """ @@ -503,10 +505,10 @@ def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]: self._task_result_queue.close() return False # Don't suppress exception - def _generate_worker_info(self) -> List[Tuple[str, Any]]: + def _generate_worker_info(self) -> list[tuple[str, Any]]: # Generate as much info as possible about the worker # Some of these calls might not be available on all OS's - args = [('salt', '%09d' % random.randrange(0, 10_000_000_000)), ('workers', self.worker_processes)] + args = [('salt', f'{random.randrange(0, 10_000_000_000):09d}'), ('workers', self.worker_processes)] try: args += [('host', socket.gethostname())] except BaseException: @@ -527,26 +529,26 @@ def _generate_worker_info(self) -> List[Tuple[str, Any]]: pass return args - def _generate_worker_id(self, worker_info: List[Any]) -> str: - worker_info_str = ', '.join(['{}={}'.format(k, v) for k, v in worker_info]) - return 'Worker({})'.format(worker_info_str) + def _generate_worker_id(self, worker_info: list[Any]) -> str: + worker_info_str = ', '.join([f'{k}={v}' for k, v in worker_info]) + return f'Worker({worker_info_str})' def _validate_task(self, task: Task) -> None: if not isinstance(task, Task): - raise luigi.worker.TaskException('Can not schedule non-task %s' % task) + raise luigi.worker.TaskException(f'Can not schedule non-task {task}') if not task.initialized(): # we can't get the repr of it since it's not initialized... raise luigi.worker.TaskException( - 'Task of class %s not initialized. Did you override __init__ and forget to call super(...).__init__?' % task.__class__.__name__ + f'Task of class {task.__class__.__name__} not initialized. Did you override __init__ and forget to call super(...).__init__?' ) def _log_complete_error(self, task: Task, tb: str) -> None: - log_msg = 'Will not run {task} or any dependencies due to error in complete() method:\n{tb}'.format(task=task, tb=tb) + log_msg = f'Will not run {task} or any dependencies due to error in complete() method:\n{tb}' logger.warning(log_msg) def _log_dependency_error(self, task: Task, tb: str) -> None: - log_msg = 'Will not run {task} or any dependencies due to error in deps() method:\n{tb}'.format(task=task, tb=tb) + log_msg = f'Will not run {task} or any dependencies due to error in deps() method:\n{tb}' logger.warning(log_msg) def _log_unexpected_error(self, task: Task) -> None: @@ -614,10 +616,10 @@ def _email_error(self, task: Task, formatted_traceback: str, subject: str, headl message = notifications.format_task_error(formatted_headline, task, command, formatted_traceback) notifications.send_error_email(formatted_subject, message, task.owner_email) - def _handle_task_load_error(self, exception: Exception, task_ids: List[str]) -> None: + def _handle_task_load_error(self, exception: Exception, task_ids: list[str]) -> None: msg = 'Cannot find task(s) sent by scheduler: {}'.format(','.join(task_ids)) logger.exception(msg) - subject = 'Luigi: {}'.format(msg) + subject = f'Luigi: {msg}' error_message = notifications.wrap_traceback(exception) for task_id in task_ids: self._add_task( @@ -778,13 +780,13 @@ def _validate_dependency(self, dependency: Task) -> None: if isinstance(dependency, Target): raise Exception('requires() can not return Target objects. Wrap it in an ExternalTask class') elif not isinstance(dependency, Task): - raise Exception('requires() must return Task objects but {} is a {}'.format(dependency, type(dependency))) + raise Exception(f'requires() must return Task objects but {dependency} is a {type(dependency)}') def _check_complete_value(self, is_complete: bool) -> None: if is_complete not in (True, False): if isinstance(is_complete, luigi.worker.TracebackWrapper): raise luigi.workerAsyncCompletionException(is_complete.trace) - raise Exception('Return value of Task.complete() must be boolean (was %r)' % is_complete) + raise Exception(f'Return value of Task.complete() must be boolean (was {is_complete!r})') def _add_worker(self) -> None: self._worker_info.append(('first_task', self._first_task)) @@ -803,7 +805,7 @@ def _log_remote_tasks(self, get_work_response: GetWorkResponse) -> None: if get_work_response.n_pending_last_scheduled: logger.debug('There are %i pending tasks last scheduled by this worker', get_work_response.n_pending_last_scheduled) - def _get_work_task_id(self, get_work_response: Dict[str, Any]) -> Optional[str]: + def _get_work_task_id(self, get_work_response: dict[str, Any]) -> str | None: if get_work_response.get('task_id') is not None: return get_work_response['task_id'] elif 'batch_id' in get_work_response: @@ -885,7 +887,7 @@ def _get_work(self) -> GetWorkResponse: def _run_task(self, task_id: str) -> None: if task_id in self._running_tasks: - logger.debug('Got already running task id {} from scheduler, taking a break'.format(task_id)) + logger.debug(f'Got already running task id {task_id} from scheduler, taking a break') next(self._sleeper()) return @@ -928,11 +930,11 @@ def _purge_children(self) -> None: """ for task_id, p in self._running_tasks.items(): if not p.is_alive() and p.exitcode: - error_msg = 'Task {} died unexpectedly with exit code {}'.format(task_id, p.exitcode) + error_msg = f'Task {task_id} died unexpectedly with exit code {p.exitcode}' p.task.trigger_event(Event.PROCESS_FAILURE, p.task, error_msg) elif p.timeout_time is not None and time.time() > float(p.timeout_time) and p.is_alive(): p.terminate() - error_msg = 'Task {} timed out after {} seconds and was terminated.'.format(task_id, p.worker_timeout) + error_msg = f'Task {task_id} timed out after {p.worker_timeout} seconds and was terminated.' p.task.trigger_event(Event.TIMEOUT, p.task, error_msg) else: continue @@ -1107,8 +1109,8 @@ def run(self) -> bool: return self.run_succeeded - def _handle_rpc_message(self, message: Dict[str, Any]) -> None: - logger.info('Worker %s got message %s' % (self._id, message)) + def _handle_rpc_message(self, message: dict[str, Any]) -> None: + logger.info(f'Worker {self._id} got message {message}') # the message is a dict {'name': , 'kwargs': } name = message['name'] @@ -1119,11 +1121,11 @@ def _handle_rpc_message(self, message: Dict[str, Any]) -> None: func = getattr(self, name, None) tpl = (self._id, name) if not callable(func): - logger.error("Worker %s has no function '%s'" % tpl) + logger.error("Worker {} has no function '{}'".format(*tpl)) elif not getattr(func, 'is_rpc_message_callback', False): - logger.error("Worker %s function '%s' is not available as rpc message callback" % tpl) + logger.error("Worker {} function '{}' is not available as rpc message callback".format(*tpl)) else: - logger.info("Worker %s successfully dispatched rpc message to function '%s'" % tpl) + logger.info("Worker {} successfully dispatched rpc message to function '{}'".format(*tpl)) func(**kwargs) @luigi.worker.rpc_message_callback diff --git a/gokart/workspace_management.py b/gokart/workspace_management.py index 0729d754..af15e2ec 100644 --- a/gokart/workspace_management.py +++ b/gokart/workspace_management.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import os import pathlib diff --git a/gokart/zip_client.py b/gokart/zip_client.py index 0fcecac9..d154252a 100644 --- a/gokart/zip_client.py +++ b/gokart/zip_client.py @@ -1,17 +1,19 @@ +from __future__ import annotations + import os import shutil import zipfile from abc import abstractmethod -from typing import IO, Union +from typing import IO -def _unzip_file(fp: Union[str, IO, os.PathLike], extract_dir: str) -> None: +def _unzip_file(fp: str | IO | os.PathLike, extract_dir: str) -> None: zip_file = zipfile.ZipFile(fp) zip_file.extractall(extract_dir) zip_file.close() -class ZipClient(object): +class ZipClient: @abstractmethod def exists(self) -> bool: pass diff --git a/gokart/zip_client_util.py b/gokart/zip_client_util.py index 1c3a207d..61180e07 100644 --- a/gokart/zip_client_util.py +++ b/gokart/zip_client_util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from gokart.object_storage import ObjectStorage from gokart.zip_client import LocalZipClient, ZipClient diff --git a/pyproject.toml b/pyproject.toml index 8afd9235..92c6f335 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,8 @@ exclude = ["venv/*", "tox/*", "examples/*"] # All the rules are listed on https://docs.astral.sh/ruff/rules/ extend-select = [ "B", # bugbear - "I" # isort + "I", # isort + "UP", # pyupgrade, upgrade syntax for newer versions of the language. ] # B006: Do not use mutable data structures for argument defaults. They are created during function definition time. All calls to the function reuse this one instance of that data structure, persisting changes between them. diff --git a/test/test_build.py b/test/test_build.py index 3cca77cb..981bc356 100644 --- a/test/test_build.py +++ b/test/test_build.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import io import logging import os import sys import unittest from copy import copy -from typing import Dict if sys.version_info >= (3, 11): from typing import assert_type @@ -30,7 +31,7 @@ def run(self): self.dump(self.param) -class _DummyTaskTwoOutputs(gokart.TaskOnKart[Dict[str, str]]): +class _DummyTaskTwoOutputs(gokart.TaskOnKart[dict[str, str]]): task_namespace = __name__ param1: str = luigi.Parameter() param2: str = luigi.Parameter() @@ -114,7 +115,7 @@ def test_build_dict_outputs(self): 'out2': 'test2', } output = gokart.build(_DummyTaskTwoOutputs(param1=param_dict['out1'], param2=param_dict['out2']), reset_register=False) - assert_type(output, Dict[str, str]) + assert_type(output, dict[str, str]) self.assertEqual(output, param_dict) def test_failed_task(self): diff --git a/test/test_file_processor.py b/test/test_file_processor.py index 38545a3f..7832dd6e 100644 --- a/test/test_file_processor.py +++ b/test/test_file_processor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import tempfile import unittest diff --git a/test/test_gcs_obj_metadata_client.py b/test/test_gcs_obj_metadata_client.py index 9e395edc..2fe37f01 100644 --- a/test/test_gcs_obj_metadata_client.py +++ b/test/test_gcs_obj_metadata_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import unittest from typing import Any diff --git a/test/test_pandas_type_check_framework.py b/test/test_pandas_type_check_framework.py index 404e12a2..205c1362 100644 --- a/test/test_pandas_type_check_framework.py +++ b/test/test_pandas_type_check_framework.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import logging import unittest from logging import getLogger -from typing import Any, Dict +from typing import Any from unittest.mock import patch import luigi @@ -19,7 +21,7 @@ class TestPandasTypeConfig(PandasTypeConfig): task_namespace = 'test_pandas_type_check_framework' @classmethod - def type_dict(cls) -> Dict[str, Any]: + def type_dict(cls) -> dict[str, Any]: return {'system_cd': int} diff --git a/test/test_pandas_type_config.py b/test/test_pandas_type_config.py index 34682ecb..d1568e6d 100644 --- a/test/test_pandas_type_config.py +++ b/test/test_pandas_type_config.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from datetime import date, datetime -from typing import Any, Dict +from typing import Any from unittest import TestCase import numpy as np @@ -11,7 +13,7 @@ class _DummyPandasTypeConfig(PandasTypeConfig): @classmethod - def type_dict(cls) -> Dict[str, Any]: + def type_dict(cls) -> dict[str, Any]: return {'int_column': int, 'datetime_column': datetime, 'array_column': np.ndarray} diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index 6e2ccff6..19a845f0 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import os import pathlib import unittest from datetime import datetime -from typing import Any, Dict, List, cast +from typing import Any, cast from unittest.mock import Mock, patch import luigi @@ -355,7 +357,7 @@ def test_load_with_task_on_kart_list(self): # task2 should be in requires' return values task.requires = lambda: {'tasks': [task2, task3]} # type: ignore - load_args: List[gokart.TaskOnKart[int]] = [task2, task3] + load_args: list[gokart.TaskOnKart[int]] = [task2, task3] actual = task.load(load_args) self.assertEqual(actual, [1, 2]) @@ -393,7 +395,7 @@ def test_load_generator_with_list_task_on_kart(self): # task2 should be in requires' return values task.requires = lambda: {'tasks': [task2, task3]} # type: ignore - load_args: List[gokart.TaskOnKart[int]] = [task2, task3] + load_args: list[gokart.TaskOnKart[int]] = [task2, task3] actual = [x for x in task.load_generator(load_args)] self.assertEqual(actual, [1, 2]) @@ -420,7 +422,7 @@ def test_fail_on_empty_dump(self): @patch('luigi.configuration.get_config') def test_add_configuration(self, mock_config: Mock): mock_config.return_value = {'_DummyTask': {'list_param': '["c", "d"]', 'param': '3', 'bool_param': 'True'}} - kwargs: Dict[str, Any] = dict() + kwargs: dict[str, Any] = dict() _DummyTask._add_configuration(kwargs, '_DummyTask') self.assertEqual(3, kwargs['param']) self.assertEqual(['c', 'd'], list(kwargs['list_param'])) diff --git a/test/tree/test_task_info.py b/test/tree/test_task_info.py index 4a6c5ad2..c5a1808e 100644 --- a/test/tree/test_task_info.py +++ b/test/tree/test_task_info.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from typing import Any from unittest.mock import patch