diff --git a/changelog/1201.feature.rst b/changelog/1201.feature.rst new file mode 100644 index 0000000000..c76694f989 --- /dev/null +++ b/changelog/1201.feature.rst @@ -0,0 +1 @@ +|commands| Add support for ``Range[LargeInt, ...]`` in slash command parameters diff --git a/disnake/app_commands.py b/disnake/app_commands.py index aa41f8a347..ad1892a21c 100644 --- a/disnake/app_commands.py +++ b/disnake/app_commands.py @@ -287,9 +287,9 @@ def __init__( self.required: bool = required self.options: List[Option] = options or [] - if min_value and self.type is OptionType.integer: + if min_value is not None and self.type is OptionType.integer: min_value = math.ceil(min_value) - if max_value and self.type is OptionType.integer: + if max_value is not None and self.type is OptionType.integer: max_value = math.floor(max_value) self.min_value: Optional[float] = min_value diff --git a/disnake/ext/commands/errors.py b/disnake/ext/commands/errors.py index a3e888a03e..1755ad5004 100644 --- a/disnake/ext/commands/errors.py +++ b/disnake/ext/commands/errors.py @@ -572,6 +572,37 @@ def __init__(self, argument: str) -> None: super().__init__(f"{argument} is not able to be converted to an integer") +class LargeIntOutOfRange(BadArgument): + """Exception raised when an argument to a large integer option exceeds given range. + + This inherits from :exc:`BadArgument` + + .. versionadded:: 2.11 + + Attributes + ---------- + argument: :class:`str` + The argument that exceeded the defined range. + min_value: Optional[Union[:class:`int`, :class:`float`]] + The minimum allowed value. + max_value: Optional[Union[:class:`int`, :class:`float`]] + The maximum allowed value. + """ + + def __init__( + self, + argument: str, + min_value: Union[int, float, None], + max_value: Union[int, float, None], + ) -> None: + self.argument: str = argument + self.min_value: Union[int, float, None] = min_value + self.max_value: Union[int, float, None] = max_value + a = "..." if min_value is None else min_value + b = "..." if max_value is None else max_value + super().__init__(f"{argument} is not in range [{a}, {b}]") + + class DisabledCommand(CommandError): """Exception raised when the command being invoked is disabled. diff --git a/disnake/ext/commands/params.py b/disnake/ext/commands/params.py index 014b3bb010..3466e8941d 100644 --- a/disnake/ext/commands/params.py +++ b/disnake/ext/commands/params.py @@ -26,6 +26,7 @@ Generic, List, Literal, + Mapping, NoReturn, Optional, Sequence, @@ -33,10 +34,11 @@ Type, TypeVar, Union, - cast, get_origin, ) +from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeGuard + import disnake from disnake.app_commands import Option, OptionChoice from disnake.channel import _channel_type_factory @@ -54,11 +56,10 @@ from . import errors from .converter import CONVERTER_MAPPING -T_ = TypeVar("T_") +T = TypeVar("T") +P = ParamSpec("P") if TYPE_CHECKING: - from typing_extensions import Concatenate, ParamSpec, Self, TypeGuard - from disnake.app_commands import Choices from disnake.i18n import LocalizationValue, LocalizedOptional from disnake.types.interactions import ApplicationCommandOptionChoiceValue @@ -70,11 +71,9 @@ AnySlashCommand = Union[InvokableSlashCommand, SubCommand] - P = ParamSpec("P") - InjectionCallback = Union[ - Callable[Concatenate[CogT, P], T_], - Callable[P, T_], + Callable[Concatenate[CogT, P], T], + Callable[P, T], ] AnyAutocompleter = Union[ Sequence[Any], @@ -82,23 +81,18 @@ Callable[Concatenate[CogT, ApplicationCommandInteraction, str, P], Any], ] - TChoice = TypeVar("TChoice", bound=ApplicationCommandOptionChoiceValue) -else: - P = TypeVar("P") - if sys.version_info >= (3, 10): from types import EllipsisType, UnionType -elif TYPE_CHECKING: - EllipsisType = type(Ellipsis) - UnionType = NoReturn + + UnionTypes = (Union, UnionType) else: - UnionType = object() - EllipsisType = type(Ellipsis) + # using 'type' and not 'object', as 'type' is disjoint with 'int | float' + EllipsisType: TypeAlias = type + UnionTypes = (Union,) -T = TypeVar("T", bound=Any) -TypeT = TypeVar("TypeT", bound=Type[Any]) +TypeT = TypeVar("TypeT", bound=type) BotT = TypeVar("BotT", bound="disnake.Client", covariant=True) __all__ = ( @@ -132,7 +126,7 @@ def issubclass_(obj: Any, tp: Union[TypeT, Tuple[TypeT, ...]]) -> TypeGuard[Type if (origin := get_origin(obj)) is None: return False - if origin in (Union, UnionType): + if origin in UnionTypes: # If we have a Union, try matching any of its args # (recursively, to handle possibly generic types inside this union) return any(issubclass_(o, tp) for o in obj.__args__) @@ -141,8 +135,8 @@ def issubclass_(obj: Any, tp: Union[TypeT, Tuple[TypeT, ...]]) -> TypeGuard[Type def remove_optionals(annotation: Any) -> Any: - """Remove unwanted optionals from an annotation""" - if get_origin(annotation) in (Union, UnionType): + """Remove unwanted optionals from an annotation.""" + if get_origin(annotation) in UnionTypes: args = tuple(i for i in annotation.__args__ if i not in (None, type(None))) if len(args) == 1: annotation = args[0] @@ -170,7 +164,47 @@ def _xt_to_xe(xe: Optional[float], xt: Optional[float], direction: float = 1) -> return None -class Injection(Generic[P, T_]): +def _int_to_str_len(number: int) -> int: + """Returns `len(str(number))`, i.e. character count of base 10 signed repr of `number`.""" + # Desmos equivalent: floor(log(max(abs(x), 1))) + 1 + max(-sign(x), 0) + return ( + int(math.log10(abs(number) or 1)) + # 0 -> 0, 1 -> 0, 9 -> 0, 10 -> 1 + + 1 + + (number < 0) + ) + + +def _range_to_str_len(min_value: int, max_value: int) -> Tuple[int, int]: + min_ = _int_to_str_len(min_value) + max_ = _int_to_str_len(max_value) + opposite_sign = (min_value < 0) ^ (max_value < 0) + # both bounds positive: len(str(min_value)) <= len(str(max_value)) + # smaller bound negative: the range includes 0, which sets the minimum length to 1 + # both bounds negative: len(str(min_value)) >= len(str(max_value)) + if opposite_sign: + return 1, max(min_, max_) + return min(min_, max_), max(min_, max_) + + +def _unbound_range_to_str_len( + min_value: Optional[int], max_value: Optional[int] +) -> Tuple[Optional[int], Optional[int]]: + if min_value is not None and max_value is not None: + return _range_to_str_len(min_value, max_value) + + elif min_value is not None and min_value > 0: + # 0 < min_value <= max_value == inf + return _int_to_str_len(min_value), None + + elif max_value is not None and max_value < 0: + # -inf == min_value <= max_value < 0 + return None, _int_to_str_len(max_value) + + return None, None + + +class Injection(Generic[P, T]): """Represents a slash command injection. .. versionadded:: 2.3 @@ -192,7 +226,7 @@ class Injection(Generic[P, T_]): def __init__( self, - function: InjectionCallback[CogT, P, T_], + function: InjectionCallback[CogT, P, T], *, autocompleters: Optional[Dict[str, Callable]] = None, ) -> None: @@ -200,11 +234,11 @@ def __init__( for autocomp in autocompleters.values(): classify_autocompleter(autocomp) - self.function: InjectionCallback[Any, P, T_] = function + self.function: InjectionCallback[Any, P, T] = function self.autocompleters: Dict[str, Callable] = autocompleters or {} self._injected: Optional[Cog] = None - def __get__(self, obj: Optional[Any], _: Type[Any]) -> Self: + def __get__(self, obj: Optional[Any], _: type) -> Self: if obj is None: return self @@ -214,7 +248,7 @@ def __get__(self, obj: Optional[Any], _: Type[Any]) -> Self: return copy - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T_: + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: """Calls the underlying function that the injection holds. .. versionadded:: 2.6 @@ -227,11 +261,11 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T_: @classmethod def register( cls, - function: InjectionCallback[CogT, P, T_], + function: InjectionCallback[CogT, P, T], annotation: Any, *, autocompleters: Optional[Dict[str, Callable]] = None, - ) -> Injection[P, T_]: + ) -> Injection[P, T]: self = cls(function, autocompleters=autocompleters) cls._registered[annotation] = self return self @@ -269,17 +303,24 @@ def decorator(func: FuncT) -> FuncT: return decorator +NumT = TypeVar("NumT", bound=Union[int, float]) + + @dataclass(frozen=True) -class _BaseRange(ABC): +class _BaseRange(ABC, Generic[NumT]): """Internal base type for supporting ``Range[...]`` and ``String[...]``.""" - _allowed_types: ClassVar[Tuple[Type[Any], ...]] + _allowed_types: ClassVar[Tuple[type, ...]] - underlying_type: Type[Any] - min_value: Optional[Union[int, float]] - max_value: Optional[Union[int, float]] + underlying_type: type + min_value: Optional[NumT] + max_value: Optional[NumT] def __class_getitem__(cls, params: Tuple[Any, ...]) -> Self: + if cls is _BaseRange: + # needed since made generic + return super().__class_getitem__(params) # pyright: ignore[reportAttributeAccessIssue] + # deconstruct type arguments if not isinstance(params, tuple): params = (params,) @@ -297,12 +338,11 @@ def __class_getitem__(cls, params: Tuple[Any, ...]) -> Self: f"Use `{name}[, , ]` instead.", stacklevel=2, ) - # infer type from min/max values params = (cls._infer_type(params), *params) if len(params) != 3: - msg = f"`{name}` expects 3 type arguments ({name}[, , ]), got {len(params)}" + msg = f"`{name}` expects 3 arguments ({name}[, , ]), got {len(params)}" raise TypeError(msg) underlying_type, min_value, max_value = params @@ -312,7 +352,7 @@ def __class_getitem__(cls, params: Tuple[Any, ...]) -> Self: msg = f"First `{name}` argument must be a type, not `{underlying_type!r}`" raise TypeError(msg) - if not issubclass(underlying_type, cls._allowed_types): + if not issubclass_(underlying_type, cls._allowed_types): allowed = "/".join(t.__name__ for t in cls._allowed_types) msg = f"First `{name}` argument must be {allowed}, not `{underlying_type!r}`" raise TypeError(msg) @@ -334,8 +374,8 @@ def __class_getitem__(cls, params: Tuple[Any, ...]) -> Self: return cls(underlying_type=underlying_type, min_value=min_value, max_value=max_value) @staticmethod - def _coerce_bound(value: Any, name: str) -> Optional[Union[int, float]]: - if value is None or isinstance(value, EllipsisType): + def _coerce_bound(value: Union[NumT, EllipsisType, None], name: str) -> Optional[NumT]: + if value is None or value is ...: return None elif isinstance(value, (int, float)): if not math.isfinite(value): @@ -351,9 +391,9 @@ def __repr__(self) -> str: b = "..." if self.max_value is None else self.max_value return f"{type(self).__name__}[{self.underlying_type.__name__}, {a}, {b}]" - @classmethod + @staticmethod @abstractmethod - def _infer_type(cls, params: Tuple[Any, ...]) -> Type[Any]: + def _infer_type(params: Tuple[Any, ...]) -> type: raise NotImplementedError # hack to get `typing._type_check` to pass, e.g. when using `Range` as a generic parameter @@ -363,7 +403,7 @@ def __call__(self) -> NoReturn: # support new union syntax for `Range[int, 1, 2] | None` if sys.version_info >= (3, 10): - def __or__(self, other): + def __or__(self, other: type) -> UnionType: return Union[self, other] @@ -373,7 +413,7 @@ def __or__(self, other): else: @dataclass(frozen=True, repr=False) - class Range(_BaseRange): + class Range(_BaseRange[Union[int, float]]): """Type representing a number with a limited range of allowed values. See :ref:`param_ranges` for more information. @@ -392,18 +432,27 @@ def __post_init__(self) -> None: if value is None: continue - if self.underlying_type is int and not isinstance(value, int): + if self.underlying_type is not float and not isinstance(value, int): msg = "Range[int, ...] bounds must be int, not float" raise TypeError(msg) - @classmethod - def _infer_type(cls, params: Tuple[Any, ...]) -> Type[Any]: + if self.underlying_type is int and abs(value) >= 2**53: + msg = ( + "Discord has upper input limit on integer input type of +/-2**53.\n" + " For larger values, use Range[commands.LargeInt, ...], which will use" + " string input type with length limited to the minimum and maximum string" + " representations of the range bounds." + ) + raise ValueError(msg) + + @staticmethod + def _infer_type(params: Tuple[Any, ...]) -> type: if any(isinstance(p, float) for p in params): return float return int @dataclass(frozen=True, repr=False) - class String(_BaseRange): + class String(_BaseRange[int]): """Type representing a string option with a limited length. See :ref:`string_lengths` for more information. @@ -429,13 +478,13 @@ def __post_init__(self) -> None: msg = "String bounds may not be negative" raise ValueError(msg) - @classmethod - def _infer_type(cls, params: Tuple[Any, ...]) -> Type[Any]: + @staticmethod + def _infer_type(params: Tuple[Any, ...]) -> type: return str class LargeInt(int): - """Type for large integers in slash commands.""" + """Type representing integers `<=-2**53`, `>=2**53` in slash commands.""" # option types that require additional handling in verify_type @@ -490,23 +539,29 @@ class ParamInfo: .. versionadded:: 2.6 """ - TYPES: ClassVar[Dict[Union[type, UnionType], int]] = { + if sys.version_info >= (3, 10): + TYPES: ClassVar[Mapping[Union[type, UnionType], int]] + + else: + TYPES: ClassVar[Mapping[Union[type, object], int]] + + TYPES = { # noqa: RUF012 str: OptionType.string.value, int: OptionType.integer.value, bool: OptionType.boolean.value, + float: OptionType.number.value, disnake.abc.User: OptionType.user.value, disnake.User: OptionType.user.value, disnake.Member: OptionType.user.value, Union[disnake.User, disnake.Member]: OptionType.user.value, - # channels handled separately - disnake.abc.GuildChannel: OptionType.channel.value, disnake.Role: OptionType.role.value, disnake.abc.Snowflake: OptionType.mentionable.value, Union[disnake.Member, disnake.Role]: OptionType.mentionable.value, Union[disnake.User, disnake.Role]: OptionType.mentionable.value, Union[disnake.User, disnake.Member, disnake.Role]: OptionType.mentionable.value, - float: OptionType.number.value, disnake.Attachment: OptionType.attachment.value, + # channels handled separately + disnake.abc.GuildChannel: OptionType.channel.value, } # fmt: skip _registered_converters: ClassVar[Dict[type, Callable[..., Any]]] = {} @@ -522,10 +577,10 @@ def __init__( choices: Optional[Choices] = None, type: Optional[type] = None, channel_types: Optional[List[ChannelType]] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, + lt: Union[int, float, None] = None, + le: Union[int, float, None] = None, + gt: Union[int, float, None] = None, + ge: Union[int, float, None] = None, large: bool = False, min_length: Optional[int] = None, max_length: Optional[int] = None, @@ -550,10 +605,10 @@ def __init__( self.choices = choices or [] self.type = type or str self.channel_types = channel_types or [] - self.max_value = _xt_to_xe(le, lt, -1) - self.min_value = _xt_to_xe(ge, gt, 1) - self.min_length = min_length - self.max_length = max_length + self.min_value: Union[int, float, None] = _xt_to_xe(ge, gt, 1) + self.max_value: Union[int, float, None] = _xt_to_xe(le, lt, -1) + self.min_length: Optional[int] = min_length + self.max_length: Optional[int] = max_length self.large = large def copy(self) -> Self: @@ -635,7 +690,7 @@ def __repr__(self) -> str: return f"{type(self).__name__}({args})" async def get_default(self, inter: ApplicationCommandInteraction) -> Any: - """Gets the default for an interaction""" + """Gets the default for an interaction.""" default = self.default if callable(self.default): default = self.default(inter) @@ -667,13 +722,19 @@ async def verify_type(self, inter: ApplicationCommandInteraction, argument: Any) return argument async def convert_argument(self, inter: ApplicationCommandInteraction, argument: Any) -> Any: - """Convert a value if a converter is given""" + """Convert a value if a converter is given.""" if self.large: try: argument = int(argument) except ValueError: raise errors.LargeIntConversionFailure(argument) from None + min_value = -math.inf if self.min_value is None else self.min_value + max_value = math.inf if self.max_value is None else self.max_value + + if not min_value <= argument <= max_value: + raise errors.LargeIntOutOfRange(argument, self.min_value, self.max_value) from None + if self.converter is None: # TODO: Custom validators return await self.verify_type(inter, argument) @@ -744,19 +805,26 @@ def parse_annotation(self, annotation: Any, converter_mode: bool = False) -> boo self.min_value = annotation.min_value self.max_value = annotation.max_value annotation = annotation.underlying_type - if isinstance(annotation, _String): - self.min_length = cast("Optional[int]", annotation.min_value) - self.max_length = cast("Optional[int]", annotation.max_value) + + elif isinstance(annotation, _String): + self.min_length = annotation.min_value + self.max_length = annotation.max_value annotation = annotation.underlying_type + if issubclass_(annotation, LargeInt): self.large = True annotation = int if self.large: - self.type = str if annotation is not int: msg = "Large integers must be annotated with int or LargeInt" raise TypeError(msg) + self.type = str + self.min_length, self.max_length = _unbound_range_to_str_len( + self.min_value, # pyright: ignore[reportArgumentType] + self.max_value, # pyright: ignore[reportArgumentType] + ) + elif annotation in self.TYPES: self.type = annotation elif ( @@ -764,7 +832,7 @@ def parse_annotation(self, annotation: Any, converter_mode: bool = False) -> boo or get_origin(annotation) is Literal ): self._parse_enum(annotation) - elif get_origin(annotation) in (Union, UnionType): + elif get_origin(annotation) in UnionTypes: args = annotation.__args__ if all( issubclass_(channel, (disnake.abc.GuildChannel, disnake.Thread)) for channel in args @@ -855,8 +923,8 @@ def to_option(self) -> Option: choices=self.choices or None, channel_types=self.channel_types, autocomplete=self.autocomplete is not None, - min_value=self.min_value, - max_value=self.max_value, + min_value=None if self.large else self.min_value, + max_value=None if self.large else self.max_value, min_length=self.min_length, max_length=self.max_length, ) @@ -1353,6 +1421,9 @@ def decorator(function: Callable[..., Any]) -> Injection: return decorator +TChoice = TypeVar("TChoice", bound="ApplicationCommandOptionChoiceValue") + + def option_enum( choices: Union[Dict[str, TChoice], List[TChoice]], **kwargs: TChoice ) -> Type[TChoice]: @@ -1401,10 +1472,10 @@ def converter_method(function: Any) -> ConverterMethod: def register_injection( - function: InjectionCallback[CogT, P, T_], + function: InjectionCallback[CogT, P, T], *, autocompleters: Optional[Dict[str, Callable]] = None, -) -> Injection[P, T_]: +) -> Injection[P, T]: """A decorator to register a global injection. .. versionadded:: 2.3 diff --git a/tests/ext/commands/test_params.py b/tests/ext/commands/test_params.py index 0f811a7204..80eb95a2b7 100644 --- a/tests/ext/commands/test_params.py +++ b/tests/ext/commands/test_params.py @@ -71,7 +71,7 @@ async def test_verify_type__invalid_member(self, annotation, arg_types) -> None: class TestBaseRange: @pytest.mark.parametrize("args", [int, (int,), (int, 1, 2, 3)]) def test_param_count(self, args) -> None: - with pytest.raises(TypeError, match=r"`Range` expects 3 type arguments"): + with pytest.raises(TypeError, match=r"`Range` expects 3 arguments"): commands.Range[args] @pytest.mark.parametrize("value", ["int", 42, Optional[int], Union[int, float]])