@@ -115,78 +115,17 @@ def factory() -> _SessionProtocol:
115115 if not Live ._thread .is_alive ():
116116 Live ._thread .start ()
117117
118- def __aiter__ (self ) -> Live :
118+ def __aiter__ (self ) -> LiveIterator :
119119 return iter (self )
120120
121- async def __anext__ (self ) -> DBNRecord :
122- if not self ._dbn_queue .is_enabled ():
123- raise ValueError ("iteration has not started" )
124-
125- loop = asyncio .get_running_loop ()
126-
127- try :
128- return self ._dbn_queue .get_nowait ()
129- except queue .Empty :
130- while True :
131- try :
132- return await loop .run_in_executor (
133- None ,
134- self ._dbn_queue .get ,
135- True ,
136- 0.1 ,
137- )
138- except queue .Empty :
139- if self ._session .is_disconnected ():
140- break
141- finally :
142- if not self ._dbn_queue .is_full () and not self ._session .is_reading ():
143- logger .debug (
144- "resuming reading with %d pending records" ,
145- self ._dbn_queue .qsize (),
146- )
147- self ._session .resume_reading ()
148-
149- self ._dbn_queue .disable ()
150- await self .wait_for_close ()
151- logger .debug ("completed async iteration" )
152- raise StopAsyncIteration
153-
154- def __iter__ (self ) -> Live :
121+ def __iter__ (self ) -> LiveIterator :
155122 logger .debug ("starting iteration" )
156123 if self ._session .is_started ():
157124 logger .error ("iteration started after session has started" )
158125 raise ValueError (
159126 "Cannot start iteration after streaming has started, records may be missed. Don't call `Live.start` before iterating." ,
160127 )
161- elif self .is_connected ():
162- self .start ()
163- self ._dbn_queue ._enabled .set ()
164- return self
165-
166- def __next__ (self ) -> DBNRecord :
167- if not self ._dbn_queue .is_enabled ():
168- raise ValueError ("iteration has not started" )
169-
170- while True :
171- try :
172- record = self ._dbn_queue .get (timeout = 0.1 )
173- except queue .Empty :
174- if self ._session .is_disconnected ():
175- break
176- else :
177- return record
178- finally :
179- if not self ._dbn_queue .is_full () and not self ._session .is_reading ():
180- logger .debug (
181- "resuming reading with %d pending records" ,
182- self ._dbn_queue .qsize (),
183- )
184- self ._session .resume_reading ()
185-
186- self ._dbn_queue .disable ()
187- self .block_for_close ()
188- logger .debug ("completed iteration" )
189- raise StopIteration
128+ return LiveIterator (self )
190129
191130 def __repr__ (self ) -> str :
192131 name = self .__class__ .__name__
@@ -661,3 +600,93 @@ def _map_symbol(self, record: DBNRecord) -> None:
661600 instrument_id = record .instrument_id
662601 self ._symbology_map [instrument_id ] = record .stype_out_symbol
663602 logger .info ("added symbology mapping %s to %d" , out_symbol , instrument_id )
603+
604+
605+ class LiveIterator :
606+ """
607+ Iterator class for the `Live` client. Automatically starts the client when
608+ created and will stop it when destroyed. This provides context-manager-like
609+ behavior to for loops.
610+
611+ Parameters
612+ ----------
613+ client : Live
614+ The Live client that spawned this LiveIterator.
615+
616+ """
617+
618+ def __init__ (self , client : Live ):
619+ client ._dbn_queue ._enabled .set ()
620+ client .start ()
621+ self ._client = client
622+
623+ @property
624+ def client (self ) -> Live :
625+ return self ._client
626+
627+ def __iter__ (self ) -> LiveIterator :
628+ return self
629+
630+ def __del__ (self ) -> None :
631+ if self .client .is_connected ():
632+ self .client .stop ()
633+ self .client .block_for_close ()
634+ logger .debug ("iteration aborted" )
635+
636+ async def __anext__ (self ) -> DBNRecord :
637+ if not self .client ._dbn_queue .is_enabled ():
638+ raise ValueError ("iteration has not started" )
639+
640+ loop = asyncio .get_running_loop ()
641+
642+ try :
643+ return self .client ._dbn_queue .get_nowait ()
644+ except queue .Empty :
645+ while True :
646+ try :
647+ return await loop .run_in_executor (
648+ None ,
649+ self .client ._dbn_queue .get ,
650+ True ,
651+ 0.1 ,
652+ )
653+ except queue .Empty :
654+ if self .client ._session .is_disconnected ():
655+ break
656+ finally :
657+ if not self .client ._dbn_queue .is_full () and not self .client ._session .is_reading ():
658+ logger .debug (
659+ "resuming reading with %d pending records" ,
660+ self .client ._dbn_queue .qsize (),
661+ )
662+ self .client ._session .resume_reading ()
663+
664+ self .client ._dbn_queue .disable ()
665+ await self .client .wait_for_close ()
666+ logger .debug ("async iteration completed" )
667+ raise StopAsyncIteration
668+
669+ def __next__ (self ) -> DBNRecord :
670+ if not self .client ._dbn_queue .is_enabled ():
671+ raise ValueError ("iteration has not started" )
672+
673+ while True :
674+ try :
675+ record = self .client ._dbn_queue .get (timeout = 0.1 )
676+ except queue .Empty :
677+ if self .client ._session .is_disconnected ():
678+ break
679+ else :
680+ return record
681+ finally :
682+ if not self .client ._dbn_queue .is_full () and not self .client ._session .is_reading ():
683+ logger .debug (
684+ "resuming reading with %d pending records" ,
685+ self .client ._dbn_queue .qsize (),
686+ )
687+ self .client ._session .resume_reading ()
688+
689+ self .client ._dbn_queue .disable ()
690+ self .client .block_for_close ()
691+ logger .debug ("iteration completed" )
692+ raise StopIteration
0 commit comments