@@ -5,6 +5,9 @@ use std::env;
55
66// Test mode flag - set to true during tests
77static TEST_MODE : std:: sync:: atomic:: AtomicBool = std:: sync:: atomic:: AtomicBool :: new ( false ) ;
8+ // Test admin IDs - set during tests
9+ static TEST_ADMIN_IDS : std:: sync:: atomic:: AtomicPtr < String > =
10+ std:: sync:: atomic:: AtomicPtr :: new ( std:: ptr:: null_mut ( ) ) ;
811
912/// Enable test mode for Discord authentication
1013#[ allow( dead_code) ]
@@ -16,6 +19,17 @@ pub fn enable_test_mode() {
1619#[ allow( dead_code) ]
1720pub fn disable_test_mode ( ) {
1821 TEST_MODE . store ( false , std:: sync:: atomic:: Ordering :: Relaxed ) ;
22+ // Clear test admin IDs
23+ let null_ptr = std:: ptr:: null_mut ( ) ;
24+ TEST_ADMIN_IDS . store ( null_ptr, std:: sync:: atomic:: Ordering :: Relaxed ) ;
25+ }
26+
27+ /// Set test admin IDs for testing
28+ #[ allow( dead_code) ]
29+ pub fn set_test_admin_ids ( admin_ids : String ) {
30+ let boxed_string = Box :: new ( admin_ids) ;
31+ let ptr = Box :: into_raw ( boxed_string) ;
32+ TEST_ADMIN_IDS . store ( ptr, std:: sync:: atomic:: Ordering :: Relaxed ) ;
1933}
2034
2135#[ derive( Debug , Serialize , Deserialize ) ]
@@ -129,6 +143,19 @@ fn verify_discord_token_test(token: &str) -> Result<DiscordUser, DiscordAuthErro
129143impl DiscordAuth {
130144 /// Check if the authenticated user is an admin
131145 pub fn is_admin ( & self ) -> bool {
146+ let test_mode = TEST_MODE . load ( std:: sync:: atomic:: Ordering :: Relaxed ) ;
147+
148+ if test_mode {
149+ // Use test admin IDs
150+ let test_admin_ids_ptr = TEST_ADMIN_IDS . load ( std:: sync:: atomic:: Ordering :: Relaxed ) ;
151+ if !test_admin_ids_ptr. is_null ( ) {
152+ let test_admin_ids = unsafe { & * test_admin_ids_ptr } ;
153+ let admin_ids: Vec < & str > = test_admin_ids. split ( ',' ) . collect ( ) ;
154+ return admin_ids. contains ( & self . 0 . id . as_str ( ) ) ;
155+ }
156+ }
157+
158+ // Use environment variable
132159 let admin_ids = env:: var ( "ADMIN_IDS" ) . unwrap_or_default ( ) ;
133160 let admin_ids: Vec < & str > = admin_ids. split ( ',' ) . collect ( ) ;
134161 admin_ids. contains ( & self . 0 . id . as_str ( ) )
0 commit comments