Skip to content

Commit fa7ad46

Browse files
authored
V2 (#49)
* refactorings * fix logging * remove uneeded kwargs from ipc_client * handle disconnects gracefully * update docstrings * missed some docstrings * don't close loop when connection fails
1 parent 3c43fc8 commit fa7ad46

File tree

12 files changed

+334
-372
lines changed

12 files changed

+334
-372
lines changed

examples/broadcast_exec/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def on_message(self, event: GuildMessageCreateEvent) -> None:
4444
return
4545
if event.content.startswith("!exec"):
4646
await self.cluster.ipc.send_command(
47-
self.cluster.ipc.cluster_uids,
47+
self.cluster.ipc.clusters,
4848
"exec_code",
4949
{"code": event.content[6:]},
5050
)

hikari_clusters/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,19 @@
2828

2929
from importlib.metadata import version
3030

31+
from hikari.internal.ux import init_logging
32+
3133
from . import close_codes, commands, events, exceptions, payload
34+
from .base_client import BaseClient
3235
from .brain import Brain
3336
from .cluster import Cluster, ClusterLauncher
34-
from .info_classes import ClusterInfo, ServerInfo
37+
from .info_classes import BaseInfo, BrainInfo, ClusterInfo, ServerInfo
3538
from .ipc_client import IpcClient
3639
from .ipc_server import IpcServer
3740
from .server import Server
3841

42+
init_logging("INFO", True, False)
43+
3944
__version__ = version(__name__)
4045

4146
__all__ = (
@@ -45,8 +50,11 @@
4550
"Cluster",
4651
"ClusterLauncher",
4752
"Server",
53+
"BaseClient",
4854
"ClusterInfo",
4955
"ServerInfo",
56+
"BrainInfo",
57+
"BaseInfo",
5058
"payload",
5159
"events",
5260
"commands",

hikari_clusters/base_client.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import logging
5+
import pathlib
6+
7+
from websockets.exceptions import ConnectionClosed
8+
9+
from .info_classes import BaseInfo
10+
from .ipc_client import IpcClient
11+
from .task_manager import TaskManager
12+
13+
_LOG = logging.getLogger(__name__)
14+
15+
16+
class BaseClient:
17+
"""The base client, which contains an IpcClient.
18+
19+
Parameters
20+
----------
21+
ipc_uri : str
22+
The URI of the brain.
23+
token : str
24+
The token for the IPC server.
25+
reconnect : bool
26+
Whether to automatically reconnect if the connection
27+
is lost. Defaults to True.
28+
certificate_path : pathlib.Path | str | None
29+
The path to your certificate, which allos for secure
30+
connection over the IPC. Defaults to None.
31+
"""
32+
33+
def __init__(
34+
self,
35+
ipc_uri: str,
36+
token: str,
37+
reconnect: bool = True,
38+
certificate_path: pathlib.Path | str | None = None,
39+
):
40+
if isinstance(certificate_path, str):
41+
certificate_path = pathlib.Path(certificate_path)
42+
43+
self.tasks = TaskManager()
44+
self.ipc = IpcClient(
45+
uri=ipc_uri,
46+
token=token,
47+
reconnect=reconnect,
48+
certificate_path=certificate_path,
49+
)
50+
51+
self.stop_future: asyncio.Future[None] | None = None
52+
53+
def get_info(self) -> BaseInfo:
54+
"""Get the info class for this client.
55+
56+
Returns:
57+
BaseInfo: The info class.
58+
"""
59+
60+
raise NotImplementedError
61+
62+
async def start(self) -> None:
63+
"""Start the client.
64+
65+
Connects to the IPC server and begins sending out this clients
66+
info.
67+
"""
68+
69+
if self.stop_future is None:
70+
self.stop_future = asyncio.Future()
71+
72+
await self.ipc.start()
73+
74+
self.tasks.create_task(self._broadcast_info_loop())
75+
76+
async def join(self) -> None:
77+
"""Wait until the client begins exiting."""
78+
79+
assert self.stop_future and self.ipc.stop_future
80+
81+
await asyncio.wait(
82+
[self.stop_future, self.ipc.stop_future],
83+
return_when=asyncio.FIRST_COMPLETED,
84+
)
85+
86+
async def close(self) -> None:
87+
"""Shut down the client."""
88+
89+
self.ipc.stop()
90+
await self.ipc.close()
91+
92+
self.tasks.cancel_all()
93+
await self.tasks.wait_for_all()
94+
95+
def stop(self) -> None:
96+
"""Tell the client to stop."""
97+
98+
assert self.stop_future
99+
self.stop_future.set_result(None)
100+
101+
async def _broadcast_info_loop(self) -> None:
102+
while True:
103+
await self.ipc.wait_until_ready()
104+
assert self.ipc.uid
105+
try:
106+
await self.ipc.send_event(
107+
self.ipc.client_uids,
108+
"set_info_class",
109+
self.get_info().asdict(),
110+
)
111+
except ConnectionClosed:
112+
_LOG.error("Failed to send client info.", exc_info=True)
113+
await asyncio.sleep(1)

hikari_clusters/brain.py

Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,18 @@
2727
import signal
2828
from typing import Any
2929

30-
from . import log, payload
30+
from hikari_clusters.base_client import BaseClient
31+
from hikari_clusters.info_classes import BrainInfo
32+
33+
from . import payload
3134
from .events import EventGroup
3235
from .ipc_client import IpcClient
3336
from .ipc_server import IpcServer
34-
from .task_manager import TaskManager
3537

3638
__all__ = ("Brain",)
3739

38-
LOG = log.Logger("Brain")
39-
4040

41-
class Brain:
41+
class Brain(BaseClient):
4242
"""The brain of the bot.
4343
4444
Allows for comunication between clusters and servers,
@@ -73,29 +73,30 @@ def __init__(
7373
shards_per_cluster: int,
7474
certificate_path: pathlib.Path | str | None = None,
7575
) -> None:
76-
self.tasks = TaskManager(LOG)
76+
certificate_path = (
77+
pathlib.Path(certificate_path)
78+
if isinstance(certificate_path, str)
79+
else certificate_path
80+
)
81+
82+
super().__init__(
83+
IpcClient.get_uri(host, port, certificate_path is not None),
84+
token,
85+
True,
86+
certificate_path,
87+
)
7788

7889
self.total_servers = total_servers
7990
self.cluster_per_server = clusters_per_server
8091
self.shards_per_cluster = shards_per_cluster
8192

82-
if isinstance(certificate_path, str):
83-
certificate_path = pathlib.Path(certificate_path)
84-
8593
self.server = IpcServer(
8694
host, port, token, certificate_path=certificate_path
8795
)
88-
self.ipc = IpcClient(
89-
IpcClient.get_uri(host, port, certificate_path is not None),
90-
token,
91-
LOG,
92-
certificate_path=certificate_path,
93-
cmd_kwargs={"brain": self},
94-
event_kwargs={"brain": self},
95-
)
96-
self.ipc.events.include(_E)
9796

98-
self.stop_future: asyncio.Future[None] | None = None
97+
self.ipc.commands.cmd_kwargs["brain"] = self
98+
self.ipc.events.event_kwargs["brain"] = self
99+
self.ipc.events.include(_E)
99100

100101
self._waiting_for: tuple[int, int] | None = None
101102

@@ -128,7 +129,7 @@ def waiting_for(self) -> tuple[int, int] | None:
128129
if self._waiting_for is not None:
129130
server_uid, smallest_shard = self._waiting_for
130131
if (
131-
server_uid not in self.ipc.server_uids
132+
server_uid not in self.ipc.servers
132133
or smallest_shard in self.ipc.all_shards()
133134
):
134135
# `server_uid not in self.ipc.server_uids`
@@ -148,6 +149,11 @@ def waiting_for(self) -> tuple[int, int] | None:
148149
def waiting_for(self, value: tuple[int, int] | None) -> None:
149150
self._waiting_for = value
150151

152+
def get_info(self) -> BrainInfo:
153+
# <<<docstring from superclass>>>
154+
assert self.ipc.uid
155+
return BrainInfo(uid=self.ipc.uid)
156+
151157
def run(self) -> None:
152158
"""Run the brain, wait for the brain to stop, then cleanup."""
153159

@@ -161,42 +167,26 @@ def sigstop(*args: Any, **kwargs: Any) -> None:
161167
loop.run_until_complete(self.close())
162168

163169
async def start(self) -> None:
164-
"""Start the brain.
165-
166-
Returns as soon as all tasks have started.
167-
"""
170+
# <<<docstring from superclass>>>
168171
self.stop_future = asyncio.Future()
169-
self.tasks.create_task(self._send_brain_uid_loop())
170-
self.tasks.create_task(self._main_loop())
171-
await self.server.start()
172-
await self.ipc.start()
173-
174-
async def join(self) -> None:
175-
"""Wait for the brain to stop."""
176172

177-
assert self.stop_future
178-
await self.stop_future
173+
await self.server.start()
174+
await super().start()
175+
self.tasks.create_task(self._main_loop())
179176

180177
async def close(self) -> None:
181-
"""Shut the brain down."""
178+
# <<<docstring from superclass>>>
179+
self.ipc.stop()
180+
await self.ipc.close()
182181

183182
self.server.stop()
184183
await self.server.close()
185184

186-
self.ipc.stop()
187-
await self.ipc.close()
188-
189185
self.tasks.cancel_all()
190186
await self.tasks.wait_for_all()
191187

192-
def stop(self) -> None:
193-
"""Tell the brain to stop."""
194-
195-
assert self.stop_future
196-
self.stop_future.set_result(None)
197-
198188
def _get_next_cluster_to_launch(self) -> tuple[int, list[int]] | None:
199-
if len(self.ipc.server_uids) == 0:
189+
if len(self.ipc.servers) == 0:
200190
return None
201191

202192
if not all(c.ready for c in self.ipc.clusters.values()):
@@ -219,14 +209,6 @@ def _get_next_cluster_to_launch(self) -> tuple[int, list[int]] | None:
219209

220210
return s.uid, list(shards_to_launch)[: self.shards_per_cluster]
221211

222-
async def _send_brain_uid_loop(self) -> None:
223-
while True:
224-
await self.ipc.wait_until_ready()
225-
await self.ipc.send_event(
226-
self.ipc.client_uids, "set_brain_uid", {"uid": self.ipc.uid}
227-
)
228-
await asyncio.sleep(1)
229-
230212
async def _main_loop(self) -> None:
231213
await self.ipc.wait_until_ready()
232214
while True:
@@ -257,5 +239,13 @@ async def brain_stop(pl: payload.EVENT, brain: Brain) -> None:
257239

258240
@_E.add("shutdown")
259241
async def shutdown(pl: payload.EVENT, brain: Brain) -> None:
260-
await brain.ipc.send_event(brain.ipc.server_uids, "server_stop")
242+
await brain.ipc.send_event(brain.ipc.servers.keys(), "server_stop")
261243
brain.stop()
244+
245+
246+
@_E.add("cluster_died")
247+
async def cluster_died(pl: payload.EVENT, brain: Brain) -> None:
248+
assert pl.data.data is not None
249+
shard_id = pl.data.data["smallest_shard_id"]
250+
if brain._waiting_for is not None and brain._waiting_for[1] == shard_id:
251+
brain.waiting_for = None

0 commit comments

Comments
 (0)