@@ -12,6 +12,7 @@ class Database(object):
1212
1313 _instance = None
1414 _initialized = False
15+ _banned_users = set ()
1516
1617 def __new__ (cls ):
1718 if not Database ._instance :
@@ -37,6 +38,8 @@ def __init__(self):
3738 self .connection .text_factory = lambda x : str (x , 'utf-8' , "ignore" )
3839 self .cursor = self .connection .cursor ()
3940
41+ self .load_banned_users ()
42+
4043 self ._initialized = True
4144
4245 @staticmethod
@@ -77,6 +80,22 @@ def create_database(database_path):
7780 connection .commit ()
7881 connection .close ()
7982
83+ def load_banned_users (self ):
84+ """Loads all banned users from the database into a list"""
85+ self .cursor .execute ("SELECT user_id FROM users WHERE banned=1;" )
86+ result = self .cursor .fetchall ()
87+
88+ if not result :
89+ return
90+
91+ for row in result :
92+ print (int (row ["user_id" ]))
93+ self ._banned_users .add (int (row ["user_id" ]))
94+
95+ def get_banned_users (self ):
96+ """Returns a list of all banned user_ids"""
97+ return self ._banned_users
98+
8099 def get_user (self , user_id ):
81100 self .cursor .execute ("SELECT user_id, first_name, last_name, username, games_played, games_won, games_tie, last_played, banned"
82101 " FROM users WHERE user_id=?;" , [str (user_id )])
@@ -88,18 +107,21 @@ def get_user(self, user_id):
88107
89108 def is_user_banned (self , user_id ):
90109 """Checks if a user was banned by the admin of the bot from using it"""
91- user = self .get_user (user_id )
92- return user is not None and user [8 ] == 1
110+ # user = self.get_user(user_id)
111+ # return user is not None and user[8] == 1
112+ return int (user_id ) in self ._banned_users
93113
94114 def ban_user (self , user_id ):
95115 """Bans a user from using a the bot"""
96116 self .cursor .execute ("UPDATE users SET banned=1 WHERE user_id=?;" , [str (user_id )])
97117 self .connection .commit ()
118+ self ._banned_users .add (int (user_id ))
98119
99120 def unban_user (self , user_id ):
100121 """Unbans a user from using a the bot"""
101122 self .cursor .execute ("UPDATE users SET banned=0 WHERE user_id=?;" , [str (user_id )])
102123 self .connection .commit ()
124+ self ._banned_users .remove (int (user_id ))
103125
104126 def get_recent_players (self ):
105127 one_day_in_secs = 60 * 60 * 24
0 commit comments