@@ -3209,7 +3209,10 @@ async def main() -> None:
3209
3209
force_subscribe: bool
3210
3210
An optional :class:`bool` which when ``True`` will force attempt to subscribe to the subscriptions provided in the
3211
3211
``subscriptions`` parameter, regardless of whether a new conduit was created or not. Defaults to ``False``.
3212
-
3212
+ force_scale: bool
3213
+ An optional :class:`bool` which when ``True`` will force the :class:`~twitchio.Conduit` associated with the
3214
+ AutoClient/Bot to scale up/down to the provided amount of shards in the ``shard_ids`` parameter if provided. If the
3215
+ ``shard_ids`` parameter is not passed, this parameter has no effect. Defaults to ``False``.
3213
3216
"""
3214
3217
3215
3218
# NOTE:
@@ -3228,8 +3231,10 @@ def __init__(
3228
3231
** kwargs : Unpack [AutoClientOptions ],
3229
3232
) -> None :
3230
3233
self ._shard_ids : list [int ] = kwargs .pop ("shard_ids" , [])
3234
+ self ._original_shards = self ._shard_ids
3231
3235
self ._conduit_id : str | bool | None = kwargs .pop ("conduit_id" , MISSING )
3232
3236
self ._force_sub : bool = kwargs .pop ("force_subscribe" , False )
3237
+ self ._force_scale : bool = kwargs .pop ("force_scale" , False )
3233
3238
self ._subbed : bool = False
3234
3239
3235
3240
if self ._conduit_id is MISSING or self ._conduit_id is None :
@@ -3358,6 +3363,10 @@ async def _setup(self) -> None:
3358
3363
# TODO: Maybe log currernt conduit info?
3359
3364
raise MissingConduit ("No conduit could be found with the provided ID or a new one can not be created." )
3360
3365
3366
+ if self ._force_scale and self ._original_shards :
3367
+ logger .info ("Scaling %r to %d shards." , len (self ._original_shards ))
3368
+ await self ._conduit_info .update_shard_count (len (self ._original_shards ), assign_transports = False )
3369
+
3361
3370
await self ._associate_shards (self ._shard_ids )
3362
3371
if self ._force_sub and not self ._subbed :
3363
3372
await self .multi_subscribe (self ._initial_subs )
@@ -3438,7 +3447,12 @@ async def _process_batched(self, batched: list[Websocket]) -> None:
3438
3447
await self ._conduit_info ._update_shards (payloads )
3439
3448
self ._conduit_info ._sockets .update ({str (socket ._shard_id ): socket for socket in batched })
3440
3449
3441
- logger .info ("Associated shards with %r successfully." , self ._conduit_info )
3450
+ logger .info (
3451
+ "Associated shards with %r successfully. Shards: %d / %d (connected / Conduit total)." ,
3452
+ self ._conduit_info ,
3453
+ len (self ._conduit_info .websockets ),
3454
+ self ._conduit_info .shard_count ,
3455
+ )
3442
3456
3443
3457
async def _associate_shards (self , shard_ids : list [int ]) -> None :
3444
3458
await self ._associate_lock .acquire ()
@@ -3626,14 +3640,25 @@ async def multi_subscribe(
3626
3640
)
3627
3641
return task
3628
3642
3643
+ async def _close_sockets (self ) -> None :
3644
+ socks = self ._conduit_info ._sockets .values ()
3645
+ logger .info ("Attempting to close %d associated Conduit Websockets." , len (socks ))
3646
+
3647
+ tasks : list [asyncio .Task [None ]] = [asyncio .create_task (s .close ()) for s in socks ]
3648
+ await asyncio .wait (tasks )
3649
+
3650
+ logger .info ("Successfully closed %d Conduit Websockets on %r." , len (socks ), self )
3651
+
3629
3652
async def close (self , ** options : Any ) -> None :
3630
3653
if self ._closing :
3631
3654
return
3632
3655
3633
3656
self ._closing = True
3634
3657
3635
- for socket in self ._conduit_info .websockets .values ():
3636
- await socket .close ()
3658
+ try :
3659
+ await self ._close_sockets ()
3660
+ except Exception as e :
3661
+ logger .warning ("An error occurred during the cleanup of Conduit Websockets: %s" , e )
3637
3662
3638
3663
self ._conduit_info ._sockets .clear ()
3639
3664
0 commit comments