2727import signal
2828from typing import Any
2929
30- from . import log , payload
30+ from hikari_clusters .base_client import BaseClient
31+ from hikari_clusters .info_classes import BrainInfo
32+
33+ from . import payload
3134from .events import EventGroup
3235from .ipc_client import IpcClient
3336from .ipc_server import IpcServer
34- from .task_manager import TaskManager
3537
3638__all__ = ("Brain" ,)
3739
38- LOG = log .Logger ("Brain" )
39-
4040
41- class Brain :
41+ class Brain ( BaseClient ) :
4242 """The brain of the bot.
4343
4444 Allows for comunication between clusters and servers,
@@ -73,29 +73,30 @@ def __init__(
7373 shards_per_cluster : int ,
7474 certificate_path : pathlib .Path | str | None = None ,
7575 ) -> None :
76- self .tasks = TaskManager (LOG )
76+ certificate_path = (
77+ pathlib .Path (certificate_path )
78+ if isinstance (certificate_path , str )
79+ else certificate_path
80+ )
81+
82+ super ().__init__ (
83+ IpcClient .get_uri (host , port , certificate_path is not None ),
84+ token ,
85+ True ,
86+ certificate_path ,
87+ )
7788
7889 self .total_servers = total_servers
7990 self .cluster_per_server = clusters_per_server
8091 self .shards_per_cluster = shards_per_cluster
8192
82- if isinstance (certificate_path , str ):
83- certificate_path = pathlib .Path (certificate_path )
84-
8593 self .server = IpcServer (
8694 host , port , token , certificate_path = certificate_path
8795 )
88- self .ipc = IpcClient (
89- IpcClient .get_uri (host , port , certificate_path is not None ),
90- token ,
91- LOG ,
92- certificate_path = certificate_path ,
93- cmd_kwargs = {"brain" : self },
94- event_kwargs = {"brain" : self },
95- )
96- self .ipc .events .include (_E )
9796
98- self .stop_future : asyncio .Future [None ] | None = None
97+ self .ipc .commands .cmd_kwargs ["brain" ] = self
98+ self .ipc .events .event_kwargs ["brain" ] = self
99+ self .ipc .events .include (_E )
99100
100101 self ._waiting_for : tuple [int , int ] | None = None
101102
@@ -128,7 +129,7 @@ def waiting_for(self) -> tuple[int, int] | None:
128129 if self ._waiting_for is not None :
129130 server_uid , smallest_shard = self ._waiting_for
130131 if (
131- server_uid not in self .ipc .server_uids
132+ server_uid not in self .ipc .servers
132133 or smallest_shard in self .ipc .all_shards ()
133134 ):
134135 # `server_uid not in self.ipc.server_uids`
@@ -148,6 +149,11 @@ def waiting_for(self) -> tuple[int, int] | None:
148149 def waiting_for (self , value : tuple [int , int ] | None ) -> None :
149150 self ._waiting_for = value
150151
152+ def get_info (self ) -> BrainInfo :
153+ # <<<docstring from superclass>>>
154+ assert self .ipc .uid
155+ return BrainInfo (uid = self .ipc .uid )
156+
151157 def run (self ) -> None :
152158 """Run the brain, wait for the brain to stop, then cleanup."""
153159
@@ -161,42 +167,26 @@ def sigstop(*args: Any, **kwargs: Any) -> None:
161167 loop .run_until_complete (self .close ())
162168
163169 async def start (self ) -> None :
164- """Start the brain.
165-
166- Returns as soon as all tasks have started.
167- """
170+ # <<<docstring from superclass>>>
168171 self .stop_future = asyncio .Future ()
169- self .tasks .create_task (self ._send_brain_uid_loop ())
170- self .tasks .create_task (self ._main_loop ())
171- await self .server .start ()
172- await self .ipc .start ()
173-
174- async def join (self ) -> None :
175- """Wait for the brain to stop."""
176172
177- assert self .stop_future
178- await self .stop_future
173+ await self .server .start ()
174+ await super ().start ()
175+ self .tasks .create_task (self ._main_loop ())
179176
180177 async def close (self ) -> None :
181- """Shut the brain down."""
178+ # <<<docstring from superclass>>>
179+ self .ipc .stop ()
180+ await self .ipc .close ()
182181
183182 self .server .stop ()
184183 await self .server .close ()
185184
186- self .ipc .stop ()
187- await self .ipc .close ()
188-
189185 self .tasks .cancel_all ()
190186 await self .tasks .wait_for_all ()
191187
192- def stop (self ) -> None :
193- """Tell the brain to stop."""
194-
195- assert self .stop_future
196- self .stop_future .set_result (None )
197-
198188 def _get_next_cluster_to_launch (self ) -> tuple [int , list [int ]] | None :
199- if len (self .ipc .server_uids ) == 0 :
189+ if len (self .ipc .servers ) == 0 :
200190 return None
201191
202192 if not all (c .ready for c in self .ipc .clusters .values ()):
@@ -219,14 +209,6 @@ def _get_next_cluster_to_launch(self) -> tuple[int, list[int]] | None:
219209
220210 return s .uid , list (shards_to_launch )[: self .shards_per_cluster ]
221211
222- async def _send_brain_uid_loop (self ) -> None :
223- while True :
224- await self .ipc .wait_until_ready ()
225- await self .ipc .send_event (
226- self .ipc .client_uids , "set_brain_uid" , {"uid" : self .ipc .uid }
227- )
228- await asyncio .sleep (1 )
229-
230212 async def _main_loop (self ) -> None :
231213 await self .ipc .wait_until_ready ()
232214 while True :
@@ -257,5 +239,13 @@ async def brain_stop(pl: payload.EVENT, brain: Brain) -> None:
257239
258240@_E .add ("shutdown" )
259241async def shutdown (pl : payload .EVENT , brain : Brain ) -> None :
260- await brain .ipc .send_event (brain .ipc .server_uids , "server_stop" )
242+ await brain .ipc .send_event (brain .ipc .servers . keys () , "server_stop" )
261243 brain .stop ()
244+
245+
246+ @_E .add ("cluster_died" )
247+ async def cluster_died (pl : payload .EVENT , brain : Brain ) -> None :
248+ assert pl .data .data is not None
249+ shard_id = pl .data .data ["smallest_shard_id" ]
250+ if brain ._waiting_for is not None and brain ._waiting_for [1 ] == shard_id :
251+ brain .waiting_for = None
0 commit comments