33import os
44import sqlite3
55from time import time
6+
67from util import Cache
78
89
@@ -11,29 +12,35 @@ class Database(object):
1112
1213 _instance = None
1314 _initialized = False
15+ _banned_users = set ()
1416
1517 def __new__ (cls ):
1618 if not Database ._instance :
1719 Database ._instance = super (Database , cls ).__new__ (cls )
1820 return Database ._instance
1921
2022 def __init__ (self ):
21- if not self ._initialized :
22- database_path = os .path .join (self .dir_path , "users.db" )
23- self .logger = logging .getLogger (__name__ )
23+ if self ._initialized :
24+ return
25+
26+ database_path = os .path .join (self .dir_path , "users.db" )
27+ self .logger = logging .getLogger (__name__ )
28+
29+ if not os .path .exists (database_path ):
30+ self .logger .debug ("File '{}' does not exist! Trying to create one." .format (database_path ))
31+ try :
32+ self .create_database (database_path )
33+ except Exception :
34+ self .logger .error ("An error has occurred while creating the database!" )
2435
25- if not os .path .exists (database_path ):
26- self .logger .debug ("File '{}' does not exist! Trying to create one." .format (database_path ))
27- try :
28- self .create_database (database_path )
29- except Exception :
30- self .logger .error ("An error has occurred while creating the database!" )
36+ self .connection = sqlite3 .connect (database_path )
37+ self .connection .row_factory = sqlite3 .Row
38+ self .connection .text_factory = lambda x : str (x , 'utf-8' , "ignore" )
39+ self .cursor = self .connection .cursor ()
3140
32- self .connection = sqlite3 .connect (database_path )
33- self .connection .text_factory = lambda x : str (x , 'utf-8' , "ignore" )
34- self .cursor = self .connection .cursor ()
41+ self .load_banned_users ()
3542
36- self ._initialized = True
43+ self ._initialized = True
3744
3845 @staticmethod
3946 def create_database (database_path ):
@@ -63,6 +70,7 @@ def create_database(database_path):
6370 "'games_won' INTEGER DEFAULT 0,"
6471 "'games_tie' INTEGER DEFAULT 0,"
6572 "'last_played' INTEGER DEFAULT 0,"
73+ "'banned' INTEGER DEFAULT 0,"
6674 "PRIMARY KEY('user_id'));" )
6775
6876 cursor .execute ("CREATE TABLE IF NOT EXISTS 'chats'"
@@ -72,15 +80,49 @@ def create_database(database_path):
7280 connection .commit ()
7381 connection .close ()
7482
83+ def load_banned_users (self ):
84+ """Loads all banned users from the database into a list"""
85+ self .cursor .execute ("SELECT user_id FROM users WHERE banned=1;" )
86+ result = self .cursor .fetchall ()
87+
88+ if not result :
89+ return
90+
91+ for row in result :
92+ print (int (row ["user_id" ]))
93+ self ._banned_users .add (int (row ["user_id" ]))
94+
95+ def get_banned_users (self ):
96+ """Returns a list of all banned user_ids"""
97+ return self ._banned_users
98+
7599 def get_user (self , user_id ):
76- self .cursor .execute ("SELECT user_id, first_name, last_name, username, games_played, games_won, games_tie, last_played"
100+ self .cursor .execute ("SELECT user_id, first_name, last_name, username, games_played, games_won, games_tie, last_played, banned "
77101 " FROM users WHERE user_id=?;" , [str (user_id )])
78102
79103 result = self .cursor .fetchone ()
80104 if not result or len (result ) == 0 :
81105 return None
82106 return result
83107
108+ def is_user_banned (self , user_id ):
109+ """Checks if a user was banned by the admin of the bot from using it"""
110+ # user = self.get_user(user_id)
111+ # return user is not None and user[8] == 1
112+ return int (user_id ) in self ._banned_users
113+
114+ def ban_user (self , user_id ):
115+ """Bans a user from using a the bot"""
116+ self .cursor .execute ("UPDATE users SET banned=1 WHERE user_id=?;" , [str (user_id )])
117+ self .connection .commit ()
118+ self ._banned_users .add (int (user_id ))
119+
120+ def unban_user (self , user_id ):
121+ """Unbans a user from using a the bot"""
122+ self .cursor .execute ("UPDATE users SET banned=0 WHERE user_id=?;" , [str (user_id )])
123+ self .connection .commit ()
124+ self ._banned_users .remove (int (user_id ))
125+
84126 def get_recent_players (self ):
85127 one_day_in_secs = 60 * 60 * 24
86128 current_time = int (time ())
@@ -141,7 +183,7 @@ def add_user(self, user_id, lang_id, first_name, last_name, username):
141183
142184 def _add_user (self , user_id , lang_id , first_name , last_name , username ):
143185 try :
144- self .cursor .execute ("INSERT INTO users VALUES (?, ?, ?, ?, 0, 0, 0, 0);" , [str (user_id ), first_name , last_name , username ])
186+ self .cursor .execute ("INSERT INTO users VALUES (?, ?, ?, ?, 0, 0, 0, 0, 0 );" , [str (user_id ), first_name , last_name , username ])
145187 self .cursor .execute ("INSERT INTO chats VALUES (?, ?);" , [str (user_id ), lang_id ])
146188 self .connection .commit ()
147189 except sqlite3 .IntegrityError :
0 commit comments