Skip to content
This repository was archived by the owner on Oct 2, 2024. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions ircbot/ircbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@
MAX_CLIENT_MSG = 435


def synchronize(method):
"""Decorator to wrap a method in a lock-acquiring context manager"""
@functools.wraps(method)
def new_method(self, *args, **kwargs):
with self.lock:
return method(self, *args, **kwargs)
return new_method


class Listener(NamedTuple):
pattern: Pattern
fn: FunctionType
Expand Down Expand Up @@ -135,6 +144,9 @@ def __init__(
self.plugins: Dict[str, ModuleType] = {}
self.extra_channels: Set[str] = set() # plugins can add stuff here

# As we use threads, we should ensure that we use them safely
self.lock = threading.RLock()

# Register plugins before joining the server.
self.register_plugins()

Expand All @@ -146,6 +158,7 @@ def __init__(
connect_factory=factory,
)

@synchronize
def register_plugins(self):
for importer, mod_name, _ in pkgutil.iter_modules(['ircbot/plugin']):
mod = importer.find_module(mod_name).load_module(mod_name)
Expand All @@ -154,6 +167,7 @@ def register_plugins(self):
if register is not None:
register(self)

@synchronize
def handle_error(self, error_message):
# for debugging purposes
print(error_message)
Expand All @@ -162,6 +176,7 @@ def handle_error(self, error_message):
if not TESTING:
send_problem_report(error_message)

@synchronize
def listen(
self,
pattern,
Expand All @@ -183,13 +198,15 @@ def listen(
),
)

@synchronize
def on_welcome(self, conn, _):
conn.privmsg('NickServ', f'identify {self.nickserv_password}')

# Join the "main" IRC channels.
for channel in IRC_CHANNELS_OPER | IRC_CHANNELS_ANNOUNCE | self.extra_channels:
conn.join(channel)

@synchronize
def on_pubmsg(self, conn, event):
if event.target in self.channels:
is_oper = False
Expand Down Expand Up @@ -286,10 +303,12 @@ def respond(raw_text, ping=True):
if raw_text[0] != '!':
self.recent_messages[event.target].appendleft((user, raw_text))

@synchronize
def on_currenttopic(self, connection, event):
channel, topic = event.arguments
self.topics[channel] = topic

@synchronize
def on_topic(self, connection, event):
topic, = event.arguments
self.topics[event.target] = topic
Expand All @@ -299,6 +318,7 @@ def on_invite(self, connection, event):
import ircbot.plugin.channels
return ircbot.plugin.channels.on_invite(self, connection, event)

@synchronize
def add_thread(self, func):
def thread_func():
try:
Expand Down Expand Up @@ -330,6 +350,7 @@ def thread_func():
thread = threading.Thread(target=thread_func, daemon=True)
thread.start()

@synchronize
def bump_topic(self):
for channel, topic in self.topics.items():
def plusone(m):
Expand All @@ -339,6 +360,7 @@ def plusone(m):
if topic != new_topic:
self.connection.topic(channel, new_topic=new_topic)

@synchronize
def say(self, channel, message):
# Find the length of the full message
msg_len = len(f'PRIVMSG {channel} :{message}\r\n'.encode('utf-8'))
Expand Down