1313# limitations under the License.
1414
1515import collections
16+ import logging
1617import threading
1718
1819try :
1920 from newrelic .core .infinite_tracing_pb2 import AttributeValue
2021except :
2122 AttributeValue = None
2223
24+ _logger = logging .getLogger (__name__ )
2325
24- class StreamBuffer (object ):
2526
27+ class StreamBuffer (object ):
2628 def __init__ (self , maxlen ):
2729 self ._queue = collections .deque (maxlen = maxlen )
2830 self ._notify = self .condition ()
@@ -64,18 +66,46 @@ def stats(self):
6466
6567 return seen , dropped
6668
67- def __next__ (self ):
68- while True :
69- if self ._shutdown :
70- raise StopIteration
69+ def __iter__ (self ):
70+ return StreamBufferIterator (self )
71+
7172
72- try :
73- return self ._queue .popleft ()
74- except IndexError :
75- pass
73+ class StreamBufferIterator (object ):
74+ def __init__ (self , stream_buffer ):
75+ self .stream_buffer = stream_buffer
76+ self ._notify = self .stream_buffer ._notify
77+ self ._shutdown = False
78+ self ._stream = None
7679
77- with self ._notify :
78- if not self ._shutdown and not self ._queue :
80+ def shutdown (self ):
81+ with self ._notify :
82+ self ._shutdown = True
83+ self ._notify .notify_all ()
84+
85+ def stream_closed (self ):
86+ return self ._shutdown or self .stream_buffer ._shutdown or (self ._stream and self ._stream .done ())
87+
88+ def __next__ (self ):
89+ with self ._notify :
90+ while True :
91+ # When a gRPC stream receives a server side disconnect (usually in the form of an OK code)
92+ # the item it is waiting to consume from the iterator will not be sent, and will inevitably
93+ # be lost. To prevent this, StopIteration is raised by shutting down the iterator and
94+ # notifying to allow the thread to exit. Iterators cannot be reused or race conditions may
95+ # occur between iterator shutdown and restart, so a new iterator must be created from the
96+ # streaming buffer.
97+ if self .stream_closed ():
98+ _logger .debug ("gRPC stream is closed. Shutting down and refusing to iterate." )
99+ if not self ._shutdown :
100+ self .shutdown ()
101+ raise StopIteration
102+
103+ try :
104+ return self .stream_buffer ._queue .popleft ()
105+ except IndexError :
106+ pass
107+
108+ if not self .stream_closed () and not self .stream_buffer ._queue :
79109 self ._notify .wait ()
80110
81111 next = __next__
@@ -90,10 +120,8 @@ def __init__(self, *args, **kwargs):
90120 if args :
91121 arg = args [0 ]
92122 if len (args ) > 1 :
93- raise TypeError (
94- "SpanProtoAttrs expected at most 1 argument, got %d" ,
95- len (args ))
96- elif hasattr (arg , 'keys' ):
123+ raise TypeError ("SpanProtoAttrs expected at most 1 argument, got %d" , len (args ))
124+ elif hasattr (arg , "keys" ):
97125 for k in arg :
98126 self [k ] = arg [k ]
99127 else :
@@ -104,8 +132,7 @@ def __init__(self, *args, **kwargs):
104132 self [k ] = kwargs [k ]
105133
106134 def __setitem__ (self , key , value ):
107- super (SpanProtoAttrs , self ).__setitem__ (key ,
108- SpanProtoAttrs .get_attribute_value (value ))
135+ super (SpanProtoAttrs , self ).__setitem__ (key , SpanProtoAttrs .get_attribute_value (value ))
109136
110137 def copy (self ):
111138 copy = SpanProtoAttrs ()
0 commit comments