Skip to content

Commit 35d131e

Browse files
committed
ADD: Add subscription ID to clients
1 parent 317d7f3 commit 35d131e

File tree

5 files changed

+24
-4
lines changed

5 files changed

+24
-4
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Changelog
22

3+
## 0.52.0 - TBD
4+
5+
#### Enhancements
6+
- Added new optional `id` field to `SubcriptionRequest` class which will be used for improved error messages
7+
38
## 0.51.0 - 2025-04-08
49

510
#### Enhancements

databento/live/gateway.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class SubscriptionRequest(GatewayControl):
132132
symbols: str
133133
start: int | None = None
134134
snapshot: int = 0
135+
id: int | None = None
135136

136137

137138
@dataclasses.dataclass

databento/live/protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def subscribe(
284284
stype_in: SType | str = SType.RAW_SYMBOL,
285285
start: str | int | None = None,
286286
snapshot: bool = False,
287+
subscription_id: int | None = None,
287288
) -> list[SubscriptionRequest]:
288289
"""
289290
Send a SubscriptionRequest to the gateway. Returns a list of all
@@ -302,6 +303,8 @@ def subscribe(
302303
within 24 hours.
303304
snapshot: bool, default to 'False'
304305
Request subscription with snapshot. The `start` parameter must be `None`.
306+
subscription_id : int, optional
307+
A numerical identifier to associate with this subscription.
305308
306309
Returns
307310
-------
@@ -329,6 +332,7 @@ def subscribe(
329332
symbols=batch_str,
330333
start=optional_datetime_to_unix_nanoseconds(start),
331334
snapshot=int(snapshot),
335+
id=subscription_id,
332336
)
333337
subscriptions.append(message)
334338

databento/live/session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def __init__(
332332
self._transport: asyncio.Transport | None = None
333333
self._session_id: str | None = None
334334

335+
self._subscription_counter = 0
335336
self._subscriptions: list[SubscriptionRequest] = []
336337
self._reconnect_policy = ReconnectPolicy(reconnect_policy)
337338
self._reconnect_task: asyncio.Task[None] | None = None
@@ -499,13 +500,15 @@ def subscribe(
499500
if self._protocol is None:
500501
self._connect(dataset=dataset)
501502

503+
self._subscription_counter += 1
502504
self._subscriptions.extend(
503505
self._protocol.subscribe(
504506
schema=schema,
505507
symbols=symbols,
506508
stype_in=stype_in,
507509
start=start,
508510
snapshot=snapshot,
511+
subscription_id=self._subscription_counter,
509512
),
510513
)
511514

@@ -672,6 +675,7 @@ async def _reconnect(self) -> None:
672675
stype_in=sub.stype_in,
673676
snapshot=bool(sub.snapshot),
674677
start=None,
678+
subscription_id=sub.id,
675679
)
676680

677681
if should_restart:

tests/test_live_gateway_messages.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,15 +299,15 @@ def test_serialize_session_start(
299299
"line, expected",
300300
[
301301
pytest.param(
302-
"schema=trades|" "stype_in=instrument_id|" "symbols=1,2,3\n",
303-
("trades", "instrument_id", "1,2,3", None),
302+
"schema=trades|" "stype_in=instrument_id|" "symbols=1,2,3|" "id=23\n",
303+
("trades", "instrument_id", "1,2,3", None, "23"),
304304
),
305305
pytest.param(
306306
"schema=trades|"
307307
"stype_in=instrument_id|"
308308
"symbols=1,2,3|"
309309
"start=1671717080706865759\n",
310-
("trades", "instrument_id", "1,2,3", "1671717080706865759"),
310+
("trades", "instrument_id", "1,2,3", "1671717080706865759", None),
311311
),
312312
pytest.param(
313313
"schema=trades|" "stype_in=instrument_id|" "symbols=1,2,3",
@@ -336,6 +336,7 @@ def test_parse_subscription_request(
336336
msg.stype_in,
337337
msg.symbols,
338338
msg.start,
339+
msg.id,
339340
) == expected
340341
else:
341342
with pytest.raises(expected):
@@ -374,8 +375,13 @@ def test_parse_subscription_request(
374375
symbols="1234,5678,90",
375376
start=None,
376377
snapshot=1,
378+
id=5,
377379
),
378-
b"schema=mbo|" b"stype_in=instrument_id|" b"symbols=1234,5678,90|" b"snapshot=1\n",
380+
b"schema=mbo|"
381+
b"stype_in=instrument_id|"
382+
b"symbols=1234,5678,90|"
383+
b"snapshot=1|"
384+
b"id=5\n",
379385
),
380386
],
381387
)

0 commit comments

Comments
 (0)