7676Coro : TypeAlias = Coroutine [Any , Any , None ]
7777CoroC : TypeAlias = Coroutine [Any , Any , bool ]
7878
79+
7980DT = TypeVar ("DT" )
8081VT = TypeVar ("VT" )
8182
@@ -204,8 +205,8 @@ def __init__(
204205 self ._name : str = name
205206 self .callback = callback
206207 self ._aliases : list [str ] = kwargs .get ("aliases" , [])
207- self ._guards : list [Callable [..., bool ] | Callable [..., CoroC ]] = getattr (self . _callback , "__command_guards__" , [])
208- self ._buckets : list [Bucket [Context ]] = getattr (self . _callback , "__command_cooldowns__" , [])
208+ self ._guards : list [Callable [..., bool ] | Callable [..., CoroC ]] = getattr (callback , "__command_guards__" , [])
209+ self ._buckets : list [Bucket [Context ]] = getattr (callback , "__command_cooldowns__" , [])
209210 self ._guards_after_parsing = kwargs .get ("guards_after_parsing" , False )
210211 self ._cooldowns_first = kwargs .get ("cooldowns_before_guards" , False )
211212
@@ -215,6 +216,9 @@ def __init__(
215216 self ._parent : Group [Component_T , P ] | None = kwargs .get ("parent" )
216217 self ._bypass_global_guards : bool = kwargs .get ("bypass_global_guards" , False )
217218
219+ self ._before_hook : Callable [[Component_T , Context ], Coro ] | Callable [[Context ], Coro ] | None = None
220+ self ._after_hook : Callable [[Component_T , Context ], Coro ] | Callable [[Context ], Coro ] | None = None
221+
218222 def __repr__ (self ) -> str :
219223 return f"Command(name={ self ._name } , parent={ self .parent } )"
220224
@@ -542,8 +546,11 @@ async def _invoke(self, context: Context) -> Any:
542546 context ._args = args
543547 context ._kwargs = kwargs
544548
549+ base_args : list [Any ] = [context ]
545550 args : list [Any ] = [context , * args ]
551+
546552 args .insert (0 , self ._injected ) if self ._injected else None
553+ base_args .insert (0 , self ._injected ) if self ._injected else None
547554
548555 if self ._guards_after_parsing :
549556 await self ._run_guards (context )
@@ -556,6 +563,9 @@ async def _invoke(self, context: Context) -> Any:
556563 await context .bot .before_invoke (context )
557564 if self ._injected is not None :
558565 await self ._injected .component_before_invoke (context )
566+
567+ if self ._before_hook :
568+ await self ._before_hook (* base_args )
559569 except Exception as e :
560570 raise CommandHookError (str (e ), e ) from e
561571
@@ -608,9 +618,81 @@ def error(self, func: Any) -> Any:
608618 self ._error = func
609619 return func
610620
611- def before_invoke (self ) -> None : ...
621+ def before_invoke (self , func : Any ) -> Any :
622+ """|deco|
623+
624+ A decorator which adds a local ``before_invoke`` callback to this command.
625+
626+ Similar to :meth:`~twitchio.ext.commands.Bot.before_invoke` except local to this command.
627+
628+ Example
629+ -------
630+
631+ .. code:: python3
632+
633+ @commands.command()
634+ async def test(ctx: commands.Context) -> None:
635+ ...
636+
637+ @test.before_invoke
638+ async def test_before(ctx: commands.Context) -> None:
639+ # Open a database connection for example
640+ ...
641+
642+ ctx.connection = connection
643+ ...
644+
645+ @test.after_invoke
646+ async def test_after(ctx: commands.Context) -> None:
647+ # Close a database connection for example
648+ ...
649+
650+ await ctx.connection.close()
651+ ...
652+ """
653+ if not asyncio .iscoroutinefunction (func ):
654+ raise TypeError (f'The Command "before_invoke" callback for "{ self ._name } " must be a coroutine function.' )
655+
656+ self ._before_hook = func
657+ return func
658+
659+ def after_invoke (self , func : Any ) -> Any :
660+ """|deco|
661+
662+ A decorator which adds a local ``after_invoke`` callback to this command.
663+
664+ Similar to :meth:`~twitchio.ext.commands.Bot.after_invoke` except local to this command.
665+
666+ Example
667+ -------
668+
669+ .. code:: python3
612670
613- def after_invoke (self ) -> None : ...
671+ @commands.command()
672+ async def test(ctx: commands.Context) -> None:
673+ ...
674+
675+ @test.before_invoke
676+ async def test_before(ctx: commands.Context) -> None:
677+ # Open a database connection for example
678+ ...
679+
680+ ctx.connection = connection
681+ ...
682+
683+ @test.after_invoke
684+ async def test_after(ctx: commands.Context) -> None:
685+ # Close a database connection for example
686+ ...
687+
688+ await ctx.connection.close()
689+ ...
690+ """
691+ if not asyncio .iscoroutinefunction (func ):
692+ raise TypeError (f'The Command "after_invoke" callback for "{ self ._name } " must be a coroutine function.' )
693+
694+ self ._after_hook = func
695+ return func
614696
615697
616698class RewardCommand (Command [Component_T , P ]):
0 commit comments