2525import net .dv8tion .jda .api .interactions .modals .Modal ;
2626import net .dv8tion .jda .api .requests .RestAction ;
2727import net .dv8tion .jda .api .requests .restaction .AuditableRestAction ;
28+ import net .dv8tion .jda .api .requests .restaction .MessageCreateAction ;
2829import net .dv8tion .jda .api .requests .restaction .MessageEditAction ;
2930import net .dv8tion .jda .api .requests .restaction .WebhookMessageCreateAction ;
3031import net .dv8tion .jda .api .requests .restaction .WebhookMessageEditAction ;
5253import static org .mockito .ArgumentMatchers .anyBoolean ;
5354import static org .mockito .ArgumentMatchers .anyCollection ;
5455import static org .mockito .ArgumentMatchers .anyList ;
56+ import static org .mockito .ArgumentMatchers .anyLong ;
5557import static org .mockito .ArgumentMatchers .anyString ;
5658import static org .mockito .Mockito .doAnswer ;
5759import static org .mockito .Mockito .mock ;
@@ -63,9 +65,12 @@ public class MockJDA {
6365
6466 private static final Map <Long , String > messages = new HashMap <>();
6567 private static final Map <Long , MessageEmbed []> embeds = new HashMap <>();
68+ private static final Map <Long , Member > members = new HashMap <>();
69+
6670 private static long CURRENT_ID = 0 ;
6771
6872 public static final JDA JDA = JDAObjects .getJDA ();
73+ public static final Member SELF = mockMember ("Bot" );
6974 public static final Guild GUILD = mockGuild ();
7075 public static Modal CURRENT_MODAL = null ;
7176
@@ -80,6 +85,14 @@ public class MockJDA {
8085 public static final Role AUTHOR = mockRole ("Author" , CONFIG .getLong ("author_role" ), 1 );
8186 public static final List <Role > ROLES = List .of (ADMINISTRATOR , MAINTAINER , REVIEWER , AUTHOR );
8287
88+ public static String getMessage (long id ) {
89+ return messages .get (id );
90+ }
91+
92+ public static List <MessageEmbed > getEmbeds (long id ) {
93+ return Arrays .asList (embeds .get (id ));
94+ }
95+
8396 public static InteractionHook mockInteractionHook (Member user , MessageChannel channel , InteractionType type ) {
8497 return mockInteractionHook (mockInteraction (user , channel , type ));
8598 }
@@ -97,12 +110,12 @@ public static InteractionHook mockInteractionHook(Interaction interaction) {
97110 when (hook .editOriginal (anyString ())).thenAnswer (inv -> {
98111 String content = inv .getArgument (0 );
99112 messages .put (hook .getIdLong (), content );
100- return mockReply (WebhookMessageEditAction .class , hook , mockMessage (content , channel ));
113+ return mockWebhookReply (WebhookMessageEditAction .class , hook , mockMessage (content , channel ));
101114 });
102- when (hook .editOriginalFormat (anyString (), any ())).thenAnswer (inv -> {
103- String content = String .format (inv .getArgument (0 ), (Object []) inv .getArguments ()[1 ]);
115+ when (hook .editOriginalFormat (anyString (), any (Object []. class ))).thenAnswer (inv -> {
116+ String content = String .format (inv .getArgument (0 ), (Object []) inv .getRawArguments ()[1 ]);
104117 messages .put (hook .getIdLong (), content );
105- return mockReply (WebhookMessageEditAction .class , hook , mockMessage (content , channel ));
118+ return mockWebhookReply (WebhookMessageEditAction .class , hook , mockMessage (content , channel ));
106119 });
107120
108121 when (hook .editOriginalEmbeds (any (MessageEmbed [].class ))).thenAnswer (inv -> {
@@ -112,7 +125,7 @@ public static InteractionHook mockInteractionHook(Interaction interaction) {
112125 else
113126 embeds .put (hook .getIdLong (), new MessageEmbed [] { (MessageEmbed ) obj });
114127
115- return mockReply (WebhookMessageEditAction .class , hook , mockMessage (null , Arrays .asList (embeds .get (hook .getIdLong ())), channel ));
128+ return mockWebhookReply (WebhookMessageEditAction .class , hook , mockMessage (null , Arrays .asList (embeds .get (hook .getIdLong ())), channel ));
116129 });
117130
118131 when (hook .sendMessageEmbeds (any (), any (MessageEmbed [].class ))).thenAnswer (inv -> {
@@ -129,7 +142,7 @@ public static InteractionHook mockInteractionHook(Interaction interaction) {
129142 embeds .add ((MessageEmbed ) obj );
130143 }
131144
132- return mockReply (WebhookMessageCreateAction .class , hook , mockMessage (null , embeds , channel ));
145+ return mockWebhookReply (WebhookMessageCreateAction .class , hook , mockMessage (null , embeds , channel ));
133146 });
134147
135148 return hook ;
@@ -140,6 +153,7 @@ public static Interaction mockInteraction(Member user, MessageChannel channel, I
140153 when (interaction .getJDA ()).thenReturn (JDA );
141154 when (interaction .getChannel ()).thenReturn (channel );
142155 when (interaction .getMember ()).thenReturn (user );
156+ when (interaction .getUser ()).thenAnswer (inv -> user .getUser ());
143157 when (interaction .getGuild ()).thenReturn (GUILD );
144158 when (interaction .getTypeRaw ()).thenReturn (type .getKey ());
145159 when (interaction .getIdLong ()).thenReturn (CURRENT_ID );
@@ -149,19 +163,66 @@ public static Interaction mockInteraction(Member user, MessageChannel channel, I
149163
150164 public static TextChannel mockChannel (String configName ) {
151165 long id = CONFIG .getLong ("channels" , configName );
152- return (TextChannel ) JDAObjects .getMessageChannel (configName .replace ('_' , '-' ), id , Callback .single ());
166+ TextChannel channel = (TextChannel ) JDAObjects .getMessageChannel (configName .replace ('_' , '-' ), id , Callback .single ());
167+
168+ when (channel .getGuild ()).thenReturn (GUILD );
169+ when (channel .getJDA ()).thenReturn (JDA );
170+
171+ when (channel .retrieveMessageById (anyLong ())).thenAnswer (inv -> {
172+ long messageId = inv .getArgument (0 );
173+ String content = messages .get (messageId );
174+ Message message = mockMessage (
175+ content , Arrays .asList (embeds .getOrDefault (messageId , new MessageEmbed [0 ])), channel
176+ );
177+
178+ return mockAction (message );
179+ });
180+ when (channel .sendMessage (any (CharSequence .class ))).thenAnswer (inv -> {
181+ String content = inv .getArgument (0 );
182+ return mockReply (MessageCreateAction .class , mockMessage (content , channel ));
183+ });
184+ when (channel .sendMessage (any (MessageCreateData .class ))).thenAnswer (inv -> {
185+ MessageCreateData data = inv .getArgument (0 , MessageCreateData .class );
186+ String content = data .getContent ();
187+ List <MessageEmbed > embeds = data .getEmbeds ();
188+
189+ return mockReply (MessageCreateAction .class , mockMessage (content , embeds , channel ));
190+ });
191+ when (channel .sendMessageFormat (anyString (), any (Object [].class ))).thenAnswer (inv -> {
192+ String content = String .format (inv .getArgument (0 ), (Object []) inv .getRawArguments ()[1 ]);
193+ return mockReply (MessageCreateAction .class , mockMessage (content , channel ));
194+ });
195+ when (channel .sendMessageEmbeds (any (), any (MessageEmbed [].class ))).thenAnswer (inv -> {
196+ MessageEmbed first = inv .getArgument (0 );
197+
198+ List <MessageEmbed > embeds = new ArrayList <>();
199+ embeds .add (first );
200+ if (inv .getArguments ().length > 1 ) {
201+ Object obj = inv .getArgument (1 );
202+
203+ if (obj instanceof MessageEmbed [] allEmbeds )
204+ embeds .addAll (Arrays .asList (allEmbeds ));
205+ else
206+ embeds .add ((MessageEmbed ) obj );
207+ }
208+
209+ return mockReply (MessageCreateAction .class , mockMessage (null , embeds , channel ));
210+ });
211+
212+ return channel ;
153213 }
154214
155215 public static Message mockMessage (String content , MessageChannel channel ) {
156216 Message message = JDAObjects .getMessage (content , channel );
157217 messages .put (message .getIdLong (), content );
158218
219+ when (message .getContentRaw ()).thenAnswer (inv -> messages .get (message .getIdLong ()));
159220 when (message .getIdLong ()).thenReturn (CURRENT_ID );
160221 when (message .getGuild ()).thenReturn (GUILD );
161222
162223 when (message .editMessage (anyString ())).thenAnswer (inv -> {
163224 messages .put (message .getIdLong (), inv .getArgument (0 ));
164- return mockReply (MessageEditAction .class , mockInteractionHook (message .getMember (), channel , InteractionType .COMMAND ), message );
225+ return mockWebhookReply (MessageEditAction .class , mockInteractionHook (message .getMember (), channel , InteractionType .COMMAND ), message );
165226 });
166227
167228 return message ;
@@ -171,7 +232,13 @@ public static Message mockMessage(String content, List<MessageEmbed> embeds, Mes
171232 Message message = mockMessage (content , channel );
172233 MockJDA .embeds .put (message .getIdLong (), embeds .toArray (new MessageEmbed [0 ]));
173234
174- when (message .getEmbeds ()).thenReturn (embeds );
235+ when (message .getEmbeds ()).thenAnswer (inv -> Arrays .asList (MockJDA .embeds .get (message .getIdLong ())));
236+ when (message .getStartedThread ()).thenReturn (null );
237+ when (message .delete ()).thenAnswer (inv -> {
238+ messages .remove (message .getIdLong ());
239+ MockJDA .embeds .remove (message .getIdLong ());
240+ return mockAuditLog ();
241+ });
175242
176243 return message ;
177244 }
@@ -184,6 +251,7 @@ private static Guild mockGuild() {
184251 when (guild .getJDA ()).thenReturn (JDA );
185252 when (guild .getTextChannels ()).thenReturn (CHANNELS );
186253 when (guild .getRoles ()).thenReturn (ROLES );
254+ when (guild .getSelfMember ()).thenReturn (SELF );
187255
188256 when (guild .addRoleToMember (any (UserSnowflake .class ), any (Role .class ))).thenAnswer (inv -> {
189257 Member member = inv .getArgument (0 );
@@ -199,14 +267,26 @@ private static Guild mockGuild() {
199267 member .getRoles ().remove (role );
200268 return mockAuditLog ();
201269 });
202- when (guild .getRoleById (any ( Long . class ))).thenAnswer (inv -> {
270+ when (guild .getRoleById (anyLong ( ))).thenAnswer (inv -> {
203271 long id = inv .getArgument (0 );
204272 return ROLES .stream ().filter (role -> role .getIdLong () == id ).findFirst ().orElse (null );
205273 });
206- when (guild .getChannelById (any (), any ( Long . class ))).thenAnswer (inv -> {
274+ when (guild .getChannelById (any (), anyLong ( ))).thenAnswer (inv -> {
207275 long id = inv .getArgument (1 );
208276 return CHANNELS .stream ().filter (channel -> channel .getIdLong () == id ).findFirst ().orElse (null );
209277 });
278+ when (guild .getTextChannelById (anyLong ())).thenAnswer (inv -> {
279+ long id = inv .getArgument (0 );
280+ return CHANNELS .stream ().filter (channel -> channel .getIdLong () == id ).findFirst ().orElse (null );
281+ });
282+ when (guild .getMemberById (anyString ())).thenAnswer (inv -> {
283+ long id = Long .parseLong (inv .getArgument (0 ));
284+ return members .get (id );
285+ });
286+ when (guild .getMemberById (anyLong ())).thenAnswer (inv -> {
287+ long id = inv .getArgument (0 );
288+ return members .get (id );
289+ });
210290
211291 return guild ;
212292 }
@@ -215,14 +295,19 @@ public static Member mockMember(String username) {
215295 Member member = JDAObjects .getMember (username , "0000" );
216296
217297 long id = new SecureRandom ().nextLong ();
298+ members .put (id , member );
299+
300+ when (member .getJDA ()).thenReturn (JDA );
218301 when (member .getGuild ()).thenReturn (GUILD );
302+ when (member .getId ()).thenReturn (Long .toString (id ));
219303 when (member .getIdLong ()).thenReturn (id );
304+ when (member .getAsMention ()).thenReturn ("<@" + id + ">" );
220305
221306 List <Role > roles = new ArrayList <>();
222307 when (member .getRoles ()).thenReturn (roles );
223-
224308 when (member .getUser ().getEffectiveName ()).thenReturn (username );
225309 when (member .getUser ().getIdLong ()).thenReturn (id );
310+ when (member .getUser ().getAsMention ()).thenReturn ("<@" + id + ">" );
226311
227312 return member ;
228313 }
@@ -241,7 +326,7 @@ private static Role mockRole(String name, long id, int position) {
241326 return role ;
242327 }
243328
244- private static void assertEmbeds (long id , List <MessageEmbed > expectedOutputs , boolean ignoreTimestamp ) {
329+ public static void assertEmbeds (long id , List <MessageEmbed > expectedOutputs , boolean ignoreTimestamp ) {
245330 MessageEmbed [] embeds = MockJDA .embeds .get (id );
246331 if (embeds == null && expectedOutputs .isEmpty ()) return ;
247332
@@ -321,6 +406,8 @@ public static SlashCommandEvent mockSlashCommandEvent(MessageChannel channel, Bo
321406 GUILD .addRoleToMember (user , ADMINISTRATOR );
322407
323408 when (event .getMember ()).thenReturn (user );
409+ when (event .getUser ()).thenAnswer (inv -> user .getUser ());
410+ when (event .getOptions ()).thenAnswer (invocation -> options );
324411
325412 when (event .getOption (anyString ())).thenAnswer (invocation -> {
326413 if (options == null ) return null ;
@@ -429,7 +516,7 @@ private static ModalCallbackAction mockModalReply(Member user, MessageChannel ch
429516 return action ;
430517 }
431518
432- private static <T extends MessageRequest <?> & RestAction <?>> T mockReply (Class <T > clazz , InteractionHook hook , Message message ) {
519+ private static <T extends MessageRequest <?> & RestAction <?>> T mockWebhookReply (Class <T > clazz , InteractionHook hook , Message message ) {
433520 T action = mock (clazz );
434521
435522 doAnswer (inv -> {
@@ -461,10 +548,71 @@ private static <T extends MessageRequest<?> & RestAction<?>> T mockReply(Class<T
461548 return action ;
462549 }
463550
551+ private static <T extends MessageRequest <?> & RestAction <?>> T mockReply (Class <T > clazz , Message message ) {
552+ T action = mock (clazz );
553+
554+ doAnswer (inv -> {
555+ Consumer <Message > messageConsumer = inv .getArgument (0 );
556+ messageConsumer .accept (message );
557+ return null ;
558+ }).when (action ).queue (any ());
559+
560+ when (action .getEmbeds ()).thenAnswer (inv -> message .getEmbeds ());
561+ when (action .setEmbeds (anyCollection ())).thenAnswer (inv -> {
562+ Collection <MessageEmbed > embed = inv .getArgument (0 );
563+ embeds .put (message .getIdLong (), embed .toArray (new MessageEmbed [0 ]));
564+ return action ;
565+ });
566+
567+ return action ;
568+ }
569+
570+ @ SuppressWarnings ("unchecked" )
571+ private static <T > RestAction <T > mockAction (T object ) {
572+ RestAction <T > action = mock (RestAction .class );
573+
574+ doAnswer (inv -> {
575+ Consumer <T > consumer = inv .getArgument (0 );
576+ consumer .accept (object );
577+ return null ;
578+ }).when (action ).queue (any ());
579+
580+ doAnswer (inv -> {
581+ Consumer <T > consumer = inv .getArgument (0 );
582+ Consumer <Throwable > error = inv .getArgument (1 );
583+
584+ try {
585+ consumer .accept (object );
586+ } catch (Exception e ) {
587+ error .accept (e );
588+ }
589+ return null ;
590+ }).when (action ).queue (any (), any ());
591+
592+ return action ;
593+ }
594+
464595 @ SuppressWarnings ("unchecked" )
465596 private static AuditableRestAction <Void > mockAuditLog () {
466597 AuditableRestAction <Void > action = mock (AuditableRestAction .class );
467- doAnswer (inv -> null ).when (action ).queue (any ());
598+
599+ doAnswer (inv -> {
600+ Consumer <Void > consumer = inv .getArgument (0 );
601+ consumer .accept (null );
602+ return null ;
603+ }).when (action ).queue (any ());
604+
605+ doAnswer (inv -> {
606+ Consumer <Void > consumer = inv .getArgument (0 );
607+ Consumer <Throwable > error = inv .getArgument (1 );
608+
609+ try {
610+ consumer .accept (null );
611+ } catch (Exception e ) {
612+ error .accept (e );
613+ }
614+ return null ;
615+ }).when (action ).queue (any (), any ());
468616
469617 when (action .reason (any ())).thenReturn (action );
470618
0 commit comments