Skip to content

Commit 1f0b73a

Browse files
committed
Add command specific before and after invoke
1 parent c4e4ed0 commit 1f0b73a

File tree

2 files changed

+92
-4
lines changed

2 files changed

+92
-4
lines changed

twitchio/ext/commands/context.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,12 @@ async def invoke(self) -> bool | None:
449449
await self._bot.after_invoke(self)
450450
if self._component:
451451
await self._component.component_after_invoke(self)
452+
453+
if self._command._after_hook:
454+
base_args: list[Any] = [self]
455+
base_args.insert(0, self._component) if self._component else None
456+
457+
await self._command._after_hook(*base_args)
452458
except Exception as e:
453459
payload = CommandErrorPayload(context=self, exception=CommandHookError(str(e), e))
454460
self.bot.dispatch("command_error", payload=payload)

twitchio/ext/commands/core.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
Coro: TypeAlias = Coroutine[Any, Any, None]
7777
CoroC: TypeAlias = Coroutine[Any, Any, bool]
7878

79+
7980
DT = TypeVar("DT")
8081
VT = TypeVar("VT")
8182

@@ -204,8 +205,8 @@ def __init__(
204205
self._name: str = name
205206
self.callback = callback
206207
self._aliases: list[str] = kwargs.get("aliases", [])
207-
self._guards: list[Callable[..., bool] | Callable[..., CoroC]] = getattr(self._callback, "__command_guards__", [])
208-
self._buckets: list[Bucket[Context]] = getattr(self._callback, "__command_cooldowns__", [])
208+
self._guards: list[Callable[..., bool] | Callable[..., CoroC]] = getattr(callback, "__command_guards__", [])
209+
self._buckets: list[Bucket[Context]] = getattr(callback, "__command_cooldowns__", [])
209210
self._guards_after_parsing = kwargs.get("guards_after_parsing", False)
210211
self._cooldowns_first = kwargs.get("cooldowns_before_guards", False)
211212

@@ -215,6 +216,9 @@ def __init__(
215216
self._parent: Group[Component_T, P] | None = kwargs.get("parent")
216217
self._bypass_global_guards: bool = kwargs.get("bypass_global_guards", False)
217218

219+
self._before_hook: Callable[[Component_T, Context], Coro] | Callable[[Context], Coro] | None = None
220+
self._after_hook: Callable[[Component_T, Context], Coro] | Callable[[Context], Coro] | None = None
221+
218222
def __repr__(self) -> str:
219223
return f"Command(name={self._name}, parent={self.parent})"
220224

@@ -542,8 +546,11 @@ async def _invoke(self, context: Context) -> Any:
542546
context._args = args
543547
context._kwargs = kwargs
544548

549+
base_args: list[Any] = [context]
545550
args: list[Any] = [context, *args]
551+
546552
args.insert(0, self._injected) if self._injected else None
553+
base_args.insert(0, self._injected) if self._injected else None
547554

548555
if self._guards_after_parsing:
549556
await self._run_guards(context)
@@ -556,6 +563,9 @@ async def _invoke(self, context: Context) -> Any:
556563
await context.bot.before_invoke(context)
557564
if self._injected is not None:
558565
await self._injected.component_before_invoke(context)
566+
567+
if self._before_hook:
568+
await self._before_hook(*base_args)
559569
except Exception as e:
560570
raise CommandHookError(str(e), e) from e
561571

@@ -608,9 +618,81 @@ def error(self, func: Any) -> Any:
608618
self._error = func
609619
return func
610620

611-
def before_invoke(self) -> None: ...
621+
def before_invoke(self, func: Any) -> Any:
622+
"""|deco|
623+
624+
A decorator which adds a local ``before_invoke`` callback to this command.
625+
626+
Similar to :meth:`~twitchio.ext.commands.Bot.before_invoke` except local to this command.
627+
628+
Example
629+
-------
630+
631+
.. code:: python3
632+
633+
@commands.command()
634+
async def test(ctx: commands.Context) -> None:
635+
...
636+
637+
@test.before_invoke
638+
async def test_before(ctx: commands.Context) -> None:
639+
# Open a database connection for example
640+
...
641+
642+
ctx.connection = connection
643+
...
644+
645+
@test.after_invoke
646+
async def test_after(ctx: commands.Context) -> None:
647+
# Close a database connection for example
648+
...
649+
650+
await ctx.connection.close()
651+
...
652+
"""
653+
if not asyncio.iscoroutinefunction(func):
654+
raise TypeError(f'The Command "before_invoke" callback for "{self._name}" must be a coroutine function.')
655+
656+
self._before_hook = func
657+
return func
658+
659+
def after_invoke(self, func: Any) -> Any:
660+
"""|deco|
661+
662+
A decorator which adds a local ``after_invoke`` callback to this command.
663+
664+
Similar to :meth:`~twitchio.ext.commands.Bot.after_invoke` except local to this command.
665+
666+
Example
667+
-------
668+
669+
.. code:: python3
612670
613-
def after_invoke(self) -> None: ...
671+
@commands.command()
672+
async def test(ctx: commands.Context) -> None:
673+
...
674+
675+
@test.before_invoke
676+
async def test_before(ctx: commands.Context) -> None:
677+
# Open a database connection for example
678+
...
679+
680+
ctx.connection = connection
681+
...
682+
683+
@test.after_invoke
684+
async def test_after(ctx: commands.Context) -> None:
685+
# Close a database connection for example
686+
...
687+
688+
await ctx.connection.close()
689+
...
690+
"""
691+
if not asyncio.iscoroutinefunction(func):
692+
raise TypeError(f'The Command "after_invoke" callback for "{self._name}" must be a coroutine function.')
693+
694+
self._after_hook = func
695+
return func
614696

615697

616698
class RewardCommand(Command[Component_T, P]):

0 commit comments

Comments
 (0)