Skip to content

Commit 39af8fb

Browse files
committed
Various changes to AutoClient to support shard_ids inplace of shard_count
1 parent 3f3424c commit 39af8fb

File tree

1 file changed

+67
-86
lines changed

1 file changed

+67
-86
lines changed

twitchio/client.py

Lines changed: 67 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from .authentication import ManagedHTTPClient, Scopes, UserTokenPayload
3535
from .eventsub.enums import SubscriptionType
36-
from .eventsub.websockets import Websocket
36+
from .eventsub.websockets import Websocket, WebsocketClosed
3737
from .exceptions import HTTPException, MissingConduit
3838
from .http import HTTPAsyncIterator
3939
from .models.bits import Cheermote, ExtensionTransaction
@@ -2709,7 +2709,7 @@ async def update_shard_count(self, shard_count: int, /, *, assign_transports: bo
27092709
.. warning::
27102710
27112711
Caution should be used when scaling multi-process solutions that connect to the same Conduit. In this particular
2712-
case it is more likely appropriate to close down and restart each instance and adjust the ``shard_count`` parameter
2712+
case it is more likely appropriate to close down and restart each instance and adjust the ``shard_ids`` parameter
27132713
on :class:`~twitchio.AutoClient` accordingly. However this method can be called with ``assign_transports=False``
27142714
first to allow each instance to easily adjust when restarted.
27152715
@@ -2751,12 +2751,15 @@ async def update_shard_count(self, shard_count: int, /, *, assign_transports: bo
27512751
assert self.conduit
27522752
assert self.shard_count
27532753

2754+
end = max(self._client._shard_ids) + 1
2755+
shard_ids = list(range(end, end + (shard_count - len(self._client._shard_ids))))
2756+
27542757
if assign_transports:
27552758
if self.shard_count > len(self.websockets):
2756-
await self._client._associate_shards()
2759+
await self._client._associate_shards(shard_ids)
27572760
elif self.shard_count < len(self.websockets):
27582761
remove = len(self.websockets) - self.shard_count
2759-
await self._disassociate_shards(remove)
2762+
await self._disassociate_shards(self.shard_count, remove)
27602763

27612764
return self
27622765

@@ -2822,17 +2825,17 @@ def fetch_shards(self, *, status: ShardStatus | None = None) -> HTTPAsyncIterato
28222825

28232826
return self._conduit.fetch_shards(status=status)
28242827

2825-
async def _disassociate_shards(self, remove_count: int, /) -> None:
2826-
sockets = list(self._sockets.values())
2828+
async def _disassociate_shards(self, start: int, remove_count: int, /) -> None:
28272829
closed = 0
28282830

2829-
for _ in range(remove_count):
2830-
socket = sockets.pop()
2831-
self._sockets.pop(str(socket._shard_id), None)
2831+
for n in range(start, start + remove_count):
2832+
socket = self._sockets.pop(str(n), None)
2833+
if not socket:
2834+
continue
28322835

28332836
try:
28342837
closed += 1
2835-
await socket.close()
2838+
await socket.close(reassociate=False)
28362839
except Exception as e:
28372840
logger.debug("Ignoring exception in close of %r. It is likely a non-issue: %s", socket, e)
28382841

@@ -2844,6 +2847,8 @@ async def _disassociate_shards(self, remove_count: int, /) -> None:
28442847
len(self.websockets),
28452848
)
28462849

2850+
self._client._shard_ids = sorted([int(k) for k in self._sockets])
2851+
28472852
async def _update_shards(self, shards: list[ShardUpdateRequest]) -> Self:
28482853
# TODO?
28492854
if not self._conduit:
@@ -2942,7 +2947,7 @@ class AutoClient(Client):
29422947
* Your application will associate with that conduit and assign transports; existing subscriptions remain.
29432948
* If no Conduit exists:
29442949
* Your application will create and associate itself with the new Conduit, assigning transports and subscribing.
2945-
* By default the amount of shards will be equal ``len(subscriptions) / max_per_shard`` or ``2`` whichever is greater **or** ``shard_count`` if passed.
2950+
* By default the amount of shards will be equal ``len(subscriptions) / max_per_shard`` or ``2`` whichever is greater **or** ``len(shard_ids)`` if passed.
29462951
* Most applications will only require ``2`` or ``3`` shards, with the second shard mostly existing as a fallback.
29472952
29482953
In both scenarios above your application can restart at anytime without re-subscribing, however take note of the
@@ -2989,8 +2994,8 @@ async def main() -> None:
29892994
29902995
* Your application will connect to the provided Conduit-ID.
29912996
* If the conduit does not exist or could not be found, an error will be raised on start-up.
2992-
* Your application will assign the amount of transports equal to the shard count the Conduit returns from Twitch **or** the amount passed to ``shard_count``.
2993-
* ``shard_count`` is not a required parameter, however see below for more info.
2997+
* Your application will assign the amount of transports equal to the shard count the Conduit returns from Twitch **or** the amount passed to ``len(shard_ids)``.
2998+
* ``shard_ids`` is not a required parameter, however see below for more info.
29942999
29953000
This case allows greater control over which Conduit your application connects to which is mostly only useful if your
29963001
setup includes multiple Conduits and/or multiple instances of :class:`~twitchio.AutoClient` or
@@ -3000,10 +3005,9 @@ async def main() -> None:
30003005
realistically connect to a different conduit each.
30013006
30023007
When the latter is true, your application can be setup to effectively distribute shards equally accross instances by
3003-
connecting to the same Conduit. In this scenario you will also need to pass the ``shard_count`` parameter to each
3004-
instance, dividing the total amount of shards required by the amount of instances that will run. You will also need to
3005-
pass ``False`` to the ``update_conduit_shards`` keyword-only parameter, as this will update the underlying Conduit
3006-
to only have a max of ``shard_count``; which will not be useful when distributing shards.
3008+
connecting to the same Conduit. In this scenario you will also need to pass the ``shard_ids`` parameter to each
3009+
instance, making sure each instance is assigned shards. E.g. ``shard_ids=[0, 1, 2]`` and ``shard_ids=[3, 4, 5]`` for a
3010+
``2`` instance setup on a Conduit which contains ``6`` shards.
30073011
30083012
As this case is more involved and requires much more attention; e.g. multiple instances need to be started correctly and
30093013
connections to multiple Conduits need to be properly configured it usually wouldn't be advised for small to medium sized
@@ -3015,7 +3019,8 @@ async def main() -> None:
30153019
**An example of case 2 (multiple instances on one conduit):**
30163020
30173021
In this scenario your conduit should have ``6`` shards associated with it already. You should do this before starting
3018-
multiple instances. The example below would simply be run twice, assigning ``3`` shards to each instance.
3022+
multiple instances. The example below would simply be run twice, assigning ``3`` shards to each instance, E.g. on the
3023+
first instance ``shard_ids=[0, 1, 2]`` and the second instance ``shard_ids=[3, 4, 5]``.
30193024
30203025
You can prepare for this scenario by calling :meth:`twitchio.Conduit.update` on the appropriate Conduit or by calling
30213026
:meth:`~twitchio.ConduitInfo.update_shard_count` with ``6`` shards, and keeping note of the Conduit-ID, before starting
@@ -3031,8 +3036,7 @@ def __init__(self, *args, **kwargs) -> None:
30313036
**kwargs,
30323037
subscriptions=subs,
30333038
conduit_id="CONDUIT_1_ID",
3034-
shard_count=3,
3035-
update_conduit_shards=False,
3039+
shard_ids=[...],
30363040
)
30373041
30383042
async def event_message(self, payload: twitchio.ChatMessage) -> None:
@@ -3065,30 +3069,24 @@ async def main() -> None:
30653069
ownership of, or ``None`` to connect to an existing Conduit or create one if none exist. You can also pass ``True``
30663070
to this parameter, however see above for more details and examples on how this parameter affects your application.
30673071
Defaults to ``None``, which is the most common use case.
3068-
shard_count: int
3069-
An optional :class:`int` which sets the amount of shards or transports that should be associated with the Conduit.
3070-
If the :class:`~twitchio.AutoClient` creates a new Conduit this will be what is sent to Twitch if passed. When this
3071-
parameter is passed and ``update_conduit_shards`` is ``True`` (Default), when your Client starts a request will be
3072-
made to Twitch to update the Conduit Shard Count on the API. If this parameter is passed and ``update_conduit_shards``
3073-
is ``False``, the Client will not update the total Shard Count on the API, however will connect ``shard_count``
3074-
amount of websocket transports. Together with ``update_conduit_shards=False``, this parameter can be used to equally
3075-
distribute shards accross multiple instances/processes on a single Conduit.
3072+
shard_ids: list[int]
3073+
An optional :class:`list` of :class:`int` which sets the shard IDs this instance of :class:`~twitchio.AutoClient`
3074+
will assign transports for. If the :class:`~twitchio.AutoClient` creates a new Conduit the length of this parameter
3075+
will be the ``shard_count`` which the Conduit will be created with. This parameter can be used to equally distribute
3076+
shards accross multiple instances/processes on a single Conduit. You should make sure the list of ``shard_ids`` is
3077+
consecutive and in order. E.g. ``list(range(3))`` or ``[0, 1, 2]``. Shard ID's are ``0`` indexed.
30763078
max_per_shard: int
30773079
An optional parameter which allows the Client to automatically determine the amount of shards you may require based on
30783080
the amount of ``subscriptions`` passed. The default value is ``1000`` and the algorithm used to determine shards is
30793081
simply: ``len(subscriptons) / max_per_shard`` or ``2`` whichever is greater. Note this parameter has no effect when
3080-
``shard_count`` is explicitly passed.
3082+
``shard_ids`` is explicitly passed.
30813083
subscriptions: list[twitchio.eventsub.SubscriptionPayload]
30823084
A list of any combination of EventSub subscriptions (all of which inherit from
30833085
:class:`~twitchio.eventsub.SubscriptionPayload`) the Client should attempt to subscribe to when required. The
30843086
:class:`~twitchio.AutoClient` will only attempt to subscribe to these subscriptions when it creates a new Conduit. If
30853087
your Client connects to an existing Conduit either by passing ``conduit_id`` or automatically, this parameter has no
30863088
effect. In cases where you need to update an existing Conduit with new subscriptions see:
30873089
:meth:`~twitchio.AutoClient.multi_subscribe`.
3088-
update_conduit_shards: bool
3089-
Optional parameter which when ``True`` will hard update the Conduit Shard Count on the Twitch API.
3090-
This parameter should be set to ``False`` when you are using ``2`` or more instances/processes of
3091-
:class:`~twitchio.AutoClient`. Defaults to ``True``.
30923090
"""
30933091

30943092
# NOTE:
@@ -3106,7 +3104,9 @@ def __init__(
31063104
bot_id: str | None = None,
31073105
**kwargs: Unpack[AutoClientOptions],
31083106
) -> None:
3107+
self._shard_ids: list[int] = kwargs.pop("shard_ids", [])
31093108
self._conduit_id: str | bool | None = kwargs.pop("conduit_id", MISSING)
3109+
31103110
if self._conduit_id is MISSING or self._conduit_id is None:
31113111
logger.warning(
31123112
'No "conduit_id" was passed. If a single conduit exists we will try to take ownership. '
@@ -3122,38 +3122,14 @@ def __init__(
31223122
elif not isinstance(self._conduit_id, str):
31233123
raise TypeError('The parameter "conduit_id" must be either a str, True or not provided.')
31243124

3125-
self._shard_count = kwargs.pop("shard_count", None)
31263125
self._max_per_shard = kwargs.pop("max_per_shard", 1000)
31273126
self._initial_subs: list[SubscriptionPayload] = kwargs.pop("subscriptions", [])
3128-
self._update_conduit_shards: bool = kwargs.pop("update_conduit_shards", True)
31293127

3130-
if self._max_per_shard < 1000 and self._shard_count is not None:
3128+
if self._max_per_shard < 1000 and not self._shard_ids:
31313129
logger.warning('It is recommended that the "max_per_shard" parameter should not be set below 1000.')
31323130

3133-
if self._shard_count is not None:
3134-
if self._shard_count <= 0:
3135-
raise ValueError('"shard_count" cannot be lower than "1".')
3136-
if self._shard_count >= 20_000:
3137-
raise ValueError('"shard_count" has a maximum value of "20_000".')
3138-
3139-
if self._shard_count is None:
3140-
logger.warning(
3141-
(
3142-
'The "shard_count" parameter for %r was not provided. '
3143-
"A best attempt will be made to determine the correct shard count based on the max_per_shard of %d."
3144-
),
3145-
self,
3146-
self._max_per_shard,
3147-
)
3148-
3149-
# Try and determine best count based on provided subscriptions with leeway...
3150-
# Defaults to 2 shards if no subscriptions are passed...
3151-
total_subs = len(self._initial_subs)
3152-
self._shard_count = clamp(math.ceil(total_subs / self._max_per_shard) + 1, 2, 20_000)
3153-
self._conduit_info: ConduitInfo = ConduitInfo(self)
3154-
3155-
# So we don't mess around in the event if we are closing everying...
3156-
self._closing: bool = False
3131+
self._conduit_info: ConduitInfo = ConduitInfo(self)
3132+
self._closing: bool = False
31573133

31583134
super().__init__(client_id=client_id, client_secret=client_secret, bot_id=bot_id, **kwargs)
31593135

@@ -3182,6 +3158,9 @@ async def _setup(self) -> None:
31823158
logger.info("Conduit found: %r. Attempting to take ownership of this conduit.", conduits[0])
31833159
self._conduit_info._conduit = conduits[0]
31843160

3161+
if not self._shard_ids:
3162+
self._shard_ids = list(range(self._conduit_info.shard_count)) # type: ignore
3163+
31853164
elif self._conduit_id is True:
31863165
logger.info("Attempting to create and take ownership of a new Conduit.")
31873166
await self._generate_new_conduit()
@@ -3193,33 +3172,35 @@ async def _setup(self) -> None:
31933172
if conduit.id == self._conduit_id:
31943173
logger.info('Conduit with the provided ID: "%s" found. Attempting to take ownership.', self._conduit_id)
31953174
self._conduit_info._conduit = conduit
3196-
self._shard_count = self._conduit_info.shard_count
3175+
3176+
if not self._shard_ids:
3177+
self._shard_ids = list(range(self._conduit_info.shard_count)) # type: ignore
31973178

31983179
if not self._conduit_info.conduit:
31993180
# TODO: Maybe log currernt conduit info?
32003181
raise MissingConduit("No conduit could be found with the provided ID or a new one can not be created.")
32013182

3202-
if self._conduit_info.conduit.shard_count != self._shard_count and self._update_conduit_shards:
3203-
logger.info(
3204-
'%r "shard_count" differs to provided "shard_count". Attempting to update the Conduit to %d shards.',
3205-
self._conduit_info,
3206-
self._shard_count,
3207-
)
3208-
3209-
assert self._shard_count is not None
3210-
await self._conduit_info.update_shard_count(self._shard_count, assign_transports=False)
3211-
3212-
await self._associate_shards()
3183+
await self._associate_shards(self._shard_ids)
32133184
await self.setup_hook()
32143185

3215-
async def _websocket_closed(self, payload: ...) -> ...:
3216-
# Since conduit websockets don't have any reconnect logic we should attempt to see if the shard should
3217-
# still exist and if a new websocket(s) are needed we can reconnect here instead...
3186+
async def _websocket_closed(self, payload: WebsocketClosed) -> None:
32183187
if self._closing:
32193188
return
32203189

3190+
if not payload.socket._shard_id:
3191+
return
3192+
3193+
if not payload.reassociate:
3194+
self._conduit_info._sockets.pop(payload.socket._shard_id, None)
3195+
return
3196+
3197+
if payload.socket._shard_id not in self._conduit_info._sockets:
3198+
return
3199+
3200+
await self._associate_shards(shard_ids=[int(payload.socket._shard_id)])
3201+
32213202
async def _connect_and_welcome(self, websocket: Websocket) -> bool:
3222-
await websocket.connect(fail_once=True)
3203+
await websocket.connect(fail_once=False)
32233204

32243205
async def welcome_predicate(payload: WebsocketWelcome) -> bool:
32253206
nonlocal websocket
@@ -3272,18 +3253,12 @@ async def _process_batched(self, batched: list[Websocket]) -> None:
32723253

32733254
logger.info("Associated shards with %r successfully.", self._conduit_info)
32743255

3275-
async def _associate_shards(self) -> None:
3256+
async def _associate_shards(self, shard_ids: list[int]) -> None:
32763257
assert self._conduit_info.conduit
32773258

32783259
batched: list[Websocket] = []
32793260

3280-
start = max([int(s) for s in self._conduit_info._sockets] + [0])
3281-
start = start + 1 if start != 0 else start
3282-
3283-
end = start + (self._conduit_info.conduit.shard_count - len(self._conduit_info.websockets))
3284-
range_ = range(start, end)
3285-
3286-
for i, n in enumerate(range_):
3261+
for i, n in enumerate(shard_ids):
32873262
if i % 10 == 0 and i != 0:
32883263
await self._process_batched(batched)
32893264
batched.clear()
@@ -3294,11 +3269,17 @@ async def _associate_shards(self) -> None:
32943269
if batched:
32953270
await self._process_batched(batched)
32963271

3272+
self._shard_ids = sorted([int(k) for k in self._conduit_info._sockets])
3273+
32973274
async def _generate_new_conduit(self) -> Conduit:
3298-
assert self._shard_count
3275+
if not self._shard_ids:
3276+
# Try and determine best count based on provided subscriptions with leeway...
3277+
# Defaults to 2 shards if no subscriptions are passed...
3278+
total_subs = len(self._initial_subs)
3279+
self._shard_ids = list(range(clamp(math.ceil(total_subs / self._max_per_shard) + 1, 2, 20_000)))
32993280

33003281
try:
3301-
new = await self.create_conduit(self._shard_count)
3282+
new = await self.create_conduit(len(self._shard_ids))
33023283
except HTTPException as e:
33033284
if e.status == 429:
33043285
# TODO: Maybe log currernt conduit info?

0 commit comments

Comments
 (0)