Skip to content

Commit 04af142

Browse files
committed
add websocket impl
1 parent d4d7076 commit 04af142

File tree

5 files changed

+514
-21
lines changed

5 files changed

+514
-21
lines changed

twitchio/ext/eventsub/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@
2323
"""
2424

2525
from .server import EventSubClient
26+
from .websocket import EventSubWSClient, Websocket
2627
from .models import *

twitchio/ext/eventsub/http.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66

77
if TYPE_CHECKING:
88
from .server import EventSubClient
9+
from .websocket import EventSubWSClient
910
from .models import EventData, Subscription
1011

1112
__all__ = ("EventSubHTTP",)
1213

1314

1415
class EventSubHTTP:
15-
def __init__(self, client: EventSubClient, token: Optional[str]):
16+
def __init__(self, client: Union[EventSubClient, EventSubWSClient], token: Optional[str]):
1617
self._client = client
1718
self._http = client.client._http
1819
self._token = token
1920

20-
async def create_subscription(self, event_type: Tuple[str, int, Type[EventData]], condition: Dict[str, str]):
21+
async def create_webhook_subscription(self, event_type: Tuple[str, int, Type[EventData]], condition: Dict[str, str]):
2122
payload = {
2223
"type": event_type[0],
2324
"version": str(event_type[1]),
@@ -27,6 +28,16 @@ async def create_subscription(self, event_type: Tuple[str, int, Type[EventData]]
2728
route = Route("POST", "eventsub/subscriptions", body=payload, token=self._token)
2829
return await self._http.request(route, paginate=False, force_app_token=True)
2930

31+
async def create_websocket_subscription(self, event_type: Tuple[str, int, Type[EventData]], condition: Dict[str, str], session_id: str, token: str):
32+
payload = {
33+
"type": event_type[0],
34+
"version": str(event_type[1]),
35+
"condition": condition,
36+
"transport": {"method": "websocket", "session_id": session_id},
37+
}
38+
route = Route("POST", "eventsub/subscriptions", body=payload, token=token)
39+
return await self._http.request(route, paginate=False, full_body=True)
40+
3041
async def delete_subscription(self, subscription: Union[str, Subscription]):
3142
if isinstance(subscription, models.Subscription):
3243
return await self._http.request(

twitchio/ext/eventsub/models.py

Lines changed: 122 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import hashlib
55
import logging
66
from enum import Enum
7-
from typing import Dict, TYPE_CHECKING, Optional, Type, Union, Tuple, List
7+
from typing import Dict, TYPE_CHECKING, Optional, Type, Union, Tuple, List, overload
88
from typing_extensions import Literal
99

1010
from aiohttp import web
@@ -13,6 +13,7 @@
1313

1414
if TYPE_CHECKING:
1515
from .server import EventSubClient
16+
from .websocket import EventSubWSClient
1617

1718
try:
1819
import ujson as json
@@ -36,7 +37,7 @@ def __init__(self, **kwargs):
3637

3738

3839
class Subscription:
39-
__slots__ = "id", "status", "type", "version", "cost", "condition", "transport", "created_at"
40+
__slots__ = "id", "status", "type", "version", "cost", "condition", "transport", "transport_method", "created_at"
4041

4142
def __init__(self, data: dict):
4243
self.id: str = data["id"]
@@ -47,8 +48,14 @@ def __init__(self, data: dict):
4748
self.condition: Dict[str, str] = data["condition"]
4849
self.created_at = _parse_datetime(data["created_at"])
4950
self.transport = EmptyObject()
50-
self.transport.method: str = data["transport"]["method"] # noqa
51-
self.transport.callback: str = data["transport"]["callback"] # noqa
51+
self.transport_method: TransportType = getattr(TransportType, data["transport"]["method"])
52+
self.transport.method: str = data["transport"]["method"] # type: ignore
53+
54+
if self.transport_method is TransportType.webhook:
55+
self.transport.callback: str = data["transport"]["callback"] # type: ignore
56+
else:
57+
self.transport.callback: str = "" # type: ignore # compatibility
58+
self.transport.session_id: str = data["transport"]["session_id"] # type: ignore
5259

5360

5461
class Headers:
@@ -82,6 +89,42 @@ def __init__(self, request: web.Request):
8289
self._raw_timestamp = request.headers["Twitch-Eventsub-Message-Timestamp"]
8390

8491

92+
class WebsocketHeaders:
93+
"""
94+
The headers of the inbound Websocket EventSub message
95+
96+
..warning:
97+
98+
This is a BETA feature
99+
100+
Attributes
101+
-----------
102+
message_id: :class:`str`
103+
The unique ID of the message
104+
message_type: :class:`str`
105+
The type of the message coming through
106+
message_retry: :class:`int`
107+
Kept for compatibility with :class:`Headers`
108+
signature: :class:`str`
109+
Kept for compatibility with :class:`Headers`
110+
subscription_type: :class:`str`
111+
The type of the subscription on the inbound message
112+
subscription_version: :class:`str`
113+
The version of the subscription.
114+
timestamp: :class:`datetime.datetime`
115+
The timestamp the message was sent at
116+
"""
117+
def __init__(self, frame: dict):
118+
meta = frame["metadata"]
119+
self.message_id: str = meta["message_id"]
120+
self.timestamp = _parse_datetime(meta["message_timestamp"])
121+
self.message_type: Literal["notification", "revocation", "reconnect", "session_keepalive"] = meta["message_type"]
122+
self.message_retry: int = 0 # don't make breaking changes with the Header class
123+
self.signature: str = ""
124+
self.subscription_type: str = frame["payload"]["subscription"]["type"]
125+
self.subscription_version: str = frame["payload"]["subscription"]["version"]
126+
127+
85128
class BaseEvent:
86129
"""
87130
The base of all the event classes
@@ -94,21 +137,47 @@ class BaseEvent:
94137
The headers received with the message
95138
"""
96139

97-
__slots__ = "_client", "_raw_data", "subscription", "headers"
140+
__slots__ = ("_client", "_raw_data", "subscription", "headers")
141+
142+
@overload
143+
def __init__(self, client: EventSubClient, _data: str, request: web.Request):
144+
...
98145

99-
def __init__(self, client: EventSubClient, data: str, request: web.Request):
146+
@overload
147+
def __init__(self, client: EventSubWSClient, _data: dict, request: None):
148+
...
149+
150+
def __init__(self, client: Union[EventSubClient, EventSubWSClient], _data: Union[str, dict], request: Optional[web.Request]):
100151
self._client = client
101-
self._raw_data = data
102-
_data: dict = _loads(data)
103-
self.subscription = Subscription(_data["subscription"])
104-
self.headers = Headers(request)
105-
self.setup(_data)
152+
self._raw_data = _data
153+
154+
if isinstance(_data, str):
155+
data: dict = _loads(_data)
156+
else:
157+
data = _data
158+
159+
self.headers: Union[Headers, WebsocketHeaders]
160+
self.subscription: Subscription
161+
162+
if request:
163+
data: dict = _loads(_data)
164+
self.headers = Headers(request)
165+
self.subscription = Subscription(data["subscription"])
166+
self.setup(data)
167+
else:
168+
self.headers = WebsocketHeaders(data)
169+
self.subscription = Subscription(data["payload"]["subscription"])
170+
self.setup(data["payload"])
171+
106172

107173
def setup(self, data: dict):
108174
pass
109175

110176
def verify(self):
111-
hmac_message = (self.headers.message_id + self.headers._raw_timestamp + self._raw_data).encode("utf-8")
177+
"""
178+
Only used in webhook transport types. Verifies the message is valid
179+
"""
180+
hmac_message = (self.headers.message_id + self.headers._raw_timestamp + self._raw_data).encode("utf-8") # type: ignore
112181
secret = self._client.secret.encode("utf-8")
113182
digest = hmac.new(secret, msg=hmac_message, digestmod=hashlib.sha256).hexdigest()
114183

@@ -127,6 +196,9 @@ class ChallengeEvent(BaseEvent):
127196
"""
128197
A challenge event.
129198
199+
.. note::
200+
These are only dispatched when using :class:`~twitchio.ext.eventsub.EventSubClient`
201+
130202
Attributes
131203
-----------
132204
challenge: :class`str`
@@ -139,7 +211,7 @@ def setup(self, data: dict):
139211
self.challenge: str = data["challenge"]
140212

