Skip to content

Commit 60fe12c

Browse files
committed
ADD: Add getter for Live subscription requests
1 parent 223523c commit 60fe12c

File tree

5 files changed

+109
-46
lines changed

5 files changed

+109
-46
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# Changelog
22

3+
## 0.67.0 - TBD
4+
5+
#### Enhancements
6+
- Added a property `Live.subscription_requests` which returns a list of tuples containing every `SubscriptionRequest` for the live session
7+
- Changed the return value of `Live.subscribe()` to `int`, the value of the subscription ID, which can be used to index into the `Live.subscription_requests` property
8+
9+
#### Breaking changes
10+
- Several log messages have been reformatted to improve clarity and reduce redundancy, especially at debug levels
11+
312
## 0.66.0 - 2025-11-18
413

514
#### Enhancements

databento/live/client.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from databento.common.types import RecordCallback
3232
from databento.common.validation import validate_enum
3333
from databento.common.validation import validate_semantic_string
34+
from databento.live.gateway import SubscriptionRequest
3435
from databento.live.session import DEFAULT_REMOTE_PORT
3536
from databento.live.session import LiveSession
3637
from databento.live.session import SessionMetadata
@@ -228,6 +229,38 @@ def session_id(self) -> str | None:
228229
"""
229230
return self._session.session_id
230231

232+
@property
233+
def subscription_requests(
234+
self,
235+
) -> list[tuple[SubscriptionRequest, ...]]:
236+
"""
237+
Return a list of tuples containing every `SubscriptionRequest` message
238+
sent for the session. The list is in order of the subscriptions made
239+
and can be indexed using the value returned by each call to
240+
`Live.subscribe()`.
241+
242+
Subscriptions which contain a large
243+
list of symbols are batched. Because of this, a single `subscription_id` may have
244+
more than one associated `SubscriptionRequest`.
245+
246+
Returns
247+
-------
248+
list[tuple[SubscriptionRequest, ...]]
249+
A list of tuples containing every subscription request.
250+
Each entry in the list corresponds to a single subscription.
251+
252+
Raises
253+
------
254+
IndexError
255+
If the subscription ID is invalid.
256+
257+
See Also
258+
--------
259+
Live.subscribe()
260+
261+
"""
262+
return self._session._subscriptions
263+
231264
@property
232265
def symbology_map(self) -> dict[int, str | int]:
233266
"""
@@ -446,7 +479,7 @@ def subscribe(
446479
stype_in: SType | str = SType.RAW_SYMBOL,
447480
start: pd.Timestamp | datetime | date | str | int | None = None,
448481
snapshot: bool = False,
449-
) -> None:
482+
) -> int:
450483
"""
451484
Add a new subscription to the session.
452485
@@ -476,6 +509,11 @@ def subscribe(
476509
Request subscription with snapshot. The `start` parameter must be `None`.
477510
Only supported with `mbo` schema.
478511
512+
Returns
513+
-------
514+
int
515+
The numeric identifier for this subscription request.
516+
479517
Raises
480518
------
481519
ValueError
@@ -494,7 +532,7 @@ def subscribe(
494532
495533
"""
496534
logger.info(
497-
"subscribing to %s:%s %s start=%s snapshot=%s",
535+
"subscribing to schema=%s stype_in=%s symbols='%s' start=%s snapshot=%s",
498536
schema,
499537
stype_in,
500538
symbols,
@@ -509,7 +547,7 @@ def subscribe(
509547
if snapshot and start is not None:
510548
raise ValueError("Subscription with snapshot expects start=None")
511549

512-
self._session.subscribe(
550+
return self._session.subscribe(
513551
dataset=dataset,
514552
schema=schema,
515553
stype_in=stype_in,

databento/live/protocol.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
from typing import Final
88

99
import databento_dbn
10-
from databento_dbn import DBNError
1110
from databento_dbn import Metadata
1211
from databento_dbn import Schema
1312
from databento_dbn import SType
14-
from databento_dbn import SystemCode
1513
from databento_dbn import VersionUpgradePolicy
1614

1715
from databento.common import cram
@@ -313,13 +311,14 @@ def subscribe(
313311
list[SubscriptionRequest]
314312
315313
"""
316-
logger.info(
317-
"sending subscription to %s:%s %s start=%s snapshot=%s",
314+
logger.debug(
315+
"sending subscription request schema=%s stype_in=%s symbols='%s' start='%s' snapshot=%s id=%s",
318316
schema,
319317
stype_in,
320318
symbols,
321319
start if start is not None else "now",
322320
snapshot,
321+
subscription_id,
323322
)
324323

325324
stype_in_valid = validate_enum(stype_in, SType, "stype_in")
@@ -341,6 +340,12 @@ def subscribe(
341340
)
342341
subscriptions.append(message)
343342

343+
if len(subscriptions) > 1:
344+
logger.debug(
345+
"batched subscription into %d requests id=%s",
346+
len(subscriptions),
347+
subscription_id,
348+
)
344349
self.transport.writelines(map(bytes, subscriptions))
345350
return subscriptions
346351

@@ -374,27 +379,20 @@ def _process_dbn(self, data: bytes) -> None:
374379
continue
375380
if isinstance(record, databento_dbn.ErrorMsg):
376381
logger.error(
377-
"gateway error: %s",
382+
"gateway error code=%s err='%s'",
383+
record.code,
378384
record.err,
379385
)
380386
self._error_msgs.append(record.err)
381387
elif isinstance(record, databento_dbn.SystemMsg):
382388
if record.is_heartbeat():
383389
logger.debug("gateway heartbeat")
384390
else:
385-
try:
386-
msg_code = record.code
387-
except DBNError:
388-
msg_code = None
389-
if msg_code == SystemCode.SLOW_READER_WARNING:
390-
logger.warning(
391-
record.msg,
392-
)
393-
else:
394-
logger.debug(
395-
"gateway message: %s",
396-
record.msg,
397-
)
391+
logger.info(
392+
"system message code=%s msg='%s'",
393+
record.code,
394+
record.msg,
395+
)
398396
self.received_record(record)
399397

400398
def _process_gateway(self, data: bytes) -> None:
@@ -423,34 +421,44 @@ def _handle_gateway_message(self, message: GatewayControl) -> None:
423421

424422
@_handle_gateway_message.register(Greeting)
425423
def _(self, message: Greeting) -> None:
426-
logger.debug("greeting received by remote gateway v%s", message.lsg_version)
424+
logger.debug(
425+
"greeting received by remote gateway version='%s'",
426+
message.lsg_version,
427+
)
427428

428429
@_handle_gateway_message.register(ChallengeRequest)
429430
def _(self, message: ChallengeRequest) -> None:
430-
logger.debug("received CRAM challenge: %s", message.cram)
431+
logger.debug("received CRAM challenge cram='%s'", message.cram)
431432
response = cram.get_challenge_response(message.cram, self.__api_key)
432433
auth_request = AuthenticationRequest(
433434
auth=response,
434435
dataset=self._dataset,
435436
ts_out=str(int(self._ts_out)),
436437
heartbeat_interval_s=self._heartbeat_interval_s,
437438
)
438-
logger.debug("sending CRAM challenge response: %s", str(auth_request).strip())
439+
logger.debug(
440+
"sending CRAM challenge response auth='%s' dataset=%s encoding=%s ts_out=%s heartbeat_interval_s=%s client='%s'",
441+
auth_request.auth,
442+
auth_request.dataset,
443+
auth_request.encoding,
444+
auth_request.ts_out,
445+
auth_request.heartbeat_interval_s,
446+
auth_request.client,
447+
)
439448
self.transport.write(bytes(auth_request))
440449

441450
@_handle_gateway_message.register(AuthenticationResponse)
442451
def _(self, message: AuthenticationResponse) -> None:
443452
if message.success == "0":
444-
logger.error("CRAM authentication failed: %s", message.error)
453+
logger.error("CRAM authentication error: %s", message.error)
445454
self.authenticated.set_exception(
446-
BentoError(f"User authentication failed: {message.error}"),
455+
BentoError(message.error),
447456
)
448457
self.transport.close()
449458
else:
450459
session_id = message.session_id
451460

452461
logger.debug(
453-
"CRAM authenticated session id assigned `%s`",
454-
session_id,
462+
"CRAM authentication successful",
455463
)
456464
self.authenticated.set_result(session_id)

databento/live/session.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import dataclasses
5+
import itertools
56
import logging
67
import queue
78
import struct
@@ -329,8 +330,7 @@ def __init__(
329330
self._transport: asyncio.Transport | None = None
330331
self._session_id: str | None = None
331332

332-
self._subscription_counter = 0
333-
self._subscriptions: list[SubscriptionRequest] = []
333+
self._subscriptions: list[tuple[SubscriptionRequest, ...]] = []
334334
self._reconnect_policy = ReconnectPolicy(reconnect_policy)
335335
self._reconnect_task: asyncio.Task[None] | None = None
336336

@@ -463,7 +463,7 @@ def subscribe(
463463
stype_in: SType | str = SType.RAW_SYMBOL,
464464
start: str | int | None = None,
465465
snapshot: bool = False,
466-
) -> None:
466+
) -> int:
467467
"""
468468
Send a subscription request on the current connection. This will create
469469
a new connection if there is no active connection to the gateway.
@@ -498,17 +498,20 @@ def subscribe(
498498
self._session_id = None
499499
self._connect(dataset=dataset)
500500

501-
self._subscription_counter += 1
502-
self._subscriptions.extend(
503-
self._protocol.subscribe(
504-
schema=schema,
505-
symbols=symbols,
506-
stype_in=stype_in,
507-
start=start,
508-
snapshot=snapshot,
509-
subscription_id=self._subscription_counter,
501+
subscription_id = len(self._subscriptions)
502+
self._subscriptions.append(
503+
tuple(
504+
self._protocol.subscribe(
505+
schema=schema,
506+
symbols=symbols,
507+
stype_in=stype_in,
508+
start=start,
509+
snapshot=snapshot,
510+
subscription_id=subscription_id,
511+
),
510512
),
511513
)
514+
return subscription_id
512515

513516
def terminate(self) -> None:
514517
with self._lock:
@@ -542,7 +545,7 @@ async def wait_for_close(self) -> None:
542545
self._cleanup()
543546

544547
def _cleanup(self) -> None:
545-
logger.debug("cleaning up session_id=%s", self.session_id)
548+
logger.debug("cleaning up session_id='%s'", self.session_id)
546549
self._user_callbacks.clear()
547550
for stream in self._user_streams:
548551
if not stream.is_closed:
@@ -596,7 +599,7 @@ async def _connect_task(
596599
logger.debug("using default gateway for dataset %s", dataset)
597600
else:
598601
gateway = self._user_gateway
599-
logger.debug("using user specified gateway: %s", gateway)
602+
logger.debug("user gateway override gateway='%s'", gateway)
600603

601604
logger.info("connecting to remote gateway")
602605
try:
@@ -638,7 +641,7 @@ async def _connect_task(
638641

639642
self._session_id = session_id
640643
logger.info(
641-
"authenticated session %s",
644+
"authenticated session_id='%s'",
642645
self.session_id,
643646
)
644647

@@ -669,7 +672,7 @@ async def _reconnect(self) -> None:
669672
dataset=self._protocol._dataset,
670673
)
671674

672-
for sub in self._subscriptions:
675+
for sub in itertools.chain(*self._subscriptions):
673676
self._protocol.subscribe(
674677
schema=sub.schema,
675678
symbols=sub.symbols,

tests/test_live_client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_live_connection_cram_failure(
163163
)
164164

165165
# Ensure this was an authentication error
166-
exc.match(r"User authentication failed:")
166+
exc.match(r"Authentication failed.")
167167

168168

169169
@pytest.mark.parametrize(
@@ -554,6 +554,8 @@ async def test_live_subscribe(
554554
assert message.symbols == symbols
555555
assert message.start == start
556556
assert message.snapshot == "0"
557+
assert len(live_client.subscription_requests[0]) == 1
558+
assert live_client.subscription_requests[0][0].id == int(message.id)
557559

558560

559561
@pytest.mark.parametrize(
@@ -645,16 +647,19 @@ async def test_live_subscribe_large_symbol_list(
645647
symbols=large_symbol_list,
646648
)
647649

650+
batched = []
648651
reconstructed: list[str] = []
649652
for i in range(8):
650653
message = await mock_live_server.wait_for_message_of_type(
651654
message_type=gateway.SubscriptionRequest,
652655
)
653656
assert int(message.is_last) == int(i == 7)
654657
reconstructed.extend(message.symbols.split(","))
658+
batched.append(message)
655659

656660
# Assert
657661
assert reconstructed == large_symbol_list
662+
assert len(live_client.subscription_requests[0]) == len(batched)
658663

659664

660665
async def test_live_subscribe_from_callback(
@@ -1663,7 +1668,7 @@ async def test_live_connection_reuse_cram_failure(
16631668
)
16641669

16651670
# Ensure this was an authentication error
1666-
exc.match(r"User authentication failed:")
1671+
exc.match(r"Authentication failed.")
16671672

16681673
async with mock_live_server.api_key_context(test_api_key):
16691674
live_client.subscribe(

0 commit comments

Comments
 (0)