Skip to content

Commit 19471a7

Browse files
committed
Fix .tell message collisions
This adds unique IDs to each tell message, to allow for more "proper" ORM access to each message, for marking them read and such
1 parent 3e2c1fb commit 19471a7

File tree

2 files changed

+159
-60
lines changed

2 files changed

+159
-60
lines changed

plugins/tell.py

Lines changed: 85 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,46 @@
88
Boolean,
99
Column,
1010
DateTime,
11+
Integer,
1112
PrimaryKeyConstraint,
1213
String,
1314
Table,
1415
and_,
1516
not_,
17+
update,
1618
)
1719
from sqlalchemy.sql import select
1820

1921
from cloudbot import hook
2022
from cloudbot.event import EventType
23+
from cloudbot.hook import Priority
2124
from cloudbot.util import database, timeformat, web
2225
from cloudbot.util.formatting import gen_markdown_table
2326

24-
table = Table(
25-
"tells",
26-
database.metadata,
27-
Column("connection", String),
28-
Column("sender", String),
29-
Column("target", String),
30-
Column("message", String),
31-
Column("is_read", Boolean),
32-
Column("time_sent", DateTime),
33-
Column("time_read", DateTime),
34-
)
27+
28+
class TellMessage(database.Base):
29+
__tablename__ = "tell_messages"
30+
31+
msg_id = Column(Integer, primary_key=True, autoincrement=True)
32+
conn = Column(String, index=True)
33+
sender = Column(String)
34+
target = Column(String, index=True)
35+
message = Column(String)
36+
is_read = Column(Boolean, default=False, index=True)
37+
time_sent = Column(DateTime)
38+
time_read = Column(DateTime)
39+
40+
def format_for_message(self):
41+
reltime = timeformat.time_since(self.time_sent)
42+
return f"{self.sender} sent you a message {reltime} ago: {self.message}"
43+
44+
def mark_read(self, now=None):
45+
if now is None:
46+
now = datetime.now()
47+
48+
self.is_read = True
49+
self.time_read = now
50+
3551

