1+ use chrono:: { DateTime , Utc } ;
12use rocket:: http:: Status ;
23use rocket:: request:: { FromRequest , Outcome , Request } ;
34use serde:: { Deserialize , Serialize } ;
5+ use std:: collections:: HashMap ;
46use std:: env;
7+ use std:: sync:: RwLock ;
8+ use std:: time:: { Duration , Instant } ;
9+
10+ // Token cache: maps "Bearer <token>" -> CachedToken
11+ // Guild permission cache: maps "Bearer <token>" -> CachedGuildPermissions
12+ lazy_static:: lazy_static! {
13+ static ref TOKEN_CACHE : RwLock <HashMap <String , CachedToken >> = RwLock :: new( HashMap :: new( ) ) ;
14+ static ref GUILD_PERM_CACHE : RwLock <HashMap <String , CachedGuildPermissions >> = RwLock :: new( HashMap :: new( ) ) ;
15+ }
16+
17+ /// Fallback TTL used only in test mode where there is no real Discord expiry
18+ const TEST_TOKEN_TTL : Duration = Duration :: from_secs ( 600 ) ;
19+
20+ /// How long guild permission data is cached (permissions may change while token is still valid)
21+ const GUILD_PERM_TTL : Duration = Duration :: from_secs ( 300 ) ; // 5 minutes
22+
23+ /// Response from Discord's GET /oauth2/@me endpoint
24+ #[ derive( Debug , Deserialize ) ]
25+ struct OAuth2MeResponse {
26+ expires : String ,
27+ user : DiscordUser ,
28+ }
29+
30+ #[ derive( Clone ) ]
31+ struct CachedToken {
32+ user : DiscordUser ,
33+ expires_at : Instant ,
34+ }
35+
36+ #[ derive( Clone ) ]
37+ struct CachedGuildPermissions {
38+ /// Guild IDs where the user has MANAGE_SERVER permission
39+ editable_guild_ids : Vec < i64 > ,
40+ expires_at : Instant ,
41+ }
542
643// Test mode flag - set to true during tests
744static TEST_MODE : std:: sync:: atomic:: AtomicBool = std:: sync:: atomic:: AtomicBool :: new ( false ) ;
@@ -22,6 +59,8 @@ pub fn disable_test_mode() {
2259 // Clear test admin IDs
2360 let null_ptr = std:: ptr:: null_mut ( ) ;
2461 TEST_ADMIN_IDS . store ( null_ptr, std:: sync:: atomic:: Ordering :: Relaxed ) ;
62+ // Clear the token cache when leaving test mode
63+ clear_token_cache ( ) ;
2564}
2665
2766/// Set test admin IDs for testing
@@ -32,7 +71,7 @@ pub fn set_test_admin_ids(admin_ids: String) {
3271 TEST_ADMIN_IDS . store ( ptr, std:: sync:: atomic:: Ordering :: Relaxed ) ;
3372}
3473
35- #[ derive( Debug , Serialize , Deserialize ) ]
74+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
3675pub struct DiscordUser {
3776 pub id : String ,
3877 pub username : String ,
@@ -81,16 +120,35 @@ impl<'r> FromRequest<'r> for DiscordAuth {
81120}
82121
83122async fn verify_discord_token ( token : & str ) -> Result < DiscordUser , DiscordAuthError > {
123+ // Check cache first (works in both test and production mode)
124+ if let Some ( user) = get_cached_token ( token) {
125+ return Ok ( user) ;
126+ }
127+
84128 // Check if we're in test mode
85129 let test_mode = TEST_MODE . load ( std:: sync:: atomic:: Ordering :: Relaxed ) ;
130+
86131 if test_mode {
87- return verify_discord_token_test ( token) ;
132+ let user = verify_discord_token_test ( token) ?;
133+ cache_token_with_expiry ( token, & user, Instant :: now ( ) + TEST_TOKEN_TTL ) ;
134+ return Ok ( user) ;
88135 }
89136
137+ // Production: call /oauth2/@me which returns user data + real token expiry
138+ let ( user, expires_at) = verify_discord_token_prod ( token) . await ?;
139+ cache_token_with_expiry ( token, & user, expires_at) ;
140+ Ok ( user)
141+ }
142+
143+ /// Calls Discord's /oauth2/@me to verify the token and retrieve both the user
144+ /// data and the token's actual expiry timestamp.
145+ async fn verify_discord_token_prod (
146+ token : & str ,
147+ ) -> Result < ( DiscordUser , Instant ) , DiscordAuthError > {
90148 let client = reqwest:: Client :: new ( ) ;
91149
92150 let response = client
93- . get ( "https://discord.com/api/v10/users /@me" )
151+ . get ( "https://discord.com/api/v10/oauth2 /@me" )
94152 . header ( "Authorization" , token)
95153 . send ( )
96154 . await ;
@@ -100,8 +158,11 @@ async fn verify_discord_token(token: &str) -> Result<DiscordUser, DiscordAuthErr
100158 if resp. status ( ) . is_success ( ) {
101159 let response_text = resp. text ( ) . await . unwrap_or_default ( ) ;
102160
103- match serde_json:: from_str :: < DiscordUser > ( & response_text) {
104- Ok ( user) => Ok ( user) ,
161+ match serde_json:: from_str :: < OAuth2MeResponse > ( & response_text) {
162+ Ok ( oauth_resp) => {
163+ let expires_at = parse_discord_expiry ( & oauth_resp. expires ) ;
164+ Ok ( ( oauth_resp. user , expires_at) )
165+ }
105166 Err ( _e) => Err ( DiscordAuthError :: DiscordApiError ) ,
106167 }
107168 } else {
@@ -112,6 +173,105 @@ async fn verify_discord_token(token: &str) -> Result<DiscordUser, DiscordAuthErr
112173 }
113174}
114175
176+ /// Convert Discord's ISO-8601 `expires` string into a `std::time::Instant`.
177+ fn parse_discord_expiry ( expires : & str ) -> Instant {
178+ if let Ok ( expiry_dt) = expires. parse :: < DateTime < Utc > > ( ) {
179+ let remaining = expiry_dt
180+ . signed_duration_since ( Utc :: now ( ) )
181+ . to_std ( )
182+ . unwrap_or ( Duration :: ZERO ) ;
183+ Instant :: now ( ) + remaining
184+ } else {
185+ // If parsing fails, fall back to a short TTL so we re-verify soon
186+ Instant :: now ( ) + Duration :: from_secs ( 60 )
187+ }
188+ }
189+
190+ // ===== Token Cache Helpers =====
191+
192+ /// Look up a token in the cache, returning the user if it's present and not expired
193+ fn get_cached_token ( token : & str ) -> Option < DiscordUser > {
194+ let cache = TOKEN_CACHE . read ( ) . ok ( ) ?;
195+ if let Some ( cached) = cache. get ( token) {
196+ if cached. expires_at > Instant :: now ( ) {
197+ return Some ( cached. user . clone ( ) ) ;
198+ }
199+ }
200+ None
201+ }
202+
203+ /// Store a verified token with an explicit expiry instant
204+ fn cache_token_with_expiry ( token : & str , user : & DiscordUser , expires_at : Instant ) {
205+ if let Ok ( mut cache) = TOKEN_CACHE . write ( ) {
206+ cache. insert (
207+ token. to_string ( ) ,
208+ CachedToken {
209+ user : user. clone ( ) ,
210+ expires_at,
211+ } ,
212+ ) ;
213+ }
214+ }
215+
216+ /// Public helper that caches with the test-mode fallback TTL (used by tests)
217+ #[ allow( dead_code) ]
218+ pub fn cache_token ( token : & str , user : & DiscordUser ) {
219+ cache_token_with_expiry ( token, user, Instant :: now ( ) + TEST_TOKEN_TTL ) ;
220+ }
221+
222+ /// Remove a specific token from the cache (used by the logout endpoint)
223+ pub fn invalidate_token ( token : & str ) {
224+ if let Ok ( mut cache) = TOKEN_CACHE . write ( ) {
225+ cache. remove ( token) ;
226+ }
227+ if let Ok ( mut cache) = GUILD_PERM_CACHE . write ( ) {
228+ cache. remove ( token) ;
229+ }
230+ }
231+
232+ /// Clear all tokens and guild permissions from the cache
233+ #[ allow( dead_code) ]
234+ pub fn clear_token_cache ( ) {
235+ if let Ok ( mut cache) = TOKEN_CACHE . write ( ) {
236+ cache. clear ( ) ;
237+ }
238+ if let Ok ( mut cache) = GUILD_PERM_CACHE . write ( ) {
239+ cache. clear ( ) ;
240+ }
241+ }
242+
243+ /// Check if a token is currently cached (for testing)
244+ #[ allow( dead_code) ]
245+ pub fn is_token_cached ( token : & str ) -> bool {
246+ get_cached_token ( token) . is_some ( )
247+ }
248+
249+ // ===== Guild Permission Cache Helpers =====
250+
251+ /// Look up cached guild permissions for a token
252+ fn get_cached_guild_permissions ( token : & str ) -> Option < Vec < i64 > > {
253+ let cache = GUILD_PERM_CACHE . read ( ) . ok ( ) ?;
254+ if let Some ( cached) = cache. get ( token) {
255+ if cached. expires_at > Instant :: now ( ) {
256+ return Some ( cached. editable_guild_ids . clone ( ) ) ;
257+ }
258+ }
259+ None
260+ }
261+
262+ /// Store guild permission data in the cache
263+ fn cache_guild_permissions ( token : & str , editable_guild_ids : & [ i64 ] ) {
264+ if let Ok ( mut cache) = GUILD_PERM_CACHE . write ( ) {
265+ cache. insert (
266+ token. to_string ( ) ,
267+ CachedGuildPermissions {
268+ editable_guild_ids : editable_guild_ids. to_vec ( ) ,
269+ expires_at : Instant :: now ( ) + GUILD_PERM_TTL ,
270+ } ,
271+ ) ;
272+ }
273+ }
274+
115275fn verify_discord_token_test ( token : & str ) -> Result < DiscordUser , DiscordAuthError > {
116276 // Remove "Bearer " prefix if present
117277 let token = if let Some ( stripped) = token. strip_prefix ( "Bearer " ) {
@@ -257,6 +417,15 @@ pub async fn get_editable_guilds(
257417 token : & str ,
258418 guild_ids : & [ i64 ] ,
259419) -> Result < Vec < i64 > , DiscordAuthError > {
420+ // Check guild permission cache first
421+ if let Some ( cached_editable) = get_cached_guild_permissions ( token) {
422+ let filtered: Vec < i64 > = cached_editable
423+ . into_iter ( )
424+ . filter ( |id| guild_ids. contains ( id) )
425+ . collect ( ) ;
426+ return Ok ( filtered) ;
427+ }
428+
260429 // Check if we're in test mode
261430 let test_mode = TEST_MODE . load ( std:: sync:: atomic:: Ordering :: Relaxed ) ;
262431 if test_mode {
@@ -274,15 +443,19 @@ pub async fn get_editable_guilds(
274443 let perms: u64 = parts[ 2 ] . parse ( ) . unwrap_or ( 0 ) ;
275444 if test_user == user. id && ( perms & MANAGE_SERVER ) != 0 {
276445 if let Ok ( guild_id) = test_guild. parse :: < i64 > ( ) {
277- if guild_ids. contains ( & guild_id) {
278- editable. push ( guild_id) ;
279- }
446+ editable. push ( guild_id) ;
280447 }
281448 }
282449 }
283450 }
284451 }
285- return Ok ( editable) ;
452+ // Cache the full list of editable guilds, then filter for the requested ones
453+ cache_guild_permissions ( token, & editable) ;
454+ let filtered: Vec < i64 > = editable
455+ . into_iter ( )
456+ . filter ( |id| guild_ids. contains ( id) )
457+ . collect ( ) ;
458+ return Ok ( filtered) ;
286459 }
287460
288461 let client = reqwest:: Client :: new ( ) ;
@@ -300,18 +473,28 @@ pub async fn get_editable_guilds(
300473
301474 match serde_json:: from_str :: < Vec < DiscordGuildInfo > > ( & response_text) {
302475 Ok ( guilds) => {
303- let mut editable = Vec :: new ( ) ;
304- for guild in guilds {
305- if let Ok ( guild_id) = guild. id . parse :: < i64 > ( ) {
306- if guild_ids. contains ( & guild_id) {
307- let perms: u64 = guild. permissions . parse ( ) . unwrap_or ( 0 ) ;
308- if ( perms & MANAGE_SERVER ) != 0 {
309- editable. push ( guild_id) ;
310- }
476+ // Collect ALL guilds the user can manage, then cache them
477+ let all_editable: Vec < i64 > = guilds
478+ . iter ( )
479+ . filter_map ( |guild| {
480+ let guild_id = guild. id . parse :: < i64 > ( ) . ok ( ) ?;
481+ let perms: u64 = guild. permissions . parse ( ) . unwrap_or ( 0 ) ;
482+ if ( perms & MANAGE_SERVER ) != 0 {
483+ Some ( guild_id)
484+ } else {
485+ None
311486 }
312- }
313- }
314- Ok ( editable)
487+ } )
488+ . collect ( ) ;
489+
490+ cache_guild_permissions ( token, & all_editable) ;
491+
492+ // Return only the requested guild IDs
493+ let filtered: Vec < i64 > = all_editable
494+ . into_iter ( )
495+ . filter ( |id| guild_ids. contains ( id) )
496+ . collect ( ) ;
497+ Ok ( filtered)
315498 }
316499 Err ( _e) => Err ( DiscordAuthError :: DiscordApiError ) ,
317500 }
0 commit comments