Skip to content

Commit 8bb77b0

Browse files
committed
MOD: Live client to start session on iteration
1 parent 76f6749 commit 8bb77b0

File tree

4 files changed

+98
-86
lines changed

4 files changed

+98
-86
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#### Enhancements
66
- Added `symbology_map` property to `Live` client
77
- Changed `Live.add_callback` and `Live.add_stream` to accept an exception callback
8+
- Changed `Live.__iter__()` and `Live.__aiter__()` to send the session start message if the session is connected but not started
89
- Upgraded `databento-dbn` to 0.7.1
910
- Removed `Encoding`, `Compression`, `Schema`, and `SType` enums as they are now exposed by `databento-dbn`
1011

databento/live/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,15 @@ async def __anext__(self) -> DBNRecord:
128128
def __iter__(self) -> Live:
129129
logger.debug("starting iteration")
130130
self._dbn_queue._enabled.set()
131+
if not self._session.is_started() and self.is_connected():
132+
self.start()
131133
return self
132134

133135
def __next__(self) -> DBNRecord:
134136
if self._dbn_queue is None:
135137
raise ValueError("iteration has not started")
136138

137-
while not self._session.is_disconnected() or self._dbn_queue._qsize() > 0:
139+
while not self._session.is_disconnected() or self._dbn_queue.qsize() > 0:
138140
try:
139141
record = self._dbn_queue.get(block=False)
140142
except queue.Empty:

databento/live/session.py

Lines changed: 88 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def __init__(
258258
port: int = DEFAULT_REMOTE_PORT,
259259
ts_out: bool = False,
260260
) -> None:
261+
self._lock = threading.RLock()
261262
self._loop = loop
262263
self._ts_out = ts_out
263264
self._protocol_factory = protocol_factory
@@ -277,13 +278,14 @@ def is_authenticated(self) -> bool:
277278
bool
278279
279280
"""
280-
if self._protocol is None:
281-
return False
282-
try:
283-
self._protocol.authenticated.result()
284-
except (asyncio.InvalidStateError, asyncio.CancelledError, BentoError):
285-
return False
286-
return True
281+
with self._lock:
282+
if self._protocol is None:
283+
return False
284+
try:
285+
self._protocol.authenticated.result()
286+
except (asyncio.InvalidStateError, asyncio.CancelledError, BentoError):
287+
return False
288+
return True
287289

288290
def is_disconnected(self) -> bool:
289291
"""
@@ -294,9 +296,10 @@ def is_disconnected(self) -> bool:
294296
bool
295297
296298
"""
297-
if self._protocol is None:
298-
return True
299-
return self._protocol.disconnected.done()
299+
with self._lock:
300+
if self._protocol is None:
301+
return True
302+
return self._protocol.disconnected.done()
300303

301304
def is_reading(self) -> bool:
302305
"""
@@ -307,9 +310,10 @@ def is_reading(self) -> bool:
307310
bool
308311
309312
"""
310-
if self._transport is None:
311-
return False
312-
return self._transport.is_reading()
313+
with self._lock:
314+
if self._transport is None:
315+
return False
316+
return self._transport.is_reading()
313317

314318
def is_started(self) -> bool:
315319
"""
@@ -320,9 +324,10 @@ def is_started(self) -> bool:
320324
bool
321325
322326
"""
323-
if self._protocol is None:
324-
return False
325-
return self._protocol.started.is_set()
327+
with self._lock:
328+
if self._protocol is None:
329+
return False
330+
return self._protocol.started.is_set()
326331

327332
@property
328333
def metadata(self) -> databento_dbn.Metadata | None:
@@ -334,9 +339,10 @@ def metadata(self) -> databento_dbn.Metadata | None:
334339
databento_dbn.Metadata
335340
336341
"""
337-
if self._protocol is None:
338-
return None
339-
return self._protocol._metadata.data
342+
with self._lock:
343+
if self._protocol is None:
344+
return None
345+
return self._protocol._metadata.data
340346

341347
def abort(self) -> None:
342348
"""
@@ -347,20 +353,22 @@ def abort(self) -> None:
347353
Session.close
348354
349355
"""
350-
if self._transport is None:
351-
return
352-
self._transport.abort()
353-
self._protocol = None
356+
with self._lock:
357+
if self._transport is None:
358+
return
359+
self._transport.abort()
360+
self._protocol = None
354361

355362
def close(self) -> None:
356363
"""
357364
Close the current connection.
358365
"""
359-
if self._transport is None:
360-
return
361-
if self._transport.can_write_eof():
362-
self._loop.call_soon_threadsafe(self._transport.write_eof)
363-
self._loop.call_soon_threadsafe(self._transport.close)
366+
with self._lock:
367+
if self._transport is None:
368+
return
369+
if self._transport.can_write_eof():
370+
self._loop.call_soon_threadsafe(self._transport.write_eof)
371+
self._loop.call_soon_threadsafe(self._transport.close)
364372

