@@ -258,6 +258,7 @@ def __init__(
258258 port : int = DEFAULT_REMOTE_PORT ,
259259 ts_out : bool = False ,
260260 ) -> None :
261+ self ._lock = threading .RLock ()
261262 self ._loop = loop
262263 self ._ts_out = ts_out
263264 self ._protocol_factory = protocol_factory
@@ -277,13 +278,14 @@ def is_authenticated(self) -> bool:
277278 bool
278279
279280 """
280- if self ._protocol is None :
281- return False
282- try :
283- self ._protocol .authenticated .result ()
284- except (asyncio .InvalidStateError , asyncio .CancelledError , BentoError ):
285- return False
286- return True
281+ with self ._lock :
282+ if self ._protocol is None :
283+ return False
284+ try :
285+ self ._protocol .authenticated .result ()
286+ except (asyncio .InvalidStateError , asyncio .CancelledError , BentoError ):
287+ return False
288+ return True
287289
288290 def is_disconnected (self ) -> bool :
289291 """
@@ -294,9 +296,10 @@ def is_disconnected(self) -> bool:
294296 bool
295297
296298 """
297- if self ._protocol is None :
298- return True
299- return self ._protocol .disconnected .done ()
299+ with self ._lock :
300+ if self ._protocol is None :
301+ return True
302+ return self ._protocol .disconnected .done ()
300303
301304 def is_reading (self ) -> bool :
302305 """
@@ -307,9 +310,10 @@ def is_reading(self) -> bool:
307310 bool
308311
309312 """
310- if self ._transport is None :
311- return False
312- return self ._transport .is_reading ()
313+ with self ._lock :
314+ if self ._transport is None :
315+ return False
316+ return self ._transport .is_reading ()
313317
314318 def is_started (self ) -> bool :
315319 """
@@ -320,9 +324,10 @@ def is_started(self) -> bool:
320324 bool
321325
322326 """
323- if self ._protocol is None :
324- return False
325- return self ._protocol .started .is_set ()
327+ with self ._lock :
328+ if self ._protocol is None :
329+ return False
330+ return self ._protocol .started .is_set ()
326331
327332 @property
328333 def metadata (self ) -> databento_dbn .Metadata | None :
@@ -334,9 +339,10 @@ def metadata(self) -> databento_dbn.Metadata | None:
334339 databento_dbn.Metadata
335340
336341 """
337- if self ._protocol is None :
338- return None
339- return self ._protocol ._metadata .data
342+ with self ._lock :
343+ if self ._protocol is None :
344+ return None
345+ return self ._protocol ._metadata .data
340346
341347 def abort (self ) -> None :
342348 """
@@ -347,20 +353,22 @@ def abort(self) -> None:
347353 Session.close
348354
349355 """
350- if self ._transport is None :
351- return
352- self ._transport .abort ()
353- self ._protocol = None
356+ with self ._lock :
357+ if self ._transport is None :
358+ return
359+ self ._transport .abort ()
360+ self ._protocol = None
354361
355362 def close (self ) -> None :
356363 """
357364 Close the current connection.
358365 """
359- if self ._transport is None :
360- return
361- if self ._transport .can_write_eof ():
362- self ._loop .call_soon_threadsafe (self ._transport .write_eof )
363- self ._loop .call_soon_threadsafe (self ._transport .close )
366+ with self ._lock :
367+ if self ._transport is None :
368+ return
369+ if self ._transport .can_write_eof ():
370+ self ._loop .call_soon_threadsafe (self ._transport .write_eof )
371+ self ._loop .call_soon_threadsafe (self ._transport .close )
364372
365373 def subscribe (
366374 self ,
@@ -389,27 +397,29 @@ def subscribe(
389397 within 24 hours.
390398
391399 """
392- if self ._protocol is None :
393- self ._connect (
394- dataset = dataset ,
395- port = self ._port ,
396- loop = self ._loop ,
397- )
400+ with self ._lock :
401+ if self ._protocol is None :
402+ self ._connect (
403+ dataset = dataset ,
404+ port = self ._port ,
405+ loop = self ._loop ,
406+ )
398407
399- self ._protocol .subscribe (
400- schema = schema ,
401- symbols = symbols ,
402- stype_in = stype_in ,
403- start = start ,
404- )
408+ self ._protocol .subscribe (
409+ schema = schema ,
410+ symbols = symbols ,
411+ stype_in = stype_in ,
412+ start = start ,
413+ )
405414
406415 def resume_reading (self ) -> None :
407416 """
408417 Resume reading from the connection.
409418 """
410- if self ._transport is None :
411- return
412- self ._loop .call_soon_threadsafe (self ._transport .resume_reading )
419+ with self ._lock :
420+ if self ._transport is None :
421+ return
422+ self ._loop .call_soon_threadsafe (self ._transport .resume_reading )
413423
414424 def start (self ) -> None :
415425 """
@@ -421,9 +431,10 @@ def start(self) -> None:
421431 If there is no connection.
422432
423433 """
424- if self ._protocol is None :
425- raise ValueError ("session is not connected" )
426- self ._protocol .start ()
434+ with self ._lock :
435+ if self ._protocol is None :
436+ raise ValueError ("session is not connected" )
437+ self ._protocol .start ()
427438
428439 async def wait_for_close (self ) -> None :
429440 """
@@ -433,44 +444,52 @@ async def wait_for_close(self) -> None:
433444 if self ._protocol is None :
434445 return
435446
447+ await self ._protocol .authenticated
436448 await self ._protocol .disconnected
437- disconnect_exc = self ._protocol .disconnected .exception ()
438-
439449 await self ._protocol .wait_for_processing ()
440- self ._protocol = self ._transport = None
441450
442- if disconnect_exc is not None :
443- raise BentoError (disconnect_exc )
451+ try :
452+ self ._protocol .authenticated .result ()
453+ except Exception as exc :
454+ raise BentoError (exc )
455+
456+ try :
457+ self ._protocol .disconnected .result ()
458+ except Exception as exc :
459+ raise BentoError (exc )
460+
461+ self ._protocol = self ._transport = None
444462
445463 def _connect (
446464 self ,
447465 dataset : Dataset | str ,
448466 port : int ,
449467 loop : asyncio .AbstractEventLoop ,
450468 ) -> None :
451- if self ._user_gateway is None :
452- subdomain = dataset .lower ().replace ("." , "-" )
453- gateway = f"{ subdomain } .lsg.databento.com"
454- logger .debug ("using default gateway for dataset %s" , dataset )
455- else :
456- gateway = self ._user_gateway
457- logger .debug ("using user specified gateway: %s" , gateway )
469+ with self ._lock :
470+ if not self .is_disconnected ():
471+ return
472+ if self ._user_gateway is None :
473+ subdomain = dataset .lower ().replace ("." , "-" )
474+ gateway = f"{ subdomain } .lsg.databento.com"
475+ logger .debug ("using default gateway for dataset %s" , dataset )
476+ else :
477+ gateway = self ._user_gateway
478+ logger .debug ("using user specified gateway: %s" , gateway )
458479
459- asyncio .run_coroutine_threadsafe (
460- coro = self ._connect_task (
461- gateway = gateway ,
462- port = port ,
463- ),
464- loop = loop ,
465- ).result ()
480+ self . _transport , self . _protocol = asyncio .run_coroutine_threadsafe (
481+ coro = self ._connect_task (
482+ gateway = gateway ,
483+ port = port ,
484+ ),
485+ loop = loop ,
486+ ).result ()
466487
467488 async def _connect_task (
468489 self ,
469490 gateway : str ,
470491 port : int ,
471- ) -> None :
472- if not self .is_disconnected ():
473- return
492+ ) -> tuple [asyncio .Transport , _SessionProtocol ]:
474493 logger .info ("connecting to remote gateway" )
475494 try :
476495 transport , protocol = await asyncio .wait_for (
@@ -514,5 +533,4 @@ async def _connect_task(
514533 "authentication with remote gateway completed" ,
515534 )
516535
517- self ._transport = transport
518- self ._protocol = protocol
536+ return transport , protocol
0 commit comments