@@ -48,7 +48,10 @@ pub enum DiscordAuthError {
4848 DiscordApiError ,
4949}
5050
51- pub struct DiscordAuth ( pub DiscordUser ) ;
51+ pub struct DiscordAuth {
52+ pub user : DiscordUser ,
53+ pub token : String , // The full "Bearer <token>" string
54+ }
5255
5356#[ rocket:: async_trait]
5457impl < ' r > FromRequest < ' r > for DiscordAuth {
@@ -68,7 +71,10 @@ impl<'r> FromRequest<'r> for DiscordAuth {
6871
6972 // Verify the token with Discord API
7073 match verify_discord_token ( auth_header) . await {
71- Ok ( user) => Outcome :: Success ( DiscordAuth ( user) ) ,
74+ Ok ( user) => Outcome :: Success ( DiscordAuth {
75+ user,
76+ token : auth_header. to_string ( ) ,
77+ } ) ,
7278 Err ( _) => Outcome :: Error ( ( Status :: Forbidden , DiscordAuthError :: Invalid ) ) ,
7379 }
7480 }
@@ -140,6 +146,35 @@ fn verify_discord_token_test(token: &str) -> Result<DiscordUser, DiscordAuthErro
140146 }
141147}
142148
149+ /// Permission flag for MANAGE_SERVER (1 << 5 = 32)
150+ pub const MANAGE_SERVER : u64 = 1 << 5 ;
151+
152+ // Test guild permissions - maps user_id -> guild_id -> permissions
153+ static TEST_GUILD_PERMISSIONS : std:: sync:: atomic:: AtomicPtr < String > =
154+ std:: sync:: atomic:: AtomicPtr :: new ( std:: ptr:: null_mut ( ) ) ;
155+
156+ /// Set test guild permissions for testing
157+ /// Format: "user_id:guild_id:permissions,user_id:guild_id:permissions,..."
158+ #[ allow( dead_code) ]
159+ pub fn set_test_guild_permissions ( permissions : String ) {
160+ let boxed_string = Box :: new ( permissions) ;
161+ let ptr = Box :: into_raw ( boxed_string) ;
162+ TEST_GUILD_PERMISSIONS . store ( ptr, std:: sync:: atomic:: Ordering :: Relaxed ) ;
163+ }
164+
165+ /// Clear test guild permissions
166+ #[ allow( dead_code) ]
167+ pub fn clear_test_guild_permissions ( ) {
168+ let null_ptr = std:: ptr:: null_mut ( ) ;
169+ TEST_GUILD_PERMISSIONS . store ( null_ptr, std:: sync:: atomic:: Ordering :: Relaxed ) ;
170+ }
171+
172+ /// Check if test mode is enabled
173+ #[ allow( dead_code) ]
174+ pub fn is_test_mode ( ) -> bool {
175+ TEST_MODE . load ( std:: sync:: atomic:: Ordering :: Relaxed )
176+ }
177+
143178impl DiscordAuth {
144179 /// Check if the authenticated user is an admin
145180 pub fn is_admin ( & self ) -> bool {
@@ -151,20 +186,20 @@ impl DiscordAuth {
151186 if !test_admin_ids_ptr. is_null ( ) {
152187 let test_admin_ids = unsafe { & * test_admin_ids_ptr } ;
153188 let admin_ids: Vec < & str > = test_admin_ids. split ( ',' ) . collect ( ) ;
154- return admin_ids. contains ( & self . 0 . id . as_str ( ) ) ;
189+ return admin_ids. contains ( & self . user . id . as_str ( ) ) ;
155190 }
156191 }
157192
158193 // Use environment variable
159194 let admin_ids = env:: var ( "ADMIN_IDS" ) . unwrap_or_default ( ) ;
160195 let admin_ids: Vec < & str > = admin_ids. split ( ',' ) . collect ( ) ;
161- admin_ids. contains ( & self . 0 . id . as_str ( ) )
196+ admin_ids. contains ( & self . user . id . as_str ( ) )
162197 }
163198
164199 /// Check if the authenticated user can access data for the given user_id
165200 pub fn can_access_user ( & self , user_id : & str ) -> bool {
166201 // Users can always access their own data
167- if self . 0 . id == user_id {
202+ if self . user . id == user_id {
168203 return true ;
169204 }
170205
@@ -175,4 +210,115 @@ impl DiscordAuth {
175210
176211 false
177212 }
213+
214+ /// Check if the authenticated user has MANAGE_SERVER permission for a guild
215+ /// Uses Discord API to verify permissions via the guilds endpoint
216+ pub async fn has_manage_server_permission ( & self , guild_id : & str ) -> bool {
217+ // Admins always have permission
218+ if self . is_admin ( ) {
219+ return true ;
220+ }
221+
222+ // Check via Discord API
223+ check_guild_permission ( & self . token , guild_id)
224+ . await
225+ . unwrap_or ( false )
226+ }
227+
228+ /// Get the user's ID
229+ pub fn get_user_id ( & self ) -> & str {
230+ & self . user . id
231+ }
232+
233+ /// Get the auth token
234+ pub fn get_token ( & self ) -> & str {
235+ & self . token
236+ }
237+ }
238+
239+ #[ derive( Debug , Serialize , Deserialize ) ]
240+ pub struct DiscordGuildInfo {
241+ pub id : String ,
242+ pub name : String ,
243+ pub icon : Option < String > ,
244+ pub permissions : String , // Discord returns this as a string
245+ }
246+
247+ /// Check if user has MANAGE_SERVER permission for a guild via Discord API
248+ pub async fn check_guild_permission ( token : & str , guild_id : & str ) -> Result < bool , DiscordAuthError > {
249+ // Parse guild_id as i64 and use get_editable_guilds
250+ let guild_id_i64: i64 = guild_id. parse ( ) . map_err ( |_| DiscordAuthError :: Invalid ) ?;
251+ let editable = get_editable_guilds ( token, & [ guild_id_i64] ) . await ?;
252+ Ok ( editable. contains ( & guild_id_i64) )
253+ }
254+
255+ /// Get the list of guilds where user has MANAGE_SERVER permission
256+ pub async fn get_editable_guilds (
257+ token : & str ,
258+ guild_ids : & [ i64 ] ,
259+ ) -> Result < Vec < i64 > , DiscordAuthError > {
260+ // Check if we're in test mode
261+ let test_mode = TEST_MODE . load ( std:: sync:: atomic:: Ordering :: Relaxed ) ;
262+ if test_mode {
263+ // In test mode, use the test permissions set via set_test_guild_permissions
264+ let user = verify_discord_token ( token) . await ?;
265+ let mut editable = Vec :: new ( ) ;
266+ let test_perms_ptr = TEST_GUILD_PERMISSIONS . load ( std:: sync:: atomic:: Ordering :: Relaxed ) ;
267+ if !test_perms_ptr. is_null ( ) {
268+ let test_perms = unsafe { & * test_perms_ptr } ;
269+ for entry in test_perms. split ( ',' ) {
270+ let parts: Vec < & str > = entry. split ( ':' ) . collect ( ) ;
271+ if parts. len ( ) == 3 {
272+ let test_user = parts[ 0 ] ;
273+ let test_guild = parts[ 1 ] ;
274+ let perms: u64 = parts[ 2 ] . parse ( ) . unwrap_or ( 0 ) ;
275+ if test_user == user. id && ( perms & MANAGE_SERVER ) != 0 {
276+ if let Ok ( guild_id) = test_guild. parse :: < i64 > ( ) {
277+ if guild_ids. contains ( & guild_id) {
278+ editable. push ( guild_id) ;
279+ }
280+ }
281+ }
282+ }
283+ }
284+ }
285+ return Ok ( editable) ;
286+ }
287+
288+ let client = reqwest:: Client :: new ( ) ;
289+
290+ let response = client
291+ . get ( "https://discord.com/api/v10/users/@me/guilds" )
292+ . header ( "Authorization" , token)
293+ . send ( )
294+ . await ;
295+
296+ match response {
297+ Ok ( resp) => {
298+ if resp. status ( ) . is_success ( ) {
299+ let response_text = resp. text ( ) . await . unwrap_or_default ( ) ;
300+
301+ match serde_json:: from_str :: < Vec < DiscordGuildInfo > > ( & response_text) {
302+ Ok ( guilds) => {
303+ let mut editable = Vec :: new ( ) ;
304+ for guild in guilds {
305+ if let Ok ( guild_id) = guild. id . parse :: < i64 > ( ) {
306+ if guild_ids. contains ( & guild_id) {
307+ let perms: u64 = guild. permissions . parse ( ) . unwrap_or ( 0 ) ;
308+ if ( perms & MANAGE_SERVER ) != 0 {
309+ editable. push ( guild_id) ;
310+ }
311+ }
312+ }
313+ }
314+ Ok ( editable)
315+ }
316+ Err ( _e) => Err ( DiscordAuthError :: DiscordApiError ) ,
317+ }
318+ } else {
319+ Err ( DiscordAuthError :: Invalid )
320+ }
321+ }
322+ Err ( _e) => Err ( DiscordAuthError :: DiscordApiError ) ,
323+ }
178324}
0 commit comments