Skip to content

Commit e69909c

Browse files
committed
Speed up AutoClient cleanup and add force_scale keyword parameter.
1 parent 20fe33b commit e69909c

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

twitchio/client.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3209,7 +3209,10 @@ async def main() -> None:
32093209
force_subscribe: bool
32103210
An optional :class:`bool` which when ``True`` will force attempt to subscribe to the subscriptions provided in the
32113211
``subscriptions`` parameter, regardless of whether a new conduit was created or not. Defaults to ``False``.
3212-
3212+
force_scale: bool
3213+
An optional :class:`bool` which when ``True`` will force the :class:`~twitchio.Conduit` associated with the
3214+
AutoClient/Bot to scale up/down to the provided amount of shards in the ``shard_ids`` parameter if provided. If the
3215+
``shard_ids`` parameter is not passed, this parameter has no effect. Defaults to ``False``.
32133216
"""
32143217

32153218
# NOTE:
@@ -3228,8 +3231,10 @@ def __init__(
32283231
**kwargs: Unpack[AutoClientOptions],
32293232
) -> None:
32303233
self._shard_ids: list[int] = kwargs.pop("shard_ids", [])
3234+
self._original_shards = self._shard_ids
32313235
self._conduit_id: str | bool | None = kwargs.pop("conduit_id", MISSING)
32323236
self._force_sub: bool = kwargs.pop("force_subscribe", False)
3237+
self._force_scale: bool = kwargs.pop("force_scale", False)
32333238
self._subbed: bool = False
32343239

32353240
if self._conduit_id is MISSING or self._conduit_id is None:
@@ -3358,6 +3363,10 @@ async def _setup(self) -> None:
33583363
# TODO: Maybe log currernt conduit info?
33593364
raise MissingConduit("No conduit could be found with the provided ID or a new one can not be created.")
33603365

3366+
if self._force_scale and self._original_shards:
3367+
logger.info("Scaling %r to %d shards.", len(self._original_shards))
3368+
await self._conduit_info.update_shard_count(len(self._original_shards), assign_transports=False)
3369+
33613370
await self._associate_shards(self._shard_ids)
33623371
if self._force_sub and not self._subbed:
33633372
await self.multi_subscribe(self._initial_subs)
@@ -3438,7 +3447,12 @@ async def _process_batched(self, batched: list[Websocket]) -> None:
34383447
await self._conduit_info._update_shards(payloads)
34393448
self._conduit_info._sockets.update({str(socket._shard_id): socket for socket in batched})
34403449

3441-
logger.info("Associated shards with %r successfully.", self._conduit_info)
3450+
logger.info(
3451+
"Associated shards with %r successfully. Shards: %d / %d (connected / Conduit total).",
3452+
self._conduit_info,
3453+
len(self._conduit_info.websockets),
3454+
self._conduit_info.shard_count,
3455+
)
34423456

34433457
async def _associate_shards(self, shard_ids: list[int]) -> None:
34443458
await self._associate_lock.acquire()
@@ -3626,14 +3640,25 @@ async def multi_subscribe(
36263640
)
36273641
return task
36283642

3643+
async def _close_sockets(self) -> None:
3644+
socks = self._conduit_info._sockets.values()
3645+
logger.info("Attempting to close %d associated Conduit Websockets.", len(socks))
3646+
3647+
tasks: list[asyncio.Task[None]] = [asyncio.create_task(s.close()) for s in socks]
3648+
await asyncio.wait(tasks)
3649+
3650+
logger.info("Successfully closed %d Conduit Websockets on %r.", len(socks), self)
3651+
36293652
async def close(self, **options: Any) -> None:
36303653
if self._closing:
36313654
return
36323655

36333656
self._closing = True
36343657

3635-
for socket in self._conduit_info.websockets.values():
3636-
await socket.close()
3658+
try:
3659+
await self._close_sockets()
3660+
except Exception as e:
3661+
logger.warning("An error occurred during the cleanup of Conduit Websockets: %s", e)
36373662

36383663
self._conduit_info._sockets.clear()
36393664

twitchio/types_/options.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class AutoClientOptions(ClientOptions, total=False):
5353
max_per_shard: int
5454
subscriptions: list[SubscriptionPayload]
5555
force_subscribe: bool
56+
force_scale: bool
5657

5758

5859
WaitPredicateT = Callable[..., Coroutine[Any, Any, bool]]

0 commit comments

Comments
 (0)