365373
def subscribe(
366374
self,
@@ -389,27 +397,29 @@ def subscribe(
389397
within 24 hours.
390398
391399
"""
392-
if self._protocol is None:
393-
self._connect(
394-
dataset=dataset,
395-
port=self._port,
396-
loop=self._loop,
397-
)
400+
with self._lock:
401+
if self._protocol is None:
402+
self._connect(
403+
dataset=dataset,
404+
port=self._port,
405+
loop=self._loop,
406+
)
398407

399-
self._protocol.subscribe(
400-
schema=schema,
401-
symbols=symbols,
402-
stype_in=stype_in,
403-
start=start,
404-
)
408+
self._protocol.subscribe(
409+
schema=schema,
410+
symbols=symbols,
411+
stype_in=stype_in,
412+
start=start,
413+
)
405414

406415
def resume_reading(self) -> None:
407416
"""
408417
Resume reading from the connection.
409418
"""
410-
if self._transport is None:
411-
return
412-
self._loop.call_soon_threadsafe(self._transport.resume_reading)
419+
with self._lock:
420+
if self._transport is None:
421+
return
422+
self._loop.call_soon_threadsafe(self._transport.resume_reading)
413423

414424
def start(self) -> None:
415425
"""
@@ -421,9 +431,10 @@ def start(self) -> None:
421431
If there is no connection.
422432
423433
"""
424-
if self._protocol is None:
425-
raise ValueError("session is not connected")
426-
self._protocol.start()
434+
with self._lock:
435+
if self._protocol is None:
436+
raise ValueError("session is not connected")
437+
self._protocol.start()
427438

428439
async def wait_for_close(self) -> None:
429440
"""
@@ -433,44 +444,52 @@ async def wait_for_close(self) -> None:
433444
if self._protocol is None:
434445
return
435446

447+
await self._protocol.authenticated
436448
await self._protocol.disconnected
437-
disconnect_exc = self._protocol.disconnected.exception()
438-
439449
await self._protocol.wait_for_processing()
440-
self._protocol = self._transport = None
441450

442-
if disconnect_exc is not None:
443-
raise BentoError(disconnect_exc)
451+
try:
452+
self._protocol.authenticated.result()
453+
except Exception as exc:
454+
raise BentoError(exc)
455+
456+
try:
457+
self._protocol.disconnected.result()
458+
except Exception as exc:
459+
raise BentoError(exc)
460+
461+
self._protocol = self._transport = None
444462

445463
def _connect(
446464
self,
447465
dataset: Dataset | str,
448466
port: int,
449467
loop: asyncio.AbstractEventLoop,
450468
) -> None:
451-
if self._user_gateway is None:
452-
subdomain = dataset.lower().replace(".", "-")
453-
gateway = f"{subdomain}.lsg.databento.com"
454-
logger.debug("using default gateway for dataset %s", dataset)
455-
else:
456-
gateway = self._user_gateway
457-
logger.debug("using user specified gateway: %s", gateway)
469+
with self._lock:
470+
if not self.is_disconnected():
471+
return
472+
if self._user_gateway is None:
473+
subdomain = dataset.lower().replace(".", "-")
474+
gateway = f"{subdomain}.lsg.databento.com"
475+
logger.debug("using default gateway for dataset %s", dataset)
476+
else:
477+
gateway = self._user_gateway
478+
logger.debug("using user specified gateway: %s", gateway)
458479

459-
asyncio.run_coroutine_threadsafe(
460-
coro=self._connect_task(
461-
gateway=gateway,
462-
port=port,
463-
),
464-
loop=loop,
465-
).result()
480+
self._transport, self._protocol = asyncio.run_coroutine_threadsafe(
481+
coro=self._connect_task(
482+
gateway=gateway,
483+
port=port,
484+
),
485+
loop=loop,
486+
).result()
466487

467488
async def _connect_task(
468489
self,
469490
gateway: str,
470491
port: int,
471-
) -> None:
472-
if not self.is_disconnected():
473-
return
492+
) -> tuple[asyncio.Transport, _SessionProtocol]:
474493
logger.info("connecting to remote gateway")
475494
try:
476495
transport, protocol = await asyncio.wait_for(
@@ -514,5 +533,4 @@ async def _connect_task(
514533
"authentication with remote gateway completed",
515534
)
516535

517-
self._transport = transport
518-
self._protocol = protocol
536+
return transport, protocol

tests/test_live_client.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,6 @@ async def test_live_async_iteration(
547547
symbols="TEST",
548548
)
549549

550-
live_client.start()
551-
552550
records: list[DBNRecord] = []
553551
async for record in live_client:
554552
records.append(record)
@@ -591,13 +589,12 @@ async def test_live_async_iteration_backpressure(
591589
pause_mock := MagicMock(),
592590
)
593591

594-
live_client.start()
595-
it = live_client.__iter__()
592+
live_it = iter(live_client)
596593
await live_client.wait_for_close()
597594

598-
assert pause_mock.called
595+
pause_mock.assert_called()
599596

600-
records = list(it)
597+
records: list[DBNRecord] = list(live_it)
601598
assert len(records) == 4
602599
assert live_client._dbn_queue.empty()
603600

@@ -632,13 +629,12 @@ async def test_live_async_iteration_dropped(
632629
pause_mock := MagicMock(),
633630
)
634631

635-
live_client.start()
636-
it = live_client.__iter__()
632+
live_it = iter(live_client)
637633
await live_client.wait_for_close()
638634

639-
assert pause_mock.called
635+
pause_mock.assert_called()
640636

641-
records = list(it)
637+
records = list(live_it)
642638
assert len(records) == 1
643639
assert live_client._dbn_queue.empty()
644640

@@ -658,8 +654,6 @@ async def test_live_async_iteration_stop(
658654
symbols="TEST",
659655
)
660656

661-
live_client.start()
662-
663657
records = []
664658
async for record in live_client:
665659
records.append(record)
@@ -683,8 +677,6 @@ def test_live_sync_iteration(
683677
symbols="TEST",
684678
)
685679

686-
live_client.start()
687-
688680
records = []
689681
for record in live_client:
690682
records.append(record)
@@ -918,7 +910,6 @@ async def test_live_iteration_with_reconnect(
918910
assert live_client.is_connected()
919911
assert live_client.dataset == Dataset.GLBX_MDP3
920912

921-
live_client.start()
922913
my_iter = iter(live_client)
923914

924915
await live_client.wait_for_close()

0 commit comments

Comments
 (0)