Skip to content

Commit ac23d31

Browse files
committed
FIX: Fix subscribe in Live callback
1 parent bed27e9 commit ac23d31

File tree

4 files changed

+108
-62
lines changed

4 files changed

+108
-62
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#### Enhancements
66
- Increase `Live` session connection & authentication timeouts
77

8+
#### Bug fixes
9+
- Fixed an issue where calling `Live.subscribe` from a `Live` client callback would cause a deadlock
10+
811
## 0.31.0 - 2024-03-05
912

1013
#### Enhancements

databento/live/session.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -434,15 +434,12 @@ def subscribe(
434434
loop=self._loop,
435435
)
436436

437-
asyncio.run_coroutine_threadsafe(
438-
self._subscribe_task(
437+
self._protocol.subscribe(
439438
schema=schema,
440439
symbols=symbols,
441440
stype_in=stype_in,
442441
start=start,
443-
),
444-
loop=self._loop,
445-
).result()
442+
)
446443

447444
def resume_reading(self) -> None:
448445
"""
@@ -565,21 +562,3 @@ async def _connect_task(
565562
)
566563

567564
return transport, protocol
568-
569-
async def _subscribe_task(
570-
self,
571-
schema: Schema | str,
572-
symbols: Iterable[str | int] | str | int = ALL_SYMBOLS,
573-
stype_in: SType | str = SType.RAW_SYMBOL,
574-
start: str | int | None = None,
575-
) -> None:
576-
with self._lock:
577-
if self._protocol is None:
578-
return
579-
580-
self._protocol.subscribe(
581-
schema=schema,
582-
symbols=symbols,
583-
stype_in=stype_in,
584-
start=start,
585-
)

tests/mock_live_server.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from functools import singledispatchmethod
1717
from io import BytesIO
1818
from os import PathLike
19+
from typing import Any
1920
from typing import Callable
2021
from typing import NewType
2122
from typing import TypeVar
@@ -32,9 +33,7 @@
3233
from databento.live.gateway import SessionStart
3334
from databento.live.gateway import SubscriptionRequest
3435
from databento.live.gateway import parse_gateway_message
35-
from databento_dbn import Metadata
3636
from databento_dbn import Schema
37-
from databento_dbn import SType
3837

3938

4039
LIVE_SERVER_VERSION: str = "1.0.0"
@@ -100,7 +99,8 @@ def __init__(
10099
self._version: str = version
101100
self._is_authenticated: bool = False
102101
self._is_streaming: bool = False
103-
self._repeater_tasks: set[asyncio.Task[None]] = set()
102+
self._subscriptions: list[SubscriptionRequest] = []
103+
self._tasks: set[asyncio.Task[None]] = set()
104104

105105
self._dbn_path = pathlib.Path(dbn_path)
106106
self._user_api_keys = user_api_keys
@@ -155,6 +155,18 @@ def is_streaming(self) -> bool:
155155
"""
156156
return self._is_streaming
157157

158+
@property
159+
def dataset_path(self) -> pathlib.Path:
160+
"""
161+
The path to the DBN files for serving.
162+
163+
Returns
164+
-------
165+
Path
166+
167+
"""
168+
return self._dbn_path / (self._dataset or "")
169+
158170
@property
159171
def mode(self) -> MockLiveMode:
160172
"""
@@ -205,6 +217,18 @@ def session_id(self) -> str:
205217
"""
206218
return str(hash(self))
207219

