8
8
Boolean ,
9
9
Column ,
10
10
DateTime ,
11
+ Integer ,
11
12
PrimaryKeyConstraint ,
12
13
String ,
13
14
Table ,
14
15
and_ ,
15
16
not_ ,
17
+ update ,
16
18
)
17
19
from sqlalchemy .sql import select
18
20
19
21
from cloudbot import hook
20
22
from cloudbot .event import EventType
23
+ from cloudbot .hook import Priority
21
24
from cloudbot .util import database , timeformat , web
22
25
from cloudbot .util .formatting import gen_markdown_table
23
26
24
- table = Table (
25
- "tells" ,
26
- database .metadata ,
27
- Column ("connection" , String ),
28
- Column ("sender" , String ),
29
- Column ("target" , String ),
30
- Column ("message" , String ),
31
- Column ("is_read" , Boolean ),
32
- Column ("time_sent" , DateTime ),
33
- Column ("time_read" , DateTime ),
34
- )
27
+
28
+ class TellMessage (database .Base ):
29
+ __tablename__ = "tell_messages"
30
+
31
+ msg_id = Column (Integer , primary_key = True , autoincrement = True )
32
+ conn = Column (String , index = True )
33
+ sender = Column (String )
34
+ target = Column (String , index = True )
35
+ message = Column (String )
36
+ is_read = Column (Boolean , default = False , index = True )
37
+ time_sent = Column (DateTime )
38
+ time_read = Column (DateTime )
39
+
40
+ def format_for_message (self ):
41
+ reltime = timeformat .time_since (self .time_sent )
42
+ return f"{ self .sender } sent you a message { reltime } ago: { self .message } "
43
+
44
+ def mark_read (self , now = None ):
45
+ if now is None :
46
+ now = datetime .now ()
47
+
48
+ self .is_read = True
49
+ self .time_read = now
50
+
35
51
36
52
disable_table = Table (
37
53
"tell_ignores" ,
60
76
tell_cache : List [Tuple [str , str ]] = []
61
77
62
78
79
+ @hook .on_start (priority = Priority .HIGHEST )
80
+ def migrate_tables (db ):
81
+ inspector = sa .inspect (db .bind )
82
+ if not inspector .has_table ("tells" ):
83
+ return
84
+
85
+ table = sa .Table (
86
+ "tells" ,
87
+ database .metadata ,
88
+ autoload_with = db .bind ,
89
+ )
90
+
91
+ if (
92
+ inspector .has_table (TellMessage .__tablename__ )
93
+ and db .query (TellMessage ).count () > 0
94
+ ):
95
+ raise Exception (
96
+ f"Can't migrate table { table .name } to { TellMessage .__tablename__ } , destination already exists"
97
+ )
98
+
99
+ data = [dict (row ) for row in db .execute (table .select ())]
100
+ for item in data :
101
+ item ["conn" ] = item .pop ("connection" )
102
+
103
+ db .bulk_insert_mappings (TellMessage , data , return_defaults = True )
104
+ db .commit ()
105
+
106
+ table .drop (db .bind )
107
+
108
+
63
109
@hook .on_start ()
64
110
def load_cache (db ):
65
111
new_cache = []
66
- for row in db .execute (table .select ().where (not_ (table .c .is_read ))):
67
- conn = row ["connection" ]
68
- target = row ["target" ]
112
+ for conn , target in db .execute (
113
+ select (
114
+ [TellMessage .conn , TellMessage .target ], not_ (TellMessage .is_read )
115
+ )
116
+ ):
69
117
new_cache .append ((conn , target ))
70
118
71
119
tell_cache .clear ()
@@ -183,48 +231,35 @@ def list_ignores(conn, nick):
183
231
yield mask
184
232
185
233
186
- def get_unread (db , server , target ):
234
+ def get_unread (db , server , target ) -> List [ TellMessage ] :
187
235
query = (
188
- select ([ table . c . sender , table . c . message , table . c . time_sent ] )
189
- .where (table . c . connection == server . lower ( ))
190
- .where (table . c . target == target . lower () )
191
- .where (not_ ( table . c . is_read ))
192
- .order_by (table . c .time_sent )
236
+ select (TellMessage )
237
+ .where (not_ ( TellMessage . is_read ))
238
+ .where (TellMessage . conn == server )
239
+ .where (TellMessage . target == target . lower ( ))
240
+ .order_by (TellMessage .time_sent )
193
241
)
194
- return db .execute (query ).fetchall ()
242
+
243
+ return db .execute (query ).scalars ().all ()
195
244
196
245
197
246
def count_unread (db , server , target ):
198
247
query = (
199
- select ([sa .func .count ()])
200
- .select_from (table )
201
- .where (table .c .connection == server .lower ())
202
- .where (table .c .target == target .lower ())
203
- .where (not_ (table .c .is_read ))
248
+ select (sa .func .count (TellMessage .msg_id ))
249
+ .where (TellMessage .conn == server .lower ())
250
+ .where (TellMessage .target == target .lower ())
251
+ .where (not_ (TellMessage .is_read ))
204
252
)
205
253
206
254
return db .execute (query ).fetchone ()[0 ]
207
255
208
256
209
257
def read_all_tells (db , server , target ):
210
258
query = (
211
- table .update ()
212
- .where (table .c .connection == server .lower ())
213
- .where (table .c .target == target .lower ())
214
- .where (not_ (table .c .is_read ))
215
- .values (is_read = True )
216
- )
217
- db .execute (query )
218
- db .commit ()
219
- load_cache (db )
220
-
221
-
222
- def read_tell (db , server , target , message ):
223
- query = (
224
- table .update ()
225
- .where (table .c .connection == server .lower ())
226
- .where (table .c .target == target .lower ())
227
- .where (table .c .message == message )
259
+ update (TellMessage )
260
+ .where (TellMessage .conn == server .lower ())
261
+ .where (TellMessage .target == target .lower ())
262
+ .where (TellMessage .is_read .is_ (False ))
228
263
.values (is_read = True )
229
264
)
230
265
db .execute (query )
@@ -233,15 +268,14 @@ def read_tell(db, server, target, message):
233
268
234
269
235
270
def add_tell (db , server , sender , target , message ):
236
- query = table . insert (). values (
237
- connection = server .lower (),
271
+ new_tell = TellMessage (
272
+ conn = server .lower (),
238
273
sender = sender .lower (),
239
274
target = target .lower (),
240
275
message = message ,
241
- is_read = False ,
242
- time_sent = datetime .today (),
276
+ time_sent = datetime .now (),
243
277
)
244
- db .execute ( query )
278
+ db .add ( new_tell )
245
279
db .commit ()
246
280
load_cache (db )
247
281
@@ -267,20 +301,18 @@ def tellinput(conn, db, nick, notice, content):
267
301
if not tells :
268
302
return
269
303
270
- user_from , message , time_sent = tells [0 ]
271
- reltime = timeformat .time_since (time_sent )
272
- reply = "{} sent you a message {} ago: {}" .format (
273
- user_from , reltime , message
274
- )
304
+ first_tell = tells [0 ]
305
+ reply = first_tell .format_for_message ()
275
306
276
307
if len (tells ) > 1 :
277
308
reply += " (+{} more, {}showtells to view)" .format (
278
309
len (tells ) - 1 , conn .config ["command_prefix" ][0 ]
279
310
)
280
311
281
- read_tell (db , conn .name , nick , message )
282
312
notice (reply )
283
313
314
+ first_tell .mark_read ()
315
+
284
316
285
317
@hook .command (autohelp = False )
286
318
def showtells (nick , notice , db , conn ):
@@ -293,9 +325,7 @@ def showtells(nick, notice, db, conn):
293
325
return
294
326
295
327
for tell in tells :
296
- sender , message , time_sent = tell
297
- past = timeformat .time_since (time_sent )
298
- notice ("{} sent you a message {} ago: {}" .format (sender , past , message ))
328
+ notice (tell .format_for_message ())
299
329
300
330
read_all_tells (db , conn .name , nick )
301
331
0 commit comments