6
6
import re
7
7
import sys
8
8
import typing
9
-
10
9
from datetime import datetime
10
+ from itertools import zip_longest
11
11
from types import SimpleNamespace
12
12
13
13
import discord
30
30
31
31
from core .clients import ApiClient , PluginDatabaseClient
32
32
from core .config import ConfigManager
33
- from core .utils import human_join , strtobool
33
+ from core .utils import human_join , strtobool , parse_alias
34
34
from core .models import PermissionLevel , ModmailLogger
35
35
from core .thread import ThreadManager
36
36
from core .time import human_timedelta
@@ -669,24 +669,76 @@ async def _process_blocked(self, message: discord.Message) -> bool:
669
669
logger .warning ("Failed to add reaction %s." , reaction , exc_info = True )
670
670
return str (message .author .id ) in self .blocked_users
671
671
672
- async def process_modmail (self , message : discord .Message ) -> None :
672
+ async def process_dm_modmail (self , message : discord .Message ) -> None :
673
673
"""Processes messages sent to the bot."""
674
- await self .wait_for_connected ()
675
-
676
674
blocked = await self ._process_blocked (message )
677
675
if not blocked :
678
676
thread = await self .threads .find_or_create (message .author )
679
677
await thread .send (message )
680
678
679
+ async def get_contexts (self , message , * , cls = commands .Context ):
680
+ """
681
+ Returns all invocation contexts from the message.
682
+ Supports getting the prefix from database as well as command aliases.
683
+ """
684
+
685
+ view = StringView (message .content )
686
+ ctx = cls (prefix = self .prefix , view = view , bot = self , message = message )
687
+ ctx .thread = await self .threads .find (channel = ctx .channel )
688
+
689
+ if self ._skip_check (message .author .id , self .user .id ):
690
+ return [ctx ]
691
+
692
+ prefixes = await self .get_prefix ()
693
+
694
+ invoked_prefix = discord .utils .find (view .skip_string , prefixes )
695
+ if invoked_prefix is None :
696
+ return [ctx ]
697
+
698
+ invoker = view .get_word ().lower ()
699
+
700
+ # Check if there is any aliases being called.
701
+ alias = self .aliases .get (invoker )
702
+ if alias is not None :
703
+ aliases = parse_alias (alias )
704
+ if not aliases :
705
+ logger .warning ("Alias %s is invalid, removing." , invoker )
706
+ self .aliases .pop (invoker )
707
+ else :
708
+ len_ = len (f"{ invoked_prefix } { invoker } " )
709
+ contents = parse_alias (message .content [len_ :])
710
+ if not contents :
711
+ contents = [message .content [len_ :]]
712
+
713
+ ctxs = []
714
+ for alias , content in zip_longest (aliases , contents ):
715
+ if alias is None :
716
+ break
717
+ ctx = cls (prefix = self .prefix , view = view , bot = self , message = message )
718
+ ctx .thread = await self .threads .find (channel = ctx .channel )
719
+
720
+ if content is not None :
721
+ view = StringView (f"{ alias } { content .strip ()} " )
722
+ else :
723
+ view = StringView (alias )
724
+ ctx .view = view
725
+ ctx .invoked_with = view .get_word ()
726
+ ctx .command = self .all_commands .get (ctx .invoked_with )
727
+ ctxs += [ctx ]
728
+ return ctxs
729
+
730
+ ctx .invoked_with = invoker
731
+ ctx .command = self .all_commands .get (invoker )
732
+ return [ctx ]
733
+
681
734
async def get_context (self , message , * , cls = commands .Context ):
682
735
"""
683
736
Returns the invocation context from the message.
684
- Supports getting the prefix from database as well as command aliases .
737
+ Supports getting the prefix from database.
685
738
"""
686
- await self .wait_for_connected ()
687
739
688
740
view = StringView (message .content )
689
- ctx = cls (prefix = None , view = view , bot = self , message = message )
741
+ ctx = cls (prefix = self . prefix , view = view , bot = self , message = message )
690
742
691
743
if self ._skip_check (message .author .id , self .user .id ):
692
744
return ctx
@@ -701,17 +753,7 @@ async def get_context(self, message, *, cls=commands.Context):
701
753
702
754
invoker = view .get_word ().lower ()
703
755
704
- # Check if there is any aliases being called.
705
- alias = self .aliases .get (invoker )
706
- if alias is not None :
707
- ctx ._alias_invoked = True # pylint: disable=W0212
708
- len_ = len (f"{ invoked_prefix } { invoker } " )
709
- view = StringView (f"{ alias } { ctx .message .content [len_ :]} " )
710
- ctx .view = view
711
- invoker = view .get_word ()
712
-
713
756
ctx .invoked_with = invoker
714
- ctx .prefix = self .prefix # Sane prefix (No mentions)
715
757
ctx .command = self .all_commands .get (invoker )
716
758
717
759
return ctx
@@ -739,47 +781,52 @@ async def update_perms(
739
781
740
782
async def on_message (self , message ):
741
783
await self .wait_for_connected ()
742
-
743
784
if message .type == discord .MessageType .pins_add and message .author == self .user :
744
785
await message .delete ()
786
+ await self .process_commands (message )
745
787
788
+ async def process_commands (self , message ):
746
789
if message .author .bot :
747
790
return
748
791
749
792
if isinstance (message .channel , discord .DMChannel ):
750
- return await self .process_modmail (message )
793
+ return await self .process_dm_modmail (message )
751
794
752
- prefix = self .prefix
795
+ if message .content .startswith (self .prefix ):
796
+ cmd = message .content [len (self .prefix ) :].strip ()
753
797
754
- if message .content .startswith (prefix ):
755
- cmd = message .content [len (prefix ) :].strip ()
798
+ # Process snippets
756
799
if cmd in self .snippets :
757
800
thread = await self .threads .find (channel = message .channel )
758
801
snippet = self .snippets [cmd ]
759
802
if thread :
760
803
snippet = snippet .format (recipient = thread .recipient )
761
- message .content = f"{ prefix } reply { snippet } "
804
+ message .content = f"{ self . prefix } reply { snippet } "
762
805
763
- ctx = await self .get_context (message )
764
- if ctx .command :
765
- return await self .invoke (ctx )
806
+ ctxs = await self .get_contexts (message )
807
+ for ctx in ctxs :
808
+ if ctx .command :
809
+ await self .invoke (ctx )
810
+ continue
766
811
767
- thread = await self .threads .find (channel = ctx .channel )
768
- if thread is not None :
769
- try :
770
- reply_without_command = strtobool (self .config ["reply_without_command" ])
771
- except ValueError :
772
- reply_without_command = self .config .remove ("reply_without_command" )
812
+ thread = await self .threads .find (channel = ctx .channel )
813
+ if thread is not None :
814
+ try :
815
+ reply_without_command = strtobool (
816
+ self .config ["reply_without_command" ]
817
+ )
818
+ except ValueError :
819
+ reply_without_command = self .config .remove ("reply_without_command" )
773
820
774
- if reply_without_command :
775
- await thread .reply (message )
776
- else :
777
- await self .api .append_log (message , type_ = "internal" )
778
- elif ctx .invoked_with :
779
- exc = commands .CommandNotFound (
780
- 'Command "{}" is not found' .format (ctx .invoked_with )
781
- )
782
- self .dispatch ("command_error" , ctx , exc )
821
+ if reply_without_command :
822
+ await thread .reply (message )
823
+ else :
824
+ await self .api .append_log (message , type_ = "internal" )
825
+ elif ctx .invoked_with :
826
+ exc = commands .CommandNotFound (
827
+ 'Command "{}" is not found' .format (ctx .invoked_with )
828
+ )
829
+ self .dispatch ("command_error" , ctx , exc )
783
830
784
831
async def on_typing (self , channel , user , _ ):
785
832
await self .wait_for_connected ()
0 commit comments