141213
def verify(self):
142-
hmac_message = (self.headers.message_id + self.headers._raw_timestamp + self._raw_data).encode("utf-8")
214+
hmac_message = (self.headers.message_id + self.headers._raw_timestamp + self._raw_data).encode("utf-8") # type: ignore
143215
secret = self._client.secret.encode("utf-8")
144216
digest = hmac.new(secret, msg=hmac_message, digestmod=hashlib.sha256).hexdigest()
145217

@@ -150,6 +222,39 @@ def verify(self):
150222
return web.Response(status=200, text=self.challenge)
151223

152224

225+
class ReconnectEvent(BaseEvent):
226+
"""
227+
A reconnect event. Called by twitch when the websocket needs to be disconnected for maintenance or other reasons
228+
229+
.. note::
230+
These are only dispatched when using :class:`~twitchio.ext.eventsub.EventSubWSClient
231+
232+
Attributes
233+
-----------
234+
reconnect_url: :class:`str`
235+
The URL to reconnect to
236+
connected_at: :class:`~datetime.datetime`
237+
When the original websocket connected
238+
"""
239+
240+
__slots__ = ("reconnect_url", "connected_at")
241+
242+
def setup(self, data: dict):
243+
self.reconnect_url: str = data["session"]["reconnect_url"]
244+
self.connected_at: datetime.datetime = _parse_datetime(data["session"]["connected_at"])
245+
246+
247+
class KeepAliveEvent(BaseEvent):
248+
"""
249+
A keep-alive event. Called by twitch when no message has been sent for more than ``keepalive_timeout``
250+
251+
.. note::
252+
These are only dispatched when using :class:`~twitchio.ext.eventsub.EventSubWSClient
253+
254+
"""
255+
pass
256+
257+
153258
class NotificationEvent(BaseEvent):
154259
"""
155260
A notification event
@@ -1410,3 +1515,7 @@ class _SubscriptionTypes(metaclass=_SubTypesMeta):
14101515

14111516

14121517
SubscriptionTypes = _SubscriptionTypes()
1518+
1519+
class TransportType(Enum):
1520+
webhook = "webhook"
1521+
websocket = "websocket"

twitchio/ext/eventsub/server.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async def subscribe_user_updated(self, user: Union[PartialUser, str, int]):
6868
user = user.id
6969

7070
user = str(user)
71-
return await self._http.create_subscription(models.SubscriptionTypes.user_update, {"user_id": user})
71+
return await self._http.create_webhook_subscription(models.SubscriptionTypes.user_update, {"user_id": user})
7272

7373
async def subscribe_channel_raid(
7474
self, from_broadcaster: Union[PartialUser, str, int] = None, to_broadcaster: Union[PartialUser, str, int] = None
@@ -87,7 +87,7 @@ async def subscribe_channel_raid(
8787
broadcaster = broadcaster.id
8888

8989
broadcaster = str(broadcaster)
90-
return await self._http.create_subscription(models.SubscriptionTypes.raid, {who: broadcaster})
90+
return await self._http.create_webhook_subscription(models.SubscriptionTypes.raid, {who: broadcaster})
9191

9292
async def _subscribe_channel_points_reward(
9393
self, event, broadcaster: Union[PartialUser, str, int], reward_id: str = None
@@ -100,7 +100,7 @@ async def _subscribe_channel_points_reward(
100100
if reward_id:
101101
data["reward_id"] = reward_id
102102

103-
return await self._http.create_subscription(event, data)
103+
return await self._http.create_webhook_subscription(event, data)
104104

105105
async def _subscribe_with_broadcaster(
106106
self, event: Tuple[str, int, Type[models._DataType]], broadcaster: Union[PartialUser, str, int]
@@ -109,7 +109,7 @@ async def _subscribe_with_broadcaster(
109109
broadcaster = broadcaster.id
110110

111111
broadcaster = str(broadcaster)
112-
return await self._http.create_subscription(event, {"broadcaster_user_id": broadcaster})
112+
return await self._http.create_webhook_subscription(event, {"broadcaster_user_id": broadcaster})
113113

114114
def subscribe_channel_bans(self, broadcaster: Union[PartialUser, str, int]):
115115
return self._subscribe_with_broadcaster(models.SubscriptionTypes.ban, broadcaster)
@@ -215,12 +215,12 @@ def subscribe_channel_prediction_end(self, broadcaster: Union[PartialUser, str,
215215
return self._subscribe_with_broadcaster(models.SubscriptionTypes.prediction_end, broadcaster)
216216

217217
async def subscribe_user_authorization_granted(self):
218-
return await self._http.create_subscription(
218+
return await self._http.create_webhook_subscription(
219219
models.SubscriptionTypes.user_authorization_grant, {"client_id": self.client._http.client_id}
220220
)
221221

222222
async def subscribe_user_authorization_revoked(self):
223-
return await self._http.create_subscription(
223+
return await self._http.create_webhook_subscription(
224224
models.SubscriptionTypes.user_authorization_revoke, {"client_id": self.client._http.client_id}
225225
)
226226

0 commit comments

Comments
 (0)