Skip to content

Commit f15e1d6

Browse files
committed
Add API token cache
1 parent 8deee7f commit f15e1d6

File tree

9 files changed

+942
-585
lines changed

9 files changed

+942
-585
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ I would like to give a big thank you to everyone who has helped me on this journ
136136

137137
> Found a lot of hug images and also tracked down most artists of previously used ones to give them credit making Killua a much more considerate bot.
138138
139-
* [Scientia](https://github.com/ScientiaEtVeritas)
139+
* [Scientia](https://github.com/retkowski)
140140

141141
> Gave me the original idea for this bot and helped me in the early stages, enduring lots of pings and stupid questions.
142142

api/src/main.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use routes::news::{delete_news, edit_news, get_news, get_news_by_id, like_news,
2323
use routes::stats::get_stats;
2424
use routes::update::{update, update_cors};
2525
use routes::user::{edit_user, edit_user_by_id, get_userinfo, get_userinfo_by_id};
26+
use routes::auth::logout;
2627
use routes::vote::register_vote;
2728

2829
use fairings::cors::Cors;
@@ -67,6 +68,7 @@ fn rocket() -> _ {
6768
create_tag,
6869
edit_tag,
6970
delete_tag,
71+
logout,
7072
],
7173
)
7274
.attach(db::init())

api/src/routes/auth.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use rocket::http::Status;
2+
use rocket::serde::json::Json;
3+
use serde_json::{json, Value};
4+
5+
use super::common::discord_auth::{invalidate_token, DiscordAuth};
6+
7+
/// POST /logout - Invalidate the caller's cached OAuth token
8+
#[post("/logout")]
9+
pub fn logout(auth: DiscordAuth) -> (Status, Json<Value>) {
10+
invalidate_token(&auth.token);
11+
(Status::Ok, Json(json!({ "message": "Successfully logged out" })))
12+
}

api/src/routes/common/discord_auth.rs

Lines changed: 203 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,44 @@
1+
use chrono::{DateTime, Utc};
12
use rocket::http::Status;
23
use rocket::request::{FromRequest, Outcome, Request};
34
use serde::{Deserialize, Serialize};
5+
use std::collections::HashMap;
46
use std::env;
7+
use std::sync::RwLock;
8+
use std::time::{Duration, Instant};
9+
10+
// Token cache: maps "Bearer <token>" -> CachedToken
11+
// Guild permission cache: maps "Bearer <token>" -> CachedGuildPermissions
12+
lazy_static::lazy_static! {
13+
static ref TOKEN_CACHE: RwLock<HashMap<String, CachedToken>> = RwLock::new(HashMap::new());
14+
static ref GUILD_PERM_CACHE: RwLock<HashMap<String, CachedGuildPermissions>> = RwLock::new(HashMap::new());
15+
}
16+
17+
/// Fallback TTL used only in test mode where there is no real Discord expiry
18+
const TEST_TOKEN_TTL: Duration = Duration::from_secs(600);
19+
20+
/// How long guild permission data is cached (permissions may change while token is still valid)
21+
const GUILD_PERM_TTL: Duration = Duration::from_secs(300); // 5 minutes
22+
23+
/// Response from Discord's GET /oauth2/@me endpoint
24+
#[derive(Debug, Deserialize)]
25+
struct OAuth2MeResponse {
26+
expires: String,
27+
user: DiscordUser,
28+
}
29+
30+
#[derive(Clone)]
31+
struct CachedToken {
32+
user: DiscordUser,
33+
expires_at: Instant,
34+
}
35+
36+
#[derive(Clone)]
37+
struct CachedGuildPermissions {
38+
/// Guild IDs where the user has MANAGE_SERVER permission
39+
editable_guild_ids: Vec<i64>,
40+
expires_at: Instant,
41+
}
542

643
// Test mode flag - set to true during tests
744
static TEST_MODE: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
@@ -22,6 +59,8 @@ pub fn disable_test_mode() {
2259
// Clear test admin IDs
2360
let null_ptr = std::ptr::null_mut();
2461
TEST_ADMIN_IDS.store(null_ptr, std::sync::atomic::Ordering::Relaxed);
62+
// Clear the token cache when leaving test mode
63+
clear_token_cache();
2564
}
2665

2766
/// Set test admin IDs for testing
@@ -32,7 +71,7 @@ pub fn set_test_admin_ids(admin_ids: String) {
3271
TEST_ADMIN_IDS.store(ptr, std::sync::atomic::Ordering::Relaxed);
3372
}
3473

35-
#[derive(Debug, Serialize, Deserialize)]
74+
#[derive(Debug, Clone, Serialize, Deserialize)]
3675
pub struct DiscordUser {
3776
pub id: String,
3877
pub username: String,
@@ -81,16 +120,35 @@ impl<'r> FromRequest<'r> for DiscordAuth {
81120
}
82121

83122
async fn verify_discord_token(token: &str) -> Result<DiscordUser, DiscordAuthError> {
123+
// Check cache first (works in both test and production mode)
124+
if let Some(user) = get_cached_token(token) {
125+
return Ok(user);
126+
}
127+
84128
// Check if we're in test mode
85129
let test_mode = TEST_MODE.load(std::sync::atomic::Ordering::Relaxed);
130+
86131
if test_mode {
87-
return verify_discord_token_test(token);
132+
let user = verify_discord_token_test(token)?;
133+
cache_token_with_expiry(token, &user, Instant::now() + TEST_TOKEN_TTL);
134+
return Ok(user);
88135
}
89136

137+
// Production: call /oauth2/@me which returns user data + real token expiry
138+
let (user, expires_at) = verify_discord_token_prod(token).await?;
139+
cache_token_with_expiry(token, &user, expires_at);
140+
Ok(user)
141+
}
142+
143+
/// Calls Discord's /oauth2/@me to verify the token and retrieve both the user
144+
/// data and the token's actual expiry timestamp.
145+
async fn verify_discord_token_prod(
146+
token: &str,
147+
) -> Result<(DiscordUser, Instant), DiscordAuthError> {
90148
let client = reqwest::Client::new();
91149

92150
let response = client
93-
.get("https://discord.com/api/v10/users/@me")
151+
.get("https://discord.com/api/v10/oauth2/@me")
94152
.header("Authorization", token)
95153
.send()
96154
.await;
@@ -100,8 +158,11 @@ async fn verify_discord_token(token: &str) -> Result<DiscordUser, DiscordAuthErr
100158
if resp.status().is_success() {
101159
let response_text = resp.text().await.unwrap_or_default();
102160

103-
match serde_json::from_str::<DiscordUser>(&response_text) {
104-
Ok(user) => Ok(user),
161+
match serde_json::from_str::<OAuth2MeResponse>(&response_text) {
162+
Ok(oauth_resp) => {
163+
let expires_at = parse_discord_expiry(&oauth_resp.expires);
164+
Ok((oauth_resp.user, expires_at))
165+
}
105166
Err(_e) => Err(DiscordAuthError::DiscordApiError),
106167
}
107168
} else {
@@ -112,6 +173,105 @@ async fn verify_discord_token(token: &str) -> Result<DiscordUser, DiscordAuthErr
112173
}
113174
}
114175

176+
/// Convert Discord's ISO-8601 `expires` string into a `std::time::Instant`.
177+
fn parse_discord_expiry(expires: &str) -> Instant {
178+
if let Ok(expiry_dt) = expires.parse::<DateTime<Utc>>() {
179+
let remaining = expiry_dt
180+
.signed_duration_since(Utc::now())
181+
.to_std()
182+
.unwrap_or(Duration::ZERO);
183+
Instant::now() + remaining
184+
} else {
185+
// If parsing fails, fall back to a short TTL so we re-verify soon
186+
Instant::now() + Duration::from_secs(60)
187+
}
188+
}
189+
190+
// ===== Token Cache Helpers =====
191+
192+
/// Look up a token in the cache, returning the user if it's present and not expired
193+
fn get_cached_token(token: &str) -> Option<DiscordUser> {
194+
let cache = TOKEN_CACHE.read().ok()?;
195+
if let Some(cached) = cache.get(token) {
196+
if cached.expires_at > Instant::now() {
197+
return Some(cached.user.clone());
198+
}
199+
}
200+
None
201+
}
202+
203+
/// Store a verified token with an explicit expiry instant
204+
fn cache_token_with_expiry(token: &str, user: &DiscordUser, expires_at: Instant) {
205+
if let Ok(mut cache) = TOKEN_CACHE.write() {
206+
cache.insert(
207+
token.to_string(),
208+
CachedToken {
209+
user: user.clone(),
210+
expires_at,
211+
},
212+
);
213+
}
214+
}
215+
216+
/// Public helper that caches with the test-mode fallback TTL (used by tests)
217+
#[allow(dead_code)]
218+
pub fn cache_token(token: &str, user: &DiscordUser) {
219+
cache_token_with_expiry(token, user, Instant::now() + TEST_TOKEN_TTL);
220+
}
221+
222+
/// Remove a specific token from the cache (used by the logout endpoint)
223+
pub fn invalidate_token(token: &str) {
224+
if let Ok(mut cache) = TOKEN_CACHE.write() {
225+
cache.remove(token);
226+
}
227+
if let Ok(mut cache) = GUILD_PERM_CACHE.write() {
228+
cache.remove(token);
229+
}
230+
}
231+
232+
/// Clear all tokens and guild permissions from the cache
233+
#[allow(dead_code)]
234+
pub fn clear_token_cache() {
235+
if let Ok(mut cache) = TOKEN_CACHE.write() {
236+
cache.clear();
237+
}
238+
if let Ok(mut cache) = GUILD_PERM_CACHE.write() {
239+
cache.clear();
240+
}
241+
}
242+
243+
/// Check if a token is currently cached (for testing)
244+
#[allow(dead_code)]
245+
pub fn is_token_cached(token: &str) -> bool {
246+
get_cached_token(token).is_some()
247+
}
248+
249+
// ===== Guild Permission Cache Helpers =====
250+
251+
/// Look up cached guild permissions for a token
252+
fn get_cached_guild_permissions(token: &str) -> Option<Vec<i64>> {
253+
let cache = GUILD_PERM_CACHE.read().ok()?;
254+
if let Some(cached) = cache.get(token) {
255+
if cached.expires_at > Instant::now() {
256+
return Some(cached.editable_guild_ids.clone());
257+
}
258+
}
259+
None
260+
}
261+
262+
/// Store guild permission data in the cache
263+
fn cache_guild_permissions(token: &str, editable_guild_ids: &[i64]) {
264+
if let Ok(mut cache) = GUILD_PERM_CACHE.write() {
265+
cache.insert(
266+
token.to_string(),
267+
CachedGuildPermissions {
268+
editable_guild_ids: editable_guild_ids.to_vec(),
269+
expires_at: Instant::now() + GUILD_PERM_TTL,
270+
},
271+
);
272+
}
273+
}
274+
115275
fn verify_discord_token_test(token: &str) -> Result<DiscordUser, DiscordAuthError> {
116276
// Remove "Bearer " prefix if present
117277
let token = if let Some(stripped) = token.strip_prefix("Bearer ") {
@@ -257,6 +417,15 @@ pub async fn get_editable_guilds(
257417
token: &str,
258418
guild_ids: &[i64],
259419
) -> Result<Vec<i64>, DiscordAuthError> {
420+
// Check guild permission cache first
421+
if let Some(cached_editable) = get_cached_guild_permissions(token) {
422+
let filtered: Vec<i64> = cached_editable
423+
.into_iter()
424+
.filter(|id| guild_ids.contains(id))
425+
.collect();
426+
return Ok(filtered);
427+
}
428+
260429
// Check if we're in test mode
261430
let test_mode = TEST_MODE.load(std::sync::atomic::Ordering::Relaxed);
262431
if test_mode {
@@ -274,15 +443,19 @@ pub async fn get_editable_guilds(
274443
let perms: u64 = parts[2].parse().unwrap_or(0);
275444
if test_user == user.id && (perms & MANAGE_SERVER) != 0 {
276445
if let Ok(guild_id) = test_guild.parse::<i64>() {
277-
if guild_ids.contains(&guild_id) {
278-
editable.push(guild_id);
279-
}
446+
editable.push(guild_id);
280447
}
281448
}
282449
}
283450
}
284451
}
285-
return Ok(editable);
452+
// Cache the full list of editable guilds, then filter for the requested ones
453+
cache_guild_permissions(token, &editable);
454+
let filtered: Vec<i64> = editable
455+
.into_iter()
456+
.filter(|id| guild_ids.contains(id))
457+
.collect();
458+
return Ok(filtered);
286459
}
287460

288461
let client = reqwest::Client::new();
@@ -300,18 +473,28 @@ pub async fn get_editable_guilds(
300473

301474
match serde_json::from_str::<Vec<DiscordGuildInfo>>(&response_text) {
302475
Ok(guilds) => {
303-
let mut editable = Vec::new();
304-
for guild in guilds {
305-
if let Ok(guild_id) = guild.id.parse::<i64>() {
306-
if guild_ids.contains(&guild_id) {
307-
let perms: u64 = guild.permissions.parse().unwrap_or(0);
308-
if (perms & MANAGE_SERVER) != 0 {
309-
editable.push(guild_id);
310-
}
476+
// Collect ALL guilds the user can manage, then cache them
477+
let all_editable: Vec<i64> = guilds
478+
.iter()
479+
.filter_map(|guild| {
480+
let guild_id = guild.id.parse::<i64>().ok()?;
481+
let perms: u64 = guild.permissions.parse().unwrap_or(0);
482+
if (perms & MANAGE_SERVER) != 0 {
483+
Some(guild_id)
484+
} else {
485+
None
311486
}
312-
}
313-
}
314-
Ok(editable)
487+
})
488+
.collect();
489+
490+
cache_guild_permissions(token, &all_editable);
491+
492+
// Return only the requested guild IDs
493+
let filtered: Vec<i64> = all_editable
494+
.into_iter()
495+
.filter(|id| guild_ids.contains(id))
496+
.collect();
497+
Ok(filtered)
315498
}
316499
Err(_e) => Err(DiscordAuthError::DiscordApiError),
317500
}

api/src/routes/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod auth;
12
pub mod cards;
23
pub mod commands;
34
pub mod common;

0 commit comments

Comments
 (0)