From 8e19fb0fe58b6538504340303a3e21e6455c3260 Mon Sep 17 00:00:00 2001 From: linuxdaemon Date: Sun, 14 Apr 2024 18:57:48 +0000 Subject: [PATCH 1/3] Migrate cryptocurrency.py to use pydantic models --- .gitignore | 1 + .pre-commit-config.yaml | 6 + .vscode/settings.json | 3 +- cloudbot/event.py | 4 +- cloudbot/hook.py | 68 ++- cloudbot/util/colors.py | 2 +- cloudbot/util/func_utils.py | 10 +- cloudbot/util/web.py | 93 ++-- plugins/cryptocurrency.py | 550 ++++++---------------- plugins/duckhunt.py | 11 +- plugins/pastebins/sprunge.py | 4 +- pyproject.toml | 25 + requirements.txt | 3 + tests/plugin_tests/regex_chans_test.py | 18 +- tests/plugin_tests/test_cryptocurrency.py | 114 +---- tests/plugin_tests/test_pager_commands.py | 4 +- tests/util/__init__.py | 2 +- tests/util/mock_conn.py | 34 +- 18 files changed, 377 insertions(+), 575 deletions(-) diff --git a/.gitignore b/.gitignore index 0a8cb56f7..7411425e7 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,4 @@ config.json .vagrant/ .mypy_cache/ .DS_Store +*.bak diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 20536fa52..b8628d50b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,6 +55,12 @@ repos: args: - "--py38-plus" +- repo: https://github.com/PyCQA/autoflake + rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1 + hooks: + - id: autoflake + + - repo: local hooks: - id: mypy diff --git a/.vscode/settings.json b/.vscode/settings.json index 5dc5c5728..3e98f2820 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,5 +14,6 @@ "evenBetterToml.formatter.allowedBlankLines": 1, "evenBetterToml.formatter.arrayAutoCollapse": true, "evenBetterToml.formatter.arrayAutoExpand": false, - "evenBetterToml.formatter.arrayTrailingComma": true + "evenBetterToml.formatter.arrayTrailingComma": true, + "python.analysis.diagnosticMode": "workspace" } diff --git a/cloudbot/event.py b/cloudbot/event.py index b024cf2c4..0cd974508 100644 --- a/cloudbot/event.py +++ b/cloudbot/event.py @@ -2,7 +2,7 @@ import enum import logging from functools import partial -from typing import Any, Iterator, Mapping +from typing import Any, Iterator, Mapping, Optional from irclib.parser import Message @@ -245,7 +245,7 @@ def admin_log(self, message, broadcast=False): if conn and conn.connected: conn.admin_log(message, console=not broadcast) - def reply(self, *messages, target=None): + def reply(self, *messages: str, target: Optional[str] = None) -> None: """sends a message to the current channel/user with a prefix""" reply_ping = self.conn.config.get("reply_ping", True) if target is None: diff --git a/cloudbot/hook.py b/cloudbot/hook.py index 8beec9207..8aea4f605 100644 --- a/cloudbot/hook.py +++ b/cloudbot/hook.py @@ -4,6 +4,18 @@ import re import warnings from enum import Enum, IntEnum, unique +from typing import ( + Any, + Callable, + List, + Optional, + Sequence, + TypeVar, + Union, + overload, +) + +from typing_extensions import ParamSpec from cloudbot.event import EventType from cloudbot.util import HOOK_ATTR @@ -186,10 +198,36 @@ def _hook_warn(): ) -def command(*args, **kwargs): +_T = TypeVar("_T") +_P = ParamSpec("_P") +_Func = Callable[_P, _T] + + +@overload +def command(arg: Callable[_P, _T], /) -> Callable[_P, _T]: ... + + +@overload +def command( + arg: Optional[Union[str, Sequence[str]]] = None, + /, + *args: Union[str, Sequence[str]], + **kwargs: Any, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... + + +def command( + arg: Optional[Union[Callable[_P, _T], str, Sequence[str]]] = None, + /, + *args: Union[str, Sequence[str]], + **kwargs: Any, +) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]: """External command decorator. Can be used directly as a decorator, or with args to return a decorator.""" - def _command_hook(func, alias_param=None): + def _command_hook( + func: Callable[_P, _T], + alias_param: Optional[Sequence[Union[Sequence[str], str]]] = None, + ) -> Callable[_P, _T]: hook = _get_hook(func, "command") if hook is None: hook = _CommandHook(func) @@ -198,13 +236,17 @@ def _command_hook(func, alias_param=None): hook.add_hook(alias_param, kwargs) return func - if len(args) == 1 and callable(args[0]): + if arg is not None and not isinstance(arg, (str, collections.abc.Sequence)): # this decorator is being used directly _hook_warn() - return _command_hook(args[0]) + return _command_hook(arg) + + arg_list: List[Union[str, Sequence[str]]] = list(args) + if arg: + arg_list.insert(0, arg) # this decorator is being used indirectly, so return a decorator function - return lambda func: _command_hook(func, alias_param=args) + return lambda func: _command_hook(func, alias_param=arg_list) def irc_raw(triggers_param, **kwargs): @@ -332,10 +374,22 @@ def _config_hook(func): return _config_hook -def on_start(param=None, **kwargs): +@overload +def on_start( + **kwargs: Any, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... + + +@overload +def on_start(param: Callable[_P, _T], /) -> Callable[_P, _T]: ... + + +def on_start( + param: Optional[Callable[_P, _T]] = None, /, **kwargs: Any +) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]: """External on_start decorator. Can be used directly as a decorator, or with args to return a decorator""" - def _on_start_hook(func): + def _on_start_hook(func: Callable[_P, _T]) -> Callable[_P, _T]: hook = _get_hook(func, "on_start") if hook is None: hook = _Hook(func, "on_start") diff --git a/cloudbot/util/colors.py b/cloudbot/util/colors.py index 360763f50..43a0d949a 100644 --- a/cloudbot/util/colors.py +++ b/cloudbot/util/colors.py @@ -154,7 +154,7 @@ def get_available_colours(): return ret[:-2] -def parse(string): +def parse(string: str) -> str: """ parse: Formats a string, replacing words wrapped in $( ) with actual colours or formatting. example: diff --git a/cloudbot/util/func_utils.py b/cloudbot/util/func_utils.py index 66dc368d9..9e1c2a32d 100644 --- a/cloudbot/util/func_utils.py +++ b/cloudbot/util/func_utils.py @@ -1,4 +1,5 @@ import inspect +from typing import Any, Callable, Mapping, TypeVar class ParameterError(Exception): @@ -12,7 +13,14 @@ def __init__(self, name, valid_args): self.valid_args = list(valid_args) -def call_with_args(func, arg_data): +_T = TypeVar("_T") + + +def call_with_args(func: Callable[..., _T], arg_data: Mapping[str, Any]) -> _T: + """ + >>> call_with_args(lambda a: a, {'a':1, 'b':2}) + 1 + """ sig = inspect.signature(func, follow_wrapped=False) try: args = [ diff --git a/cloudbot/util/web.py b/cloudbot/util/web.py index 99eab8c47..c9387cea1 100644 --- a/cloudbot/util/web.py +++ b/cloudbot/util/web.py @@ -17,7 +17,16 @@ import logging import time from operator import attrgetter -from typing import Dict, Optional, Union +from typing import ( + Dict, + Generic, + Iterable, + Iterator, + Optional, + Tuple, + TypeVar, + Union, +) import requests from requests import ( @@ -41,51 +50,54 @@ # Public API +_T = TypeVar("_T") -class Registry: - class Item: - def __init__(self, item): - self.item = item - self.working = True - self.last_check = 0.0 - self.uses = 0 - def failed(self): - self.working = False - self.last_check = time.time() +class RegistryItem(Generic[_T]): + def __init__(self, item: _T) -> None: + self.item = item + self.working = True + self.last_check = 0.0 + self.uses = 0 + + def failed(self) -> None: + self.working = False + self.last_check = time.time() - @property - def should_use(self): - if self.working: - return True + @property + def should_use(self) -> bool: + if self.working: + return True + + if (time.time() - self.last_check) > (5 * 60): + # It's been 5 minutes, try again + self.working = True + return True - if (time.time() - self.last_check) > (5 * 60): - # It's been 5 minutes, try again - self.working = True - return True + return False - return False +class Registry(Generic[_T]): def __init__(self): - self._items: Dict[str, "Registry.Item"] = {} + self._items: Dict[str, RegistryItem[_T]] = {} - def register(self, name, item): + def register(self, name: str, item: _T) -> None: if name in self._items: raise ValueError("Attempt to register duplicate item") - self._items[name] = self.Item(item) + self._items[name] = RegistryItem(item) - def get(self, name): + def get(self, name: str) -> Optional[_T]: val = self._items.get(name) if val: return val.item - return val + return None - def get_item(self, name): + def get_item(self, name: str) -> Optional[RegistryItem[_T]]: return self._items.get(name) - def get_working(self) -> Optional["Item"]: + def get_working(self) -> Optional[RegistryItem[_T]]: working = [item for item in self._items.values() if item.should_use] if not working: @@ -93,24 +105,24 @@ def get_working(self) -> Optional["Item"]: return min(working, key=attrgetter("uses")) - def remove(self, name): + def remove(self, name: str) -> None: del self._items[name] - def items(self): + def items(self) -> Iterable[Tuple[str, RegistryItem[_T]]]: return self._items.items() - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._items) - def __getitem__(self, item): + def __getitem__(self, item: str) -> _T: return self._items[item].item - def set_working(self): + def set_working(self) -> None: for item in self._items.values(): item.working = True -def shorten(url, custom=None, key=None, service=DEFAULT_SHORTENER): +def shorten(url: str, custom=None, key=None, service=DEFAULT_SHORTENER): impl = shorteners[service] return impl.shorten(url, custom, key) @@ -140,7 +152,12 @@ class NoPasteException(Exception): """No pastebins succeeded""" -def paste(data, ext="txt", service=DEFAULT_PASTEBIN, raise_on_no_paste=False): +def paste( + data: Union[str, bytes], + ext="txt", + service=DEFAULT_PASTEBIN, + raise_on_no_paste=False, +) -> str: if service: impl = pastebins.get_item(service) else: @@ -218,12 +235,12 @@ class Pastebin: def __init__(self): pass - def paste(self, data, ext): + def paste(self, data, ext) -> str: raise NotImplementedError -shorteners = Registry() -pastebins = Registry() +shorteners = Registry[Shortener]() +pastebins = Registry[Pastebin]() # Internal Implementations @@ -346,7 +363,7 @@ def __init__(self, base_url): super().__init__() self.url = base_url - def paste(self, data, ext): + def paste(self, data, ext) -> str: if isinstance(data, str): encoded = data.encode() else: diff --git a/plugins/cryptocurrency.py b/plugins/cryptocurrency.py index a4f9b5c10..5bf41f228 100644 --- a/plugins/cryptocurrency.py +++ b/plugins/cryptocurrency.py @@ -10,284 +10,182 @@ GPL v3 """ -import inspect import time -import typing -import warnings from decimal import Decimal -from numbers import Real from operator import itemgetter from threading import RLock -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast +from typing import ( + Any, + Callable, + ContextManager, + Dict, + Generic, + List, + Optional, + Set, + Type, + TypeVar, + Union, + cast, +) import requests +from pydantic import BaseModel, Field, computed_field from requests import Response +from typing_extensions import Self from yarl import URL from cloudbot import hook +from cloudbot.bot import AbstractBot +from cloudbot.event import CommandEvent from cloudbot.util import colors, web from cloudbot.util.func_utils import call_with_args class APIError(Exception): - def __init__(self, msg: str): + def __init__(self, msg: str) -> None: super().__init__(msg) self.msg = msg class UnknownSymbolError(APIError): - def __init__(self, name: str): + def __init__(self, name: str) -> None: super().__init__(name) self.name = name class UnknownFiatCurrencyError(APIError): - def __init__(self, name: str): + def __init__(self, name: str) -> None: super().__init__(name) self.name = name class APIResponse: def __init__( - self, api, data: "UntypedResponse", response: Response + self, + api: "CoinMarketCapAPI", + data: "UntypedResponse", + response: Response, ) -> None: self.api = api self.data = data self.response = response @classmethod - def from_response(cls, api: "CoinMarketCapAPI", response: Response): - return cls(api, read_data(response.json(), UntypedResponse), response) - - -class SchemaField: - empty = object() - - def __init__(self, name: str, field_type: Type, default=empty): - self.name = name - self.field_type = field_type - self.default = default - - -def _get_fields(init_func): - signature = inspect.signature(init_func) - for parameter in signature.parameters.values(): - if parameter.annotation is parameter.empty: - continue - - if parameter.default is parameter.empty: - default = SchemaField.empty - else: - default = parameter.default - - yield SchemaField(parameter.name, parameter.annotation, default) - - -class SchemaMeta(type): - def __new__(cls, name, bases, members): - if members.setdefault("_abstract", False): - super_fields = () - for base in bases: - if not getattr(base, "_abstract", False) and isinstance( - base, cls - ): - super_fields = getattr(base, "_fields") - break - - members["_fields"] = super_fields - else: - members["_fields"] = tuple(_get_fields(members["__init__"])) - - return type.__new__(cls, name, bases, members) - - -T = TypeVar("T", bound="Schema") - - -class Schema(metaclass=SchemaMeta): - # noinspection PyUnusedName - _abstract = True - _fields = () + def from_response(cls, api: "CoinMarketCapAPI", response: Response) -> Self: + return cls( + api, UntypedResponse.model_validate(response.json()), response + ) - def __init__(self, **kwargs): - self.unknown_fields = {} - self.unknown_fields.update(kwargs) - def cast_to(self, new_type: Type[T]) -> T: - return read_data(serialize(self), new_type) +_ModelT = TypeVar("_ModelT", bound="ApiModel") -class ResponseStatus(Schema): - def __init__( - self, - timestamp: str, - error_code: int, - elapsed: int, - credit_count: int, - error_message: str = None, - notice: str = None, - ): - super().__init__() - self.timestamp = timestamp - self.error_code = error_code - self.error_message = error_message - self.elapsed = elapsed - self.credit_count = credit_count - self.notice = notice - - -class APIRequestResponse(Schema): - def __init__(self, status: ResponseStatus): - super().__init__() - self.status = status - - -class UntypedResponse(APIRequestResponse): - def __init__(self, status: ResponseStatus, data: Any = None): - super().__init__(status) - self.data = data - - -class Platform(Schema): - # noinspection PyShadowingBuiltins - def __init__( - self, id: int, name: str, symbol: str, slug: str, token_address: str - ): - super().__init__() - self.id = id - self.name = name - self.symbol = symbol - self.slug = slug - self.token_address = token_address +class ApiModel(BaseModel, extra="forbid"): + def cast_to(self, new_type: Type[_ModelT]) -> _ModelT: + return new_type.model_validate( + self.model_dump(mode="json", by_alias=True) + ) -class Quote(Schema): - def __init__( - self, - price: Real, - volume_24h: Real, - market_cap: Real, - percent_change_1h: Real, - percent_change_24h: Real, - percent_change_7d: Real, - last_updated: str, - volume_24h_reported: Real = None, - volume_7d: Real = None, - volume_7d_reported: Real = None, - volume_30d: Real = None, - volume_30d_reported: Real = None, - ): - super().__init__() - self.price = price - self.volume_24h = volume_24h - self.volume_24h_reported = volume_24h_reported - self.volume_7d = volume_7d - self.volume_7d_reported = volume_7d_reported - self.volume_30d = volume_30d - self.volume_30d_reported = volume_30d_reported - self.market_cap = market_cap - self.percent_change_1h = percent_change_1h - self.percent_change_24h = percent_change_24h - self.percent_change_7d = percent_change_7d - self.last_updated = last_updated - - -class CryptoCurrency(Schema): - # noinspection PyShadowingBuiltins - def __init__( - self, - id: int, - name: str, - symbol: str, - slug: str, - circulating_supply: Real, - total_supply: Real, - date_added: str, - num_market_pairs: int, - cmc_rank: int, - last_updated: str, - tags: List[str], - quote: Dict[str, Quote], - max_supply: Real = None, - market_cap_by_total_supply: Real = None, - platform: Platform = None, - ): - super().__init__() - self.id = id - self.name = name - self.symbol = symbol - self.slug = slug - self.circulating_supply = circulating_supply - self.total_supply = total_supply - self.max_supply = max_supply - self.market_cap_by_total_supply = market_cap_by_total_supply - self.date_added = date_added - self.num_market_pairs = num_market_pairs - self.cmc_rank = cmc_rank - self.last_updated = last_updated - self.tags = tags - self.platform = platform - self.quote = quote +class ResponseStatus(ApiModel): + timestamp: str + error_code: int + elapsed: int + credit_count: int + error_message: Optional[str] = None + notice: Optional[str] = None + + +class APIRequestResponse(ApiModel): + status: ResponseStatus + + +class UntypedResponse(APIRequestResponse, extra="allow"): + data: Any = None + + +class Platform(ApiModel): + platform_id: int = Field(alias="id") + name: str + symbol: str + slug: str + token_address: str + + +class Quote(ApiModel): + price: float + volume_24h: float + market_cap: float + percent_change_1h: Union[int, float] + percent_change_24h: Union[int, float] + percent_change_7d: Union[int, float] + last_updated: str + volume_24h_reported: Optional[float] = None + volume_7d: Optional[float] = None + volume_7d_reported: Optional[float] = None + volume_30d: Optional[float] = None + volume_30d_reported: Optional[float] = None + + +class CryptoCurrency(ApiModel): + currency_id: int = Field(alias="id") + name: str + symbol: str + slug: str + circulating_supply: float + total_supply: float + date_added: str + num_market_pairs: int + cmc_rank: int + last_updated: str + tags: List[str] + quote: Dict[str, Quote] + max_supply: Optional[float] = None + market_cap_by_total_supply: Optional[float] = None + platform: Optional[Platform] = None class QuoteRequestResponse(APIRequestResponse): - def __init__(self, data: Dict[str, CryptoCurrency], status: ResponseStatus): - super().__init__(status) - self.data = data + data: Dict[str, CryptoCurrency] -class FiatCurrency(Schema): - # noinspection PyShadowingBuiltins - def __init__(self, id: int, name: str, sign: str, symbol: str): - super().__init__() - self.id = id - self.name = name - self.sign = sign - self.symbol = symbol +class FiatCurrency(ApiModel): + id: int + name: str + sign: str + symbol: str class FiatCurrencyMap(APIRequestResponse): - def __init__(self, data: List[FiatCurrency], status: ResponseStatus): - super().__init__(status) - self.data = data + data: List[FiatCurrency] - self.symbols = { - currency.symbol: currency.sign for currency in self.data - } + @computed_field # type: ignore[misc] + @property + def symbols(self) -> Dict[str, str]: + return {currency.symbol: currency.sign for currency in self.data} -class CryptoCurrencyEntry(Schema): - # noinspection PyShadowingBuiltins - def __init__( - self, - id: int, - name: str, - symbol: str, - slug: str, - is_active: int, - first_historical_data: str = None, - last_historical_data: str = None, - platform: Platform = None, - status: str = None, - ) -> None: - super().__init__() - self.id = id - self.name = name - self.symbol = symbol - self.slug = slug - self.is_active = is_active - self.status = status - self.first_historical_data = first_historical_data - self.last_historical_data = last_historical_data - self.platform = platform +class CryptoCurrencyEntry(ApiModel): + id: int + name: str + symbol: str + slug: str + is_active: int + first_historical_data: Optional[str] = None + last_historical_data: Optional[str] = None + platform: Optional[Platform] = None + status: Optional[str] = None class CryptoCurrencyMap(APIRequestResponse): - def __init__(self, data: List[CryptoCurrencyEntry], status: ResponseStatus): - super().__init__(status) - self.data = data + data: List[CryptoCurrencyEntry] + status: ResponseStatus - self.names = {currency.symbol for currency in self.data} + @computed_field # type: ignore[misc] + @property + def names(self) -> Set[str]: + return {currency.symbol for currency in self.data} BAD_FIELD_TYPE_MSG = ( @@ -295,173 +193,31 @@ def __init__(self, data: List[CryptoCurrencyEntry], status: ResponseStatus): ) -def sentinel(name: str): - try: - storage = getattr(sentinel, "_sentinels") - except AttributeError: - storage = {} - setattr(sentinel, "_sentinels", storage) - - try: - return storage[name] - except KeyError: - storage[name] = obj = object() - return obj - - -_unset = sentinel("unset") - - -class TypeAssertError(TypeError): - def __init__(self, obj, cls): - super().__init__() - self.cls = cls - self.obj = obj - - -class MissingSchemaField(KeyError): - pass - - -class ParseError(ValueError): - pass - - -def _assert_type(obj, cls, display_cls=_unset): - if display_cls is _unset: - display_cls = cls - - if not isinstance(obj, cls): - raise TypeAssertError(obj, display_cls) - - -def _hydrate_object(_value, _cls): - if _cls is Any: - return _value - - if isinstance(_cls, type) and issubclass(_cls, Schema): - _assert_type(_value, dict) - return read_data(_value, _cls) - - typing_cls = typing.get_origin(_cls) - if typing_cls is not None: - type_args = typing.get_args(_cls) - if issubclass(typing_cls, list): - _assert_type(_value, list, _cls) - - return [_hydrate_object(v, type_args[0]) for v in _value] - - if issubclass(typing_cls, dict): - _assert_type(_value, dict, _cls) - - return { - _hydrate_object(k, type_args[0]): _hydrate_object( - v, type_args[1] - ) - for k, v in _value.items() - } - - # pragma: no cover - raise TypeError(f"Can't match typing alias {typing_cls!r}") - - _assert_type(_value, _cls) - - return _value - - -def read_data(data: Dict, schema_cls: Type[T]) -> T: - fields: Tuple[SchemaField, ...] = schema_cls._fields +_T = TypeVar("_T") +_K = TypeVar("_K") +_V = TypeVar("_V") - out: Dict[str, Any] = {} - field_names: List[str] = [] - for schema_field in fields: - try: - param_type = schema_field.field_type - name = schema_field.name - field_names.append(name) - try: - value = data[name] - except KeyError as e: - if schema_field.default is schema_field.empty: - raise MissingSchemaField(name) from e - - value = schema_field.default - - if value is None and schema_field.default is None: - out[name] = value - continue - - try: - out[name] = _hydrate_object(value, param_type) - except TypeAssertError as e: - raise TypeError( - BAD_FIELD_TYPE_MSG.format( - field=name, exp_type=e.cls, act_type=type(e.obj) - ) - ) from e - except (MissingSchemaField, TypeAssertError, ParseError) as e: - raise ParseError( - f"Unable to parse schema {schema_cls.__name__!r}" - ) from e - - obj = schema_cls(**out) - - obj.unknown_fields.update( - {key: data[key] for key in data if key not in field_names} - ) - - if obj.unknown_fields: - warnings.warn( - "Unknown fields: {} while parsing schema {!r}".format( - list(obj.unknown_fields.keys()), schema_cls.__name__ - ) - ) - - return obj - - -def serialize(obj): - if isinstance(obj, Schema): - out = {} - for field in obj._fields: # type: SchemaField - val = getattr(obj, field.name) - out[field.name] = serialize(val) - - if obj.unknown_fields: - out.update(obj.unknown_fields) - - return out - - if isinstance(obj, list): - return [serialize(o) for o in obj] - - if isinstance(obj, dict): - return {k: serialize(v) for k, v in obj.items()} - - return obj - - -class CacheEntry: - def __init__(self, value, expire): +class CacheEntry(Generic[_T]): + def __init__(self, value: _T, expire: float) -> None: self.value = value self.expire = expire -class Cache: - def __init__(self, lock_cls=RLock): - self._data = {} +class Cache(Generic[_K, _V]): + def __init__(self, lock_cls: Type[ContextManager[Any]] = RLock) -> None: + self._data: Dict[_K, CacheEntry[_V]] = {} self._lock = lock_cls() - def clear(self): + def clear(self) -> None: self._data.clear() - def put(self, key, value, ttl) -> CacheEntry: + def put(self, key: _K, value: _V, ttl: float) -> CacheEntry[_V]: with self._lock: self._data[key] = out = CacheEntry(value, time.time() + ttl) return out - def get(self, key: str) -> Optional[CacheEntry]: + def get(self, key: _K) -> Optional[CacheEntry[_V]]: with self._lock: try: entry = self._data[key] @@ -478,16 +234,16 @@ def get(self, key: str) -> Optional[CacheEntry]: class CoinMarketCapAPI: def __init__( self, - api_key: str = None, + api_key: Optional[str] = None, api_url: str = "https://pro-api.coinmarketcap.com/v1/", ) -> None: self.api_key = api_key self.show_btc = False self.api_url = URL(api_url) - self.cache = Cache() + self.cache = Cache[str, ApiModel]() @property - def request_headers(self): + def request_headers(self) -> Dict[str, Optional[str]]: return { "Accepts": "application/json", "X-CMC_PRO_API_KEY": self.api_key, @@ -514,6 +270,7 @@ def get_quote(self, symbol: str, currency: str = "USD") -> CryptoCurrency: symbol=symbol.upper(), convert=convert, ).data.cast_to(QuoteRequestResponse) + _, out = data.data.popitem() return out @@ -531,16 +288,16 @@ def get_crypto_currency_map(self) -> CryptoCurrencyMap: ) def _request_cache( - self, name: str, endpoint: str, fmt: Type[T], ttl: int - ) -> T: + self, name: str, endpoint: str, fmt: Type[_ModelT], ttl: int + ) -> _ModelT: out = self.cache.get(name) if out is None: currencies = self.request(endpoint).data.cast_to(fmt) out = self.cache.put(name, currencies, ttl) - return out.value + return cast(_ModelT, out.value) - def request(self, endpoint: str, **params) -> APIResponse: + def request(self, endpoint: str, **params: str) -> APIResponse: url = str(self.api_url / endpoint) with requests.get( url, headers=self.request_headers, params=params @@ -559,15 +316,15 @@ def check(self, response: APIResponse) -> None: api = CoinMarketCapAPI() -def get_plugin_config(conf, name, default): +def get_plugin_config(conf: Dict[str, Any], name: str, default: _T) -> _T: try: - return conf["plugins"]["cryptocurrency"][name] + return cast(_T, conf["plugins"]["cryptocurrency"][name]) except LookupError: return default -@hook.onload() -def init_api(bot): +@hook.on_start() +def init_api(bot: "AbstractBot") -> None: api.api_key = bot.config.get_api_key("coinmarketcap") # Enabling this requires a paid CoinMarketCap API plan @@ -577,7 +334,7 @@ def init_api(bot): class Alias: __slots__ = ("name", "cmds") - def __init__(self, symbol, *cmds): + def __init__(self, symbol: str, *cmds: str) -> None: self.name = symbol if symbol not in cmds: cmds += (symbol,) @@ -585,15 +342,8 @@ def __init__(self, symbol, *cmds): self.cmds = cmds -ALIASES = ( - Alias("btc", "bitcoin"), - Alias("ltc", "litecoin"), - Alias("doge", "dogecoin"), -) - - -def alias_wrapper(alias): - def func(text, event): +def alias_wrapper(alias: Alias) -> Callable[[str, CommandEvent], str]: + def func(text: str, event: CommandEvent) -> str: event.text = alias.name + " " + text return call_with_args(crypto_command, event) @@ -605,7 +355,7 @@ def func(text, event): # main command @hook.command("crypto", "cryptocurrency") -def crypto_command(text, event): +def crypto_command(text: str, event: CommandEvent) -> str: """ [currency] - Returns current value of a cryptocurrency""" args = text.split() ticker = args.pop(0) @@ -626,7 +376,7 @@ def crypto_command(text, event): raise quote = data.quote[currency] - change = cast(Union[int, float], quote.percent_change_24h) + change = quote.percent_change_24h if change > 0: change_str = colors.parse("$(dark_green)+{}%$(clear)").format(change) elif change < 0: @@ -656,7 +406,7 @@ def crypto_command(text, event): ) -def format_price(price: Union[int, float, Real]) -> str: +def format_price(price: Union[int, float]) -> str: price = float(price) if price < 1: precision = max(2, min(10, len(str(Decimal(str(price)))) - 2)) @@ -668,7 +418,7 @@ def format_price(price: Union[int, float, Real]) -> str: @hook.command("currencies", "currencylist", autohelp=False) -def currency_list(): +def currency_list() -> str: """- List all available currencies from the API""" currency_map = api.get_crypto_currency_map() currencies = sorted( @@ -681,11 +431,11 @@ def currency_list(): return "Available currencies: " + web.paste("\n".join(lst)) -def make_alias(alias): +def make_alias(alias: Alias) -> Callable[[str, CommandEvent], str]: _hook = alias_wrapper(alias) return hook.command(*alias.cmds, autohelp=False)(_hook) btc_alias = make_alias(Alias("btc", "bitcoin")) -ltc_alias = make_alias(Alias("ltc", "litecoin")) +ltc_alias = make_alias(Alias("ltc", "litecoin", "ltc")) doge_alias = make_alias(Alias("doge", "dogecoin")) diff --git a/plugins/duckhunt.py b/plugins/duckhunt.py index a4acbf6fc..150d3f504 100644 --- a/plugins/duckhunt.py +++ b/plugins/duckhunt.py @@ -3,7 +3,7 @@ from collections import defaultdict from threading import Lock from time import sleep, time -from typing import Dict, List, NamedTuple, TypeVar +from typing import Callable, Dict, List, NamedTuple, Optional, TypeVar from sqlalchemy import ( Boolean, @@ -18,6 +18,7 @@ from sqlalchemy.sql import select from cloudbot import hook +from cloudbot.client import Client from cloudbot.event import EventType from cloudbot.util import database from cloudbot.util.formatting import pluralize_auto, truncate @@ -607,19 +608,21 @@ def get_average_scores(db, score_type: ScoreType, conn): return scores_dict -SCORE_TYPES = { +SCORE_TYPES: Dict[str, ScoreType] = { "friend": ScoreType("befriend", "befriend", "friend", "friended"), "killer": ScoreType("killer", "shot", "killer", "killed"), } -DISPLAY_FUNCS = { +DISPLAY_FUNCS: Dict[Optional[str], Callable[..., Optional[Dict[str, int]]]] = { "average": get_average_scores, "global": get_global_scores, None: get_channel_scores, } -def display_scores(score_type: ScoreType, event, text, chan, conn, db): +def display_scores( + score_type: ScoreType, event, text: str, chan: str, conn: "Client", db +): if is_opt_out(conn.name, chan): return None diff --git a/plugins/pastebins/sprunge.py b/plugins/pastebins/sprunge.py index 9d4a4c720..567d93551 100644 --- a/plugins/pastebins/sprunge.py +++ b/plugins/pastebins/sprunge.py @@ -11,11 +11,11 @@ class Sprunge(Pastebin): - def __init__(self, base_url): + def __init__(self, base_url) -> None: super().__init__() self.url = base_url - def paste(self, data, ext): + def paste(self, data, ext) -> str: if isinstance(data, str): encoded = data.encode() else: diff --git a/pyproject.toml b/pyproject.toml index 674c90e84..f749a49ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,16 @@ show_error_codes = true warn_unused_ignores = true warn_redundant_casts = true # strict_equality = true +plugins = [ + "pydantic.mypy" +] + +follow_imports = "silent" +# disallow_any_generics = true +no_implicit_reexport = true + +# for strict mypy: (this is the tricky one :-)) +# disallow_untyped_defs = true [[tool.mypy.overrides]] module = 'cloudbot.*' @@ -120,6 +130,21 @@ check_untyped_defs = true warn_return_any = true disallow_untyped_defs = true +[[tool.mypy.overrides]] +module = 'plugins.cryptocurrency' +strict_optional = true +check_untyped_defs = true +warn_return_any = true + +[tool.pydantic-mypy] +init_forbid_extra = true +init_typed = true +warn_required_dynamic_aliases = true + +[tool.autoflake] +in-place = true +remove-all-unused-imports = true + [tool.commitizen] name = "cz_conventional_commits" tag_format = "v$version" diff --git a/requirements.txt b/requirements.txt index 068bbaefe..b14737074 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,9 @@ mcstatus == 11.1.1 multidict == 6.0.5 nltk == 3.8.1 psutil == 5.9.8 +pydantic[email] == 2.7.0 +pydantic-extra-types == 2.6.0 +typing_extensions py-irclib == 0.3.0 pyowm == 3.3.0 requests == 2.31.0 diff --git a/tests/plugin_tests/regex_chans_test.py b/tests/plugin_tests/regex_chans_test.py index e3bb9518f..3e6def116 100644 --- a/tests/plugin_tests/regex_chans_test.py +++ b/tests/plugin_tests/regex_chans_test.py @@ -9,7 +9,7 @@ from cloudbot.plugin_hooks import CommandHook, EventHook, RegexHook from plugins.core import regex_chans from tests.util import wrap_hook_response -from tests.util.mock_conn import MockConn +from tests.util.mock_conn import MockClient from tests.util.mock_db import MockDB @@ -75,7 +75,7 @@ def test_delete_status(mock_db: MockDB): assert regex_chans.status_cache == {} -def test_listregex(mock_db: MockDB): +def test_listregex(mock_db: MockDB, mock_bot): regex_chans.table.create(mock_db.engine) with mock_db.session() as session: mock_db.load_data( @@ -103,7 +103,7 @@ def test_listregex(mock_db: MockDB): regex_chans.load_cache(session) - conn = MockConn(name="net") + conn = MockClient(bot=mock_bot, name="net") assert ( regex_chans.listregex(conn) == "#chan: DISABLED, #chan1: ENABLED, #chan2: DISABLED, #chan3: DISABLED" @@ -111,7 +111,7 @@ def test_listregex(mock_db: MockDB): class TestRegexStatus: - def test_current_chan(self, mock_db: MockDB): + def test_current_chan(self, mock_db: MockDB, mock_bot): regex_chans.table.create(mock_db.engine) with mock_db.session() as session: mock_db.load_data( @@ -147,13 +147,13 @@ def test_current_chan(self, mock_db: MockDB): regex_chans.load_cache(session) - conn = MockConn(name="net") + conn = MockClient(bot=mock_bot, name="net") assert ( regex_chans.regexstatus("", conn, "#chan") == "Regex status for #chan: DISABLED" ) - def test_other_chan(self, mock_db: MockDB): + def test_other_chan(self, mock_db: MockDB, mock_bot): regex_chans.table.create(mock_db.engine) with mock_db.session() as session: mock_db.load_data( @@ -189,13 +189,13 @@ def test_other_chan(self, mock_db: MockDB): regex_chans.load_cache(session) - conn = MockConn(name="net") + conn = MockClient(bot=mock_bot, name="net") assert ( regex_chans.regexstatus("#chan1", conn, "#chan") == "Regex status for #chan1: ENABLED" ) - def test_other_chan_no_prefix(self, mock_db: MockDB): + def test_other_chan_no_prefix(self, mock_db: MockDB, mock_bot): regex_chans.table.create(mock_db.engine) with mock_db.session() as session: mock_db.load_data( @@ -231,7 +231,7 @@ def test_other_chan_no_prefix(self, mock_db: MockDB): regex_chans.load_cache(session) - conn = MockConn(name="net") + conn = MockClient(bot=mock_bot, name="net") assert ( regex_chans.regexstatus("chan2", conn, "#chan") == "Regex status for #chan2: DISABLED" diff --git a/tests/plugin_tests/test_cryptocurrency.py b/tests/plugin_tests/test_cryptocurrency.py index 045417b08..5cca0669d 100644 --- a/tests/plugin_tests/test_cryptocurrency.py +++ b/tests/plugin_tests/test_cryptocurrency.py @@ -1,4 +1,3 @@ -import re from datetime import datetime, timedelta from typing import Any, Dict, List from unittest.mock import MagicMock @@ -12,32 +11,6 @@ from tests.util import HookResult, wrap_hook_response -def test_parse(): - assert cryptocurrency.ResponseStatus._fields != cryptocurrency.Quote._fields - cryptocurrency.Platform( # nosec - id=1, - name="name", - symbol="symbol", - slug="slug", - token_address="foobar", - ) - assert len(cryptocurrency.Platform._fields) == 5 - data = { - "status": { - "timestamp": "ts", - "error_code": 200, - "error_message": None, - "elapsed": 1, - "credit_count": 1, - "notice": None, - } - } - - obj = cryptocurrency.read_data(data, cryptocurrency.APIRequestResponse) - assert obj.status.credit_count == 1 - assert cryptocurrency.serialize(obj) == data - - class MatchAPIKey(Response): def __init__(self, method, url, api_key=None, **kwargs): super().__init__(method, url, **kwargs) @@ -61,7 +34,9 @@ def init_response( price=50000000000.0, ): if check_api_key: - cryptocurrency.init_api(bot.get()) + b = bot.get() + assert b is not None + cryptocurrency.init_api(b) cryptocurrency.api.cache.clear() cryptocurrency.api.show_btc = show_btc @@ -189,93 +164,20 @@ def test_api(mock_requests, mock_api_keys): result = cryptocurrency.api.get_quote("BTC", "USD") assert result.name == "Bitcoin" - assert not result.unknown_fields + assert not result.model_extra assert result.total_supply == 1000 assert result.circulating_supply == 100 -class SomeSchema(cryptocurrency.Schema): - def __init__(self, a: List[List[Dict[str, List[str]]]]): - super().__init__() - self.a = a - - -def test_schema(): - cryptocurrency.read_data({"a": [[{"a": ["1"]}]]}, SomeSchema) - - -class ConcreteSchema(cryptocurrency.Schema): - def __init__(self, a: str) -> None: - super().__init__() - self.a = a - - -class AbstractSchema(ConcreteSchema): - _abstract = True - - -class OtherConcreteSchema(AbstractSchema): - def __init__(self, a: str, b: str): - super().__init__(a) - self.b = b - - -def test_complex_schema(): - cryptocurrency.read_data({"a": "hello", "b": "world"}, OtherConcreteSchema) - - -def test_invalid_schema_type(): - with pytest.raises( - TypeError, - match="field 'a' expected type , got type ", - ): - cryptocurrency.read_data({"a": 1, "b": "world"}, OtherConcreteSchema) - - -def test_schema_missing_field(): - with pytest.raises(cryptocurrency.ParseError) as exc: - cryptocurrency.read_data({"b": "hello"}, OtherConcreteSchema) - - assert isinstance(exc.value.__cause__, cryptocurrency.MissingSchemaField) - - -class NestedSchema(cryptocurrency.Schema): - def __init__(self, a: OtherConcreteSchema) -> None: - super().__init__() - self.a = a - - -def test_schema_nested_exceptions(): - with pytest.raises(cryptocurrency.ParseError) as exc: - cryptocurrency.read_data({"a": {"b": "hello"}}, NestedSchema) - - assert isinstance(exc.value.__cause__, cryptocurrency.ParseError) - assert isinstance( - exc.value.__cause__.__cause__, cryptocurrency.MissingSchemaField - ) - - -def test_schema_unknown_fields(): - input_data = {"a": {"a": "hello", "b": "world"}, "c": 1} - with pytest.warns( - UserWarning, - match=re.escape( - "Unknown fields: ['c'] while parsing schema 'NestedSchema'" - ), - ): - obj = cryptocurrency.read_data(input_data, NestedSchema) - - assert cryptocurrency.serialize(obj) == input_data - - def test_cache(freeze_time): - c = cryptocurrency.Cache() + c = cryptocurrency.Cache[str, str]() c.put("foo", "bar", 30) # Object with a lifespan of 30 seconds should die at 30 seconds freeze_time.tick(timedelta(seconds=29)) - assert c.get("foo") is not None - assert c.get("foo").value == "bar" + entry = c.get("foo") + assert entry is not None + assert entry.value == "bar" freeze_time.tick() assert c.get("foo") is None diff --git a/tests/plugin_tests/test_pager_commands.py b/tests/plugin_tests/test_pager_commands.py index 3fa218b83..cc1242e16 100644 --- a/tests/plugin_tests/test_pager_commands.py +++ b/tests/plugin_tests/test_pager_commands.py @@ -81,9 +81,9 @@ def test_profile_pager(): pages = profile.cat_pages - def call(*args): + def call(text, chan, nick): notice = CaptureCalls() - hook(*args, notice=notice) + hook(text, chan, nick, notice=notice) return notice.lines no_grabs = "There are no category pages to show." diff --git a/tests/util/__init__.py b/tests/util/__init__.py index cdb11b95e..08cdb8b8d 100644 --- a/tests/util/__init__.py +++ b/tests/util/__init__.py @@ -1,5 +1,5 @@ import inspect -from collections.abc import Awaitable, Mapping +from collections.abc import Mapping from pathlib import Path from unittest.mock import patch diff --git a/tests/util/mock_conn.py b/tests/util/mock_conn.py index 46962c358..b31fb2c7a 100644 --- a/tests/util/mock_conn.py +++ b/tests/util/mock_conn.py @@ -1,4 +1,36 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, _CallList + +from cloudbot.bot import AbstractBot +from cloudbot.client import Client + + +class MockClient(Client): + def __init__(self, *, bot: "AbstractBot", nick=None, name=None): + super().__init__( + bot=bot, + _type="mock", + name=name or "testconn", + nick=nick or "TestBot", + ) + self._mock = MagicMock(spec=Client) + + def reload(self): + return self._mock.reload() + + async def try_connect(self): + return self._mock.try_connect() + + def join(self, channel, key=None): + return self._mock.join(channel, key) + + def notice(self, target, text): + return self._mock.notice(target, text) + + def is_nick_valid(self, nick): + return True + + def mock_calls(self) -> _CallList: + return self._mock.mock_calls class MockConn: From cc8f65510040e69c5764b939d160e99d23134c17 Mon Sep 17 00:00:00 2001 From: linuxdaemon Date: Sun, 14 Apr 2024 20:07:40 +0000 Subject: [PATCH 2/3] Add overload coverage exemption --- .coveragerc | 20 -------------------- pyproject.toml | 24 ++++++++++++++++++++++++ 2 files changed, 24 insertions(+), 20 deletions(-) delete mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index a511e9b60..000000000 --- a/.coveragerc +++ /dev/null @@ -1,20 +0,0 @@ -[report] -fail_under = 60 -exclude_lines = - if TYPE_CHECKING: - pragma: no cover - def __repr__ - raise AssertionError - raise NotImplementedError - if __name__ == .__main__.: - if sys.version_info - class .*\(.*(Error|Exception)\): - ^ *\.\.\.$ - -[run] -branch = true -omit = - tests/data/* - tests/util/* - .* - venv/* diff --git a/pyproject.toml b/pyproject.toml index f749a49ab..68aaa1e66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,30 @@ warn_required_dynamic_aliases = true in-place = true remove-all-unused-imports = true +[tool.coverage.report] +fail_under = 60 +exclude_lines = [ + "@overload", + "if TYPE_CHECKING:", + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + 'if __name__ == .__main__.:', + "if sys.version_info", + 'class .*\(.*(Error|Exception)\):', + '^ *\.\.\.$', +] + +[tool.coverage.run] +branch = true +omit = [ + "tests/data/*", + "tests/util/*", + ".*", + "venv/*", +] + [tool.commitizen] name = "cz_conventional_commits" tag_format = "v$version" From 9445248200e87c4a8096ff39a45ad2654e1cb369 Mon Sep 17 00:00:00 2001 From: linuxdaemon Date: Tue, 7 May 2024 12:02:55 +0000 Subject: [PATCH 3/3] test: add type annotations for web.Registry instances --- tests/core_tests/util_tests/test_web.py | 21 ++++++++++++--------- tests/plugin_tests/test_admin_bot.py | 3 +-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/core_tests/util_tests/test_web.py b/tests/core_tests/util_tests/test_web.py index 54a4cc27a..9ddc9d6a5 100644 --- a/tests/core_tests/util_tests/test_web.py +++ b/tests/core_tests/util_tests/test_web.py @@ -48,7 +48,7 @@ def test_paste_error(mock_requests): def test_registry_items(): - registry = web.Registry() + registry = web.Registry[object]() obj = object() registry.register("test", obj) item = registry.get_item("test") @@ -57,9 +57,10 @@ def test_registry_items(): def test_registry_item_working(freeze_time): - registry = web.Registry() + registry = web.Registry[object]() registry.register("test", object()) item = registry.get_item("test") + assert item is not None assert item.should_use item.failed() @@ -271,18 +272,20 @@ def test_expand(mock_requests): def test_register_duplicate_paste(): obj = object() obj1 = object() + registry = web.Registry[object]() - web.pastebins.register("test", obj) + registry.register("test", obj) with pytest.raises(ValueError): - web.pastebins.register("test", obj1) + registry.register("test", obj1) - web.pastebins.remove("test") + registry.remove("test") def test_remove_paste(): obj = object() + registry = web.Registry[object]() - web.pastebins.register("test", obj) - assert web.pastebins.get("test") is obj - web.pastebins.remove("test") - assert web.pastebins.get("test") is None + registry.register("test", obj) + assert registry.get("test") is obj + registry.remove("test") + assert registry.get("test") is None diff --git a/tests/plugin_tests/test_admin_bot.py b/tests/plugin_tests/test_admin_bot.py index 421b0920c..ce5800abd 100644 --- a/tests/plugin_tests/test_admin_bot.py +++ b/tests/plugin_tests/test_admin_bot.py @@ -85,8 +85,7 @@ def f(self, attr): event.__getitem__ = f event.event = event - res = await func_utils.call_with_args(admin_bot.me, event) - assert res is None + await func_utils.call_with_args(admin_bot.me, event) assert event.mock_calls == [ call.admin_log('bar used ME to make me ACT "do thing" in #foo.'), call.conn.ctcp("#foo", "ACTION", "do thing"),