3652
disable_table = Table(
3753
"tell_ignores",
@@ -60,12 +76,42 @@
6076
tell_cache: List[Tuple[str, str]] = []
6177

6278

79+
@hook.on_start(priority=Priority.HIGHEST)
80+
def migrate_tables(db):
81+
table = Table(
82+
"tells",
83+
database.metadata,
84+
Column("connection", String),
85+
Column("sender", String),
86+
Column("target", String),
87+
Column("message", String),
88+
Column("is_read", Boolean),
89+
Column("time_sent", DateTime),
90+
Column("time_read", DateTime),
91+
)
92+
93+
if not table.exists(db.bind):
94+
return
95+
96+
if TellMessage.__table__.exists(db.bin):
97+
raise Exception(
98+
f"Can't migrate table {table.name} to {TellMessage.__table__.name}, destination already exists"
99+
)
100+
101+
data = db.execute(table.select())
102+
db.bulk_insert_mappings(TellMessage, data, return_defaults=True)
103+
db.commit()
104+
table.drop(db.bind)
105+
106+
63107
@hook.on_start()
64108
def load_cache(db):
65109
new_cache = []
66-
for row in db.execute(table.select().where(not_(table.c.is_read))):
67-
conn = row["connection"]
68-
target = row["target"]
110+
for conn, target in db.execute(
111+
select(
112+
[TellMessage.conn, TellMessage.target], not_(TellMessage.is_read)
113+
)
114+
):
69115
new_cache.append((conn, target))
70116

71117
tell_cache.clear()
@@ -183,48 +229,35 @@ def list_ignores(conn, nick):
183229
yield mask
184230

185231

186-
def get_unread(db, server, target):
232+
def get_unread(db, server, target) -> List[TellMessage]:
187233
query = (
188-
select([table.c.sender, table.c.message, table.c.time_sent])
189-
.where(table.c.connection == server.lower())
190-
.where(table.c.target == target.lower())
191-
.where(not_(table.c.is_read))
192-
.order_by(table.c.time_sent)
234+
select(TellMessage)
235+
.where(not_(TellMessage.is_read))
236+
.where(TellMessage.conn == server)
237+
.where(TellMessage.target == target.lower())
238+
.order_by(TellMessage.time_sent)
193239
)
194-
return db.execute(query).fetchall()
240+
241+
return db.execute(query).scalars().all()
195242

196243

197244
def count_unread(db, server, target):
198245
query = (
199-
select([sa.func.count()])
200-
.select_from(table)
201-
.where(table.c.connection == server.lower())
202-
.where(table.c.target == target.lower())
203-
.where(not_(table.c.is_read))
246+
select(sa.func.count(TellMessage.msg_id))
247+
.where(TellMessage.conn == server.lower())
248+
.where(TellMessage.target == target.lower())
249+
.where(not_(TellMessage.is_read))
204250
)
205251

206252
return db.execute(query).fetchone()[0]
207253

208254

209255
def read_all_tells(db, server, target):
210256
query = (
211-
table.update()
212-
.where(table.c.connection == server.lower())
213-
.where(table.c.target == target.lower())
214-
.where(not_(table.c.is_read))
215-
.values(is_read=True)
216-
)
217-
db.execute(query)
218-
db.commit()
219-
load_cache(db)
220-
221-
222-
def read_tell(db, server, target, message):
223-
query = (
224-
table.update()
225-
.where(table.c.connection == server.lower())
226-
.where(table.c.target == target.lower())
227-
.where(table.c.message == message)
257+
update(TellMessage)
258+
.where(TellMessage.conn == server.lower())
259+
.where(TellMessage.target == target.lower())
260+
.where(TellMessage.is_read.is_(False))
228261
.values(is_read=True)
229262
)
230263
db.execute(query)
@@ -233,15 +266,14 @@ def read_tell(db, server, target, message):
233266

234267

235268
def add_tell(db, server, sender, target, message):
236-
query = table.insert().values(
237-
connection=server.lower(),
269+
new_tell = TellMessage(
270+
conn=server.lower(),
238271
sender=sender.lower(),
239272
target=target.lower(),
240273
message=message,
241-
is_read=False,
242-
time_sent=datetime.today(),
274+
time_sent=datetime.now(),
243275
)
244-
db.execute(query)
276+
db.add(new_tell)
245277
db.commit()
246278
load_cache(db)
247279

@@ -267,20 +299,18 @@ def tellinput(conn, db, nick, notice, content):
267299
if not tells:
268300
return
269301

270-
user_from, message, time_sent = tells[0]
271-
reltime = timeformat.time_since(time_sent)
272-
reply = "{} sent you a message {} ago: {}".format(
273-
user_from, reltime, message
274-
)
302+
first_tell = tells[0]
303+
reply = first_tell.format_for_message()
275304

276305
if len(tells) > 1:
277306
reply += " (+{} more, {}showtells to view)".format(
278307
len(tells) - 1, conn.config["command_prefix"][0]
279308
)
280309

281-
read_tell(db, conn.name, nick, message)
282310
notice(reply)
283311

312+
first_tell.mark_read()
313+
284314

285315
@hook.command(autohelp=False)
286316
def showtells(nick, notice, db, conn):
@@ -293,9 +323,7 @@ def showtells(nick, notice, db, conn):
293323
return
294324

295325
for tell in tells:
296-
sender, message, time_sent = tell
297-
past = timeformat.time_since(time_sent)
298-
notice("{} sent you a message {} ago: {}".format(sender, past, message))
326+
notice(tell.format_for_message())
299327

300328
read_all_tells(db, conn.name, nick)
301329

tests/plugin_tests/test_tell.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
def init_tables(mock_db):
1111
db_engine = mock_db.engine
12-
tell.table.create(db_engine)
12+
tell.TellMessage.__table__.create(db_engine)
1313
tell.disable_table.create(db_engine)
1414
tell.ignore_table.create(db_engine)
1515
session = mock_db.session()
@@ -109,8 +109,9 @@ def test_showtells(mock_db, freeze_time):
109109
event = MagicMock()
110110
tell.add_tell(mock_db.session(), server, sender, target, message)
111111

112-
assert mock_db.get_data(tell.table) == [
112+
assert mock_db.get_data(tell.TellMessage.__table__) == [
113113
(
114+
1,
114115
"testconn",
115116
"foo",
116117
"other",
@@ -127,8 +128,9 @@ def test_showtells(mock_db, freeze_time):
127128
assert event.mock_calls == [
128129
call.notice("foo sent you a message 60 seconds ago: bar")
129130
]
130-
assert mock_db.get_data(tell.table) == [
131+
assert mock_db.get_data(tell.TellMessage.__table__) == [
131132
(
133+
1,
132134
"testconn",
133135
"foo",
134136
"other",
@@ -206,6 +208,75 @@ def test_tellinput(mock_db, freeze_time):
206208
]
207209

208210

211+
def test_read_tell_spam(mock_db, freeze_time):
212+
init_tables(mock_db)
213+
db = mock_db.session()
214+
conn = MockConn()
215+
conn.config["command_prefix"] = "."
216+
sender = "foo"
217+
nick = "other"
218+
message = "bar"
219+
message2 = "baraa"
220+
event = MagicMock()
221+
content = "aa"
222+
tell.add_tell(db, conn.name.lower(), sender, nick, message)
223+
freeze_time.tick()
224+
tell.add_tell(db, conn.name.lower(), sender, nick, message)
225+
freeze_time.tick()
226+
tell.add_tell(db, conn.name.lower(), sender, nick, message)
227+
freeze_time.tick()
228+
tell.add_tell(db, conn.name.lower(), sender, nick, message2)
229+
assert mock_db.get_data(tell.TellMessage.__table__) == [
230+
(
231+
1,
232+
"testconn",
233+
sender,
234+
nick,
235+
message,
236+
False,
237+
datetime.datetime(2019, 8, 22, 13, 14, 36),
238+
None,
239+
),
240+
(
241+
2,
242+
"testconn",
243+
sender,
244+
nick,
245+
message,
246+
False,
247+
datetime.datetime(2019, 8, 22, 13, 14, 37),
248+
None,
249+
),
250+
(
251+
3,
252+
"testconn",
253+
sender,
254+
nick,
255+
message,
256+
False,
257+
datetime.datetime(2019, 8, 22, 13, 14, 38),
258+
None,
259+
),
260+
(
261+
4,
262+
"testconn",
263+
sender,
264+
nick,
265+
message2,
266+
False,
267+
datetime.datetime(2019, 8, 22, 13, 14, 39),
268+
None,
269+
),
270+
]
271+
res = tell.tellinput(conn, db, nick, event.notice, content)
272+
assert res is None
273+
assert event.mock_calls == [
274+
call.notice(
275+
"foo sent you a message 3 seconds ago: bar (+3 more, .showtells to view)"
276+
)
277+
]
278+
279+
209280
def test_tellinput_multiple(mock_db, freeze_time):
210281
init_tables(mock_db)
211282
db = mock_db.session()

0 commit comments

Comments
 (0)