220+
@property
221+
def subscriptions(self) -> tuple[SubscriptionRequest, ...]:
222+
"""
223+
The received subscriptions.
224+
225+
Returns
226+
-------
227+
tuple[SubscriptionRequest, ...]
228+
229+
"""
230+
return tuple(self._subscriptions)
231+
208232
@property
209233
def version(self) -> str:
210234
"""
@@ -353,11 +377,11 @@ def _(self, message: AuthenticationRequest) -> None:
353377
logger.info("received CRAM response: %s", message.auth)
354378
if self.is_authenticated:
355379
logger.error("authentication request sent when already authenticated")
356-
self.__transport.close()
380+
self.__transport.write_eof()
357381
return
358382
if self.is_streaming:
359383
logger.error("authentication request sent while streaming")
360-
self.__transport.close()
384+
self.__transport.write_eof()
361385
return
362386

363387
_, bucket_id = message.auth.split("-")
@@ -400,54 +424,49 @@ def _(self, message: SubscriptionRequest) -> None:
400424
logger.info("received subscription request: %s", str(message).strip())
401425
if not self.is_authenticated:
402426
logger.error("subscription request sent while unauthenticated")
403-
self.__transport.close()
427+
self.__transport.write_eof()
404428

405-
if self.is_streaming:
406-
logger.error("subscription request sent while streaming")
407-
self.__transport.close()
429+
self._subscriptions.append(message)
408430

409-
self._schemas.append(Schema(message.schema))
431+
if self.is_streaming:
432+
self.create_server_task(message)
410433

411434
@handle_client_message.register(SessionStart)
412435
def _(self, message: SessionStart) -> None:
413436
logger.info("received session start request: %s", str(message).strip())
414437
self._is_streaming = True
415438

416-
dataset_path = self._dbn_path / (self._dataset or "")
439+
for sub in self.subscriptions:
440+
self.create_server_task(sub)
441+
442+
def create_server_task(self, message: SubscriptionRequest) -> None:
417443
if self.mode is MockLiveMode.REPLAY:
418-
for schema in self._schemas:
419-
for test_data_path in dataset_path.glob(f"*{schema}.dbn.zst"):
420-
decompressor = zstandard.ZstdDecompressor().stream_reader(
421-
test_data_path.read_bytes(),
422-
)
423-
logger.info(
424-
"streaming %s for %s schema",
425-
test_data_path.name,
426-
schema,
427-
)
428-
self.__transport.write(decompressor.readall())
444+
task = asyncio.create_task(self.replay_task(schema=Schema(message.schema)))
445+
else:
446+
task = asyncio.create_task(self.repeater_task(schema=Schema(message.schema)))
429447

430-
logger.info(
431-
"data streaming for %d schema(s) completed",
432-
len(self._schemas),
433-
)
448+
self._tasks.add(task)
449+
task.add_done_callback(self._tasks.remove)
450+
task.add_done_callback(self.check_done)
434451

452+
def check_done(self, _: Any) -> None:
453+
if not self._tasks:
454+
logger.info("streaming tasks completed")
435455
self.__transport.write_eof()
436-
self.__transport.close()
437456

438-
elif self.mode is MockLiveMode.REPEAT:
439-
metadata = Metadata("UNIT.TEST", 0, SType.RAW_SYMBOL, [], [], [], [])
440-
self.__transport.write(bytes(metadata))
441-
442-
loop = asyncio.get_event_loop()
443-
for schema in self._schemas:
444-
task = loop.create_task(self.repeater(schema))
445-
self._repeater_tasks.add(task)
446-
task.add_done_callback(self._repeater_tasks.remove)
447-
else:
448-
raise ValueError(f"unsupported mode {MockLiveMode.REPEAT}")
457+
async def replay_task(self, schema: Schema) -> None:
458+
for test_data_path in self.dataset_path.glob(f"*{schema}.dbn.zst"):
459+
decompressor = zstandard.ZstdDecompressor().stream_reader(
460+
test_data_path.read_bytes(),
461+
)
462+
logger.info(
463+
"streaming %s for %s schema",
464+
test_data_path.name,
465+
schema,
466+
)
467+
self.__transport.write(decompressor.readall())
449468

450-
async def repeater(self, schema: Schema) -> None:
469+
async def repeater_task(self, schema: Schema) -> None:
451470
struct = SCHEMA_STRUCT_MAP[schema]
452471
repeated = bytes(struct(*[0] * 12)) # for now we only support MBP_1
453472

tests/test_live_client.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,51 @@ async def test_live_subscribe_large_symbol_list(
503503
assert reconstructed == large_symbol_list
504504

505505

506+
async def test_live_subscribe_from_callback(
507+
live_client: client.Live,
508+
mock_live_server: MockLiveServer,
509+
) -> None:
510+
"""
511+
Test that `Live.subscribe` can be called from a callback.
512+
"""
513+
# Arrange
514+
live_client.subscribe(
515+
dataset=Dataset.GLBX_MDP3,
516+
schema=Schema.OHLCV_1H,
517+
stype_in=SType.RAW_SYMBOL,
518+
symbols="TEST0",
519+
)
520+
521+
def cb_sub(_: DBNRecord) -> None:
522+
live_client.subscribe(
523+
dataset=Dataset.GLBX_MDP3,
524+
schema=Schema.MBO,
525+
stype_in=SType.RAW_SYMBOL,
526+
symbols="TEST1",
527+
)
528+
529+
live_client.add_callback(cb_sub)
530+
531+
# Act
532+
first_sub = mock_live_server.get_message_of_type(
533+
gateway.SubscriptionRequest,
534+
timeout=1,
535+
)
536+
537+
live_client.start()
538+
539+
second_sub = mock_live_server.get_message_of_type(
540+
gateway.SubscriptionRequest,
541+
timeout=1,
542+
)
543+
544+
await live_client.wait_for_close()
545+
546+
# Assert
547+
assert first_sub.symbols == "TEST0"
548+
assert second_sub.symbols == "TEST1"
549+
550+
506551
@pytest.mark.usefixtures("mock_live_server")
507552
def test_live_stop(
508553
live_client: client.Live,

0 commit comments

Comments
 (0)