Skip to content

Commit c51d011

Browse files
committed
Implemented converter
1 parent 6a2f393 commit c51d011

File tree

5 files changed

+231
-10
lines changed

5 files changed

+231
-10
lines changed

dico_command/bot.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import traceback
66
import importlib
77
import dico
8+
from contextlib import suppress
89
from .command import Command
910
from .context import Context
11+
from .converter import AVAILABLE_CONVERTERS, ConverterBase
1012
from .exception import *
1113
from .utils import smart_split, is_coro
1214

@@ -26,6 +28,7 @@ def __init__(self,
2628
application_id: typing.Optional[dico.Snowflake.TYPING] = None,
2729
monoshard: bool = False,
2830
shard_count: typing.Optional[int] = None,
31+
shard_id: typing.Optional[int] = None,
2932
**cache_max_sizes: int):
3033
super().__init__(token,
3134
intents=intents,
@@ -35,6 +38,7 @@ def __init__(self,
3538
application_id=application_id,
3639
monoshard=monoshard,
3740
shard_count=shard_count,
41+
shard_id=shard_id,
3842
**cache_max_sizes)
3943
self.prefixes = [prefix] if not isinstance(prefix, list) else prefix
4044
self.commands = {}
@@ -98,13 +102,71 @@ async def execute_handler(self, message: dico.Message):
98102
try:
99103
try:
100104
args, kwargs = smart_split(ipt[1] if len(ipt) > 1 else "", cmd.args_data, subcommand=bool(cmd.subcommands))
105+
if not cmd.subcommands:
106+
args, kwargs = await self.convert_args(context, cmd.args_data, args, kwargs)
101107
except Exception as ex:
102108
raise InvalidArgument from ex
103109
self.logger.debug(f"Command {name} executed.")
104110
await cmd.invoke(context, *args, **kwargs)
105111
except Exception as ex:
106112
await self.handle_command_error(context, ex)
107113

114+
def get_converter(self, convert_type: typing.Any):
115+
if convert_type in [str, int, float, bool]:
116+
return convert_type
117+
elif convert_type == dico.Snowflake:
118+
return dico.Snowflake.ensure_snowflake
119+
elif issubclass(convert_type, ConverterBase):
120+
return convert_type(self)
121+
elif convert_type in AVAILABLE_CONVERTERS:
122+
return AVAILABLE_CONVERTERS[convert_type](self)
123+
124+
@staticmethod
125+
async def convert(context, value, *converters, safe: bool = False) -> typing.Optional[typing.Any]:
126+
orig = value
127+
for x in converters:
128+
if isinstance(x, ConverterBase):
129+
value = await x(context, orig)
130+
elif is_coro(x):
131+
value = await x(orig)
132+
else:
133+
value = x(orig)
134+
if value:
135+
return value
136+
if not safe:
137+
raise ConversionFailed(value=orig)
138+
return value
139+
140+
async def convert_args(self, context: Context, args_data: dict, args: typing.List[str], kwargs: typing.Dict[str, str]) \
141+
-> typing.Tuple[typing.List[typing.Any], typing.Dict[str, typing.Any]]:
142+
for i, x in enumerate(args.copy()):
143+
convert_type = [*args_data.values()][i]["annotation"]
144+
if not convert_type:
145+
continue
146+
converters = []
147+
if hasattr(convert_type, "__origin__") and convert_type.__origin__ is typing.Union:
148+
for t in convert_type.__args__:
149+
if t is not None:
150+
converters.append(self.get_converter(t))
151+
else:
152+
converters.append(self.get_converter(convert_type))
153+
resp = await self.convert(context, x, *[x for x in converters if x])
154+
args[i] = resp
155+
for k, v in kwargs.items():
156+
convert_type = args_data[k]["annotation"]
157+
if not convert_type:
158+
continue
159+
converters = []
160+
if hasattr(convert_type, "__origin__") and convert_type.__origin__ is typing.Union:
161+
for t in convert_type.__args__:
162+
if t is not None:
163+
converters.append(self.get_converter(t))
164+
else:
165+
converters.append(self.get_converter(convert_type))
166+
resp = await self.convert(context, v, *[x for x in converters if x])
167+
kwargs[k] = resp
168+
return args, kwargs
169+
108170
def add_command(self, command: Command):
109171
if command.name in self.commands:
110172
raise CommandAlreadyExists(name=command.name)

dico_command/command.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def __init__(self,
2525
self.is_subcommand = is_subcommand
2626

2727
def subcommand(self, *args, **kwargs):
28+
if self.is_subcommand:
29+
raise AttributeError("unable to create subcommand in subcommand.")
30+
2831
def wrap(coro):
2932
cmd = command(*args, **kwargs)(coro)
3033
cmd.is_subcommand = True
@@ -68,11 +71,13 @@ async def invoke(self, ctx: Context, *args, **kwargs):
6871
del args[0]
6972
subcommand_invoking = True
7073
elif kwargs and [*kwargs.values()][0] in self.subcommands:
71-
subcommand = kwargs.pop([*kwargs.keys()][0])
74+
subcommand = self.subcommands[kwargs.pop([*kwargs.keys()][0])]
7275
tgt = subcommand.invoke
7376
subcommand_invoking = True
74-
elif kwargs or args:
77+
elif (args or kwargs) and not self.args_data:
7578
raise InvalidArgument("unknown subcommand or invalid argument passed.")
79+
else:
80+
args, kwargs = await ctx.bot.convert_args(ctx, self.args_data, args, kwargs)
7681
elif (args or kwargs) and not self.args_data:
7782
raise InvalidArgument("invalid argument data.")
7883
if subcommand_invoking:
@@ -81,6 +86,8 @@ async def invoke(self, ctx: Context, *args, **kwargs):
8186
ipt = msg.split(maxsplit=1)
8287
ipt = ipt[1].split(maxsplit=1) if len(ipt) > 1 else []
8388
args, kwargs = smart_split(ipt[1] if len(ipt) > 1 else "", subcommand.args_data, subcommand=bool(subcommand.subcommands))
89+
if not subcommand.subcommands:
90+
args, kwargs = await ctx.bot.convert_args(ctx, subcommand.args_data, args, kwargs)
8491
init_args = (ctx,) if self.addon is None or subcommand_invoking else (self.addon, ctx)
8592
return await tgt(*init_args, *args, **kwargs)
8693

dico_command/converter.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import re
2+
3+
from abc import ABC, abstractmethod
4+
from contextlib import suppress
5+
from typing import TypeVar, Sequence, Generic, Type, Optional, TYPE_CHECKING, List
6+
7+
import dico
8+
from dico.exception import HTTPError
9+
10+
from .utils import search, maybe_fmt
11+
12+
if TYPE_CHECKING:
13+
from .bot import Bot
14+
from .context import Context
15+
16+
17+
T = TypeVar("T")
18+
19+
20+
class ConverterBase(ABC, Generic[T]):
21+
CONVERT_TYPE: Type[T]
22+
23+
def __init__(self, bot: "Bot"):
24+
if not hasattr(self, "CONVERT_TYPE"):
25+
raise TypeError("Converter must have CONVERT_TYPE attribute")
26+
self.bot: "Bot" = bot
27+
self.cache_type = self.CONVERT_TYPE._cache_type if hasattr(self.CONVERT_TYPE, "_cache_type") else None
28+
29+
def dump_from_cache(self, guild_id: Optional[dico.Snowflake] = None) -> List[T]:
30+
if not self.cache_type:
31+
raise TypeError("dump_from_cache can only be used with DiscordObjectBase")
32+
if self.bot.has_cache:
33+
cache = self.bot.cache if not guild_id else self.bot.cache.get_guild_container(guild_id)
34+
objects = cache.get_storage(self.cache_type) if cache else []
35+
return [x["value"] for x in objects] if objects else []
36+
return []
37+
38+
def __call__(self, *args, **kwargs):
39+
return self.convert(*args, **kwargs)
40+
41+
@abstractmethod
42+
async def convert(self, ctx: "Context", value: str) -> Optional[T]:
43+
pass
44+
45+
46+
class UserConverter(ConverterBase):
47+
CONVERT_TYPE = dico.User
48+
49+
async def convert(self, ctx: "Context", value: str) -> Optional[T]:
50+
cached = self.dump_from_cache()
51+
cached.extend([x.user if isinstance(x, dico.GuildMember) else x for x in ctx.mentions])
52+
maybe_mention = maybe_fmt(value)
53+
maybe_id = value if re.match(r"^\d+$", value) else maybe_mention
54+
with suppress(HTTPError):
55+
if maybe_id:
56+
return search(cached, id=maybe_id) or await self.bot.http.request_user(maybe_id)
57+
from_username = search(cached, username=value)
58+
if from_username:
59+
return from_username
60+
from_fullname = search(cached, __str__=value)
61+
if from_fullname:
62+
return from_fullname
63+
64+
65+
class GuildMemberConverter(ConverterBase):
66+
CONVERT_TYPE =dico. GuildMember
67+
68+
def __init__(self, bot: "Bot"):
69+
super().__init__(bot)
70+
self.cache_type = "member"
71+
72+
async def convert(self, ctx: "Context", value: str) -> Optional[T]:
73+
cached = self.dump_from_cache(ctx.guild_id)
74+
cached.extend([x for x in ctx.mentions if isinstance(x, dico.GuildMember)])
75+
maybe_mention = maybe_fmt(value)
76+
maybe_id = value if re.match(r"^\d+$", value) else maybe_mention
77+
with suppress(HTTPError):
78+
if maybe_id:
79+
return search(cached, id=maybe_id) or await self.bot.http.request_user(maybe_id)
80+
from_name = search(cached, __str__=value)
81+
if from_name:
82+
return from_name
83+
84+
85+
class ChannelConverter(ConverterBase):
86+
CONVERT_TYPE = dico.Channel
87+
88+
async def convert(self, ctx: "Context", value: str) -> Optional[T]:
89+
cached = self.dump_from_cache()
90+
cached.append(ctx.channel)
91+
maybe_mention = maybe_fmt(value)
92+
maybe_id = value if re.match(r"^\d+$", value) else maybe_mention
93+
with suppress(HTTPError):
94+
if maybe_id:
95+
return search(cached, id=maybe_id) or await self.bot.http.request_user(maybe_id)
96+
from_name = search(cached, name=value)
97+
if from_name:
98+
return from_name
99+
100+
101+
class RoleConverter(ConverterBase):
102+
CONVERT_TYPE = dico.Role
103+
104+
async def convert(self, ctx: "Context", value: str) -> Optional[T]:
105+
cached = self.dump_from_cache()
106+
maybe_mention = maybe_fmt(value)
107+
maybe_id = value if re.match(r"^\d+$", value) else maybe_mention
108+
with suppress(HTTPError):
109+
if maybe_id:
110+
return search(cached, id=maybe_id) or await self.bot.http.request_user(maybe_id)
111+
from_name = search(cached, name=value)
112+
if from_name:
113+
return from_name
114+
115+
116+
AVAILABLE_CONVERTERS = {
117+
dico.User: UserConverter,
118+
dico.GuildMember: GuildMemberConverter,
119+
dico.Channel: ChannelConverter,
120+
dico.Role: RoleConverter
121+
}

dico_command/exception.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ class MissingUnloadFunction(CommandException):
4545

4646
class AddonAlreadyLoaded(CommandException):
4747
"""Addon {name} is already loaded."""
48+
49+
50+
class ConversionFailed(CommandException):
51+
"""Converting {value} has failed."""

dico_command/utils.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
import re
12
import typing
23
import inspect
34

45

6+
SPLIT_PATTERN = re.compile(r'((".+")|[.\S]+)')
7+
FMT_REGEX = re.compile(r'^<[at]?(:[^:]*)?(@[!&]?|#|:)(\d+)(:)?.*>$')
8+
T = typing.TypeVar("T")
9+
10+
511
def is_coro(coro):
612
return inspect.iscoroutinefunction(coro) or inspect.isawaitable(coro) or inspect.iscoroutine(coro)
713

@@ -16,31 +22,36 @@ def read_function(func):
1622
ret[x.name] = {
1723
"required": x.default == inspect._empty, # noqa
1824
"default": x.default,
19-
"annotation": x.annotation,
25+
"annotation": x.annotation if x.annotation != inspect._empty else None, # noqa
2026
"kind": x.kind
2127
}
2228
return ret
2329

2430

25-
def smart_split(ipt: str, args_data: dict, splitter: str = " ", subcommand: bool = False) -> typing.Tuple[list, dict]:
31+
def smart_split(ipt: str, args_data: dict, subcommand: bool = False) -> typing.Tuple[list, dict]:
2632
if len(args_data) == 0:
2733
if subcommand and ipt:
28-
return [*ipt.split(splitter)], {}
34+
return [*[x[0] for x in re.findall(SPLIT_PATTERN, ipt)]], {}
2935
return [], {}
30-
raw_split = ipt.split(splitter)
31-
# TODO: handle "..."
36+
raw_split = [x[0] for x in re.findall(SPLIT_PATTERN, ipt)]
3237
initial_split = raw_split
3338
args_name = [*args_data.keys()]
3439
last_arg = args_data[args_name[-1]]
3540
var_positional_in = True in [*map(lambda n: n["kind"] == n["kind"].VAR_POSITIONAL, args_data.values())]
3641
keyword_only_count = len([x for x in args_data.values() if x["kind"] == x["kind"].KEYWORD_ONLY])
3742
if len(args_data) == 1:
3843
if last_arg["kind"] == last_arg["kind"].VAR_POSITIONAL:
39-
return [ipt], {}
44+
if ipt:
45+
return [ipt], {}
46+
else:
47+
return [], {}
4048
elif last_arg["kind"] == last_arg["kind"].KEYWORD_ONLY:
41-
return [], {args_name[-1]: ipt}
49+
if ipt:
50+
return [], {args_name[-1]: ipt}
51+
else:
52+
return [], {}
4253
else:
43-
return [initial_split[0]], {}
54+
return [initial_split[0]] if initial_split else [], {}
4455
if (len(initial_split) == len(args_data) and not keyword_only_count) or last_arg["kind"] == last_arg["kind"].VAR_POSITIONAL: # assuming this matches
4556
return initial_split, {}
4657
if len(initial_split) != len(args_data) and not var_positional_in and not keyword_only_count:
@@ -62,3 +73,19 @@ def smart_split(ipt: str, args_data: dict, splitter: str = " ", subcommand: bool
6273
args.append(initial_split[i])
6374
ipt = ipt.split(initial_split[i], 1)[-1].lstrip()
6475
return args, kwargs
76+
77+
78+
def maybe_fmt(value: str) -> typing.Optional[str]:
79+
match = FMT_REGEX.match(value)
80+
if match:
81+
return match.group(3)
82+
83+
84+
def search(items: typing.Sequence[T], **attributes) -> typing.Optional[T]:
85+
for x in items:
86+
for k, v in attributes.items():
87+
resp = getattr(x, k)
88+
if inspect.ismethod(resp):
89+
resp = resp()
90+
if resp == v:
91+
return x

0 commit comments

Comments
 (0)