Skip to content

Commit e08663a

Browse files
authored
Merge pull request #677 from TotallyNotRobots/optout-tests
Add coverage for db access in optout.py
2 parents b2f6453 + 0d7f473 commit e08663a

File tree

6 files changed

+492
-30
lines changed

6 files changed

+492
-30
lines changed

plugins/core/optout.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections import defaultdict
66
from functools import total_ordering
77
from threading import RLock
8+
from typing import List, MutableMapping, Optional
89

910
from irclib.util.compare import match_mask
1011
from sqlalchemy import (
@@ -33,7 +34,7 @@
3334
PrimaryKeyConstraint("network", "chan", "hook"),
3435
)
3536

36-
optout_cache = DefaultKeyFoldDict(list)
37+
optout_cache: MutableMapping[str, List["OptOut"]] = DefaultKeyFoldDict(list)
3738

3839
cache_lock = RLock()
3940

@@ -47,16 +48,18 @@ def __init__(self, channel, hook_pattern, allow):
4748

4849
def __lt__(self, other):
4950
if isinstance(other, OptOut):
50-
diff = len(self.channel) - len(other.channel)
51-
if diff:
52-
return diff < 0
53-
54-
return len(self.hook) < len(other.hook)
51+
return (self.channel.rstrip("*"), self.hook.rstrip("*")) < (
52+
other.channel.rstrip("*"),
53+
other.hook.rstrip("*"),
54+
)
5555

5656
return NotImplemented
5757

58-
def __str__(self):
59-
return f"{self.channel} {self.hook} {self.allow}"
58+
def __eq__(self, other):
59+
if isinstance(other, OptOut):
60+
return self.channel == other.channel and self.hook == other.hook
61+
62+
return NotImplemented
6063

6164
def __repr__(self):
6265
return "{}({}, {}, {})".format(
@@ -82,12 +85,12 @@ async def check_channel_permissions(event, chan, *perms):
8285
return allowed
8386

8487

85-
def get_conn_optouts(conn_name):
88+
def get_conn_optouts(conn_name) -> List[OptOut]:
8689
with cache_lock:
8790
return optout_cache[conn_name.casefold()]
8891

8992

90-
def get_channel_optouts(conn_name, chan=None):
93+
def get_channel_optouts(conn_name, chan=None) -> List[OptOut]:
9194
with cache_lock:
9295
return [
9396
opt
@@ -96,6 +99,14 @@ def get_channel_optouts(conn_name, chan=None):
9699
]
97100

98101

102+
def get_first_matching_optout(conn_name, chan, hook_name) -> Optional[OptOut]:
103+
for optout in get_conn_optouts(conn_name):
104+
if optout.match(chan, hook_name):
105+
return optout
106+
107+
return None
108+
109+
99110
def format_optout_list(opts):
100111
headers = ("Channel Pattern", "Hook Pattern", "Allowed")
101112
table = [
@@ -186,19 +197,12 @@ def optout_sieve(bot, event, _hook):
186197
return event
187198

188199
hook_name = _hook.plugin.title + "." + _hook.function_name
189-
with cache_lock:
190-
optouts = get_conn_optouts(event.conn.name)
191-
for _optout in optouts:
192-
if _optout.match(event.chan, hook_name):
193-
if not _optout.allow:
194-
if _hook.type == "command":
195-
event.notice(
196-
"Sorry, that command is disabled in this channel."
197-
)
198-
199-
return None
200-
201-
break
200+
_optout = get_first_matching_optout(event.conn.name, event.chan, hook_name)
201+
if _optout and not _optout.allow:
202+
if _hook.type == "command":
203+
event.notice("Sorry, that command is disabled in this channel.")
204+
205+
return None
202206

203207
return event
204208

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ filterwarnings =
1212
error
1313
ignore:pkg_resources is deprecated as an API:DeprecationWarning
1414
ignore:datetime.*:DeprecationWarning:sqlalchemy.*
15+
asyncio_mode = auto

tests/core_tests/test_plugin_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from cloudbot.util import database
1515
from tests.util.mock_module import MockModule
1616

17+
1718
@pytest.fixture()
1819
def mock_bot(mock_bot_factory, event_loop, tmp_path):
1920
tmp_base = tmp_path / "tmp"

tests/plugin_tests/test_chan_log.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def test_format_exception_chain():
77
def _get_data(exc):
88
yield repr(exc)
9-
if hasattr(exc, 'add_note'):
9+
if hasattr(exc, "add_note"):
1010
yield f" add_note = {exc.add_note!r}"
1111

1212
yield f" args = {exc.args!r}"

0 commit comments

Comments
 (0)