@@ -877,43 +877,7 @@ def extract_bedrock_converse_attrs(kwargs, response, response_headers, model, sp
877877 return bedrock_attrs
878878
879879
880- class EventStreamWrapper (ObjectProxy ):
881- def __iter__ (self ):
882- g = GeneratorProxy (self .__wrapped__ .__iter__ ())
883- g ._nr_ft = getattr (self , "_nr_ft" , None )
884- g ._nr_bedrock_attrs = getattr (self , "_nr_bedrock_attrs" , {})
885- g ._nr_model_extractor = getattr (self , "_nr_model_extractor" , NULL_EXTRACTOR )
886- g ._nr_is_converse = getattr (self , "_nr_is_converse" , False )
887- return g
888-
889-
890- class GeneratorProxy (ObjectProxy ):
891- def __init__ (self , wrapped ):
892- super ().__init__ (wrapped )
893-
894- def __iter__ (self ):
895- return self
896-
897- def __next__ (self ):
898- transaction = current_transaction ()
899- if not transaction :
900- return self .__wrapped__ .__next__ ()
901-
902- return_val = None
903- try :
904- return_val = self .__wrapped__ .__next__ ()
905- self .record_stream_chunk (return_val , transaction )
906- except StopIteration :
907- self .record_events_on_stop_iteration (transaction )
908- raise
909- except Exception as exc :
910- self .record_error (transaction , exc )
911- raise
912- return return_val
913-
914- def close (self ):
915- return super ().close ()
916-
880+ class BedrockRecordEventMixin :
917881 def record_events_on_stop_iteration (self , transaction ):
918882 if hasattr (self , "_nr_ft" ):
919883 bedrock_attrs = getattr (self , "_nr_bedrock_attrs" , {})
@@ -1002,6 +966,44 @@ def converse_record_stream_chunk(self, event, transaction):
1002966 # self.record_events_on_stop_iteration(transaction)
1003967
1004968
969+ class EventStreamWrapper (ObjectProxy ):
970+ def __iter__ (self ):
971+ g = GeneratorProxy (self .__wrapped__ .__iter__ ())
972+ g ._nr_ft = getattr (self , "_nr_ft" , None )
973+ g ._nr_bedrock_attrs = getattr (self , "_nr_bedrock_attrs" , {})
974+ g ._nr_model_extractor = getattr (self , "_nr_model_extractor" , NULL_EXTRACTOR )
975+ g ._nr_is_converse = getattr (self , "_nr_is_converse" , False )
976+ return g
977+
978+
979+ class GeneratorProxy (BedrockRecordEventMixin , ObjectProxy ):
980+ def __init__ (self , wrapped ):
981+ super ().__init__ (wrapped )
982+
983+ def __iter__ (self ):
984+ return self
985+
986+ def __next__ (self ):
987+ transaction = current_transaction ()
988+ if not transaction :
989+ return self .__wrapped__ .__next__ ()
990+
991+ return_val = None
992+ try :
993+ return_val = self .__wrapped__ .__next__ ()
994+ self .record_stream_chunk (return_val , transaction )
995+ except StopIteration :
996+ self .record_events_on_stop_iteration (transaction )
997+ raise
998+ except Exception as exc :
999+ self .record_error (transaction , exc )
1000+ raise
1001+ return return_val
1002+
1003+ def close (self ):
1004+ return super ().close ()
1005+
1006+
10051007class AsyncEventStreamWrapper (ObjectProxy ):
10061008 def __aiter__ (self ):
10071009 g = AsyncGeneratorProxy (self .__wrapped__ .__aiter__ ())
@@ -1012,13 +1014,7 @@ def __aiter__(self):
10121014 return g
10131015
10141016
1015- class AsyncGeneratorProxy (ObjectProxy ):
1016- # Import these methods from the synchronous GeneratorProxy
1017- # Avoid direct inheritance so we don't implement both __iter__ and __aiter__
1018- record_stream_chunk = GeneratorProxy .record_stream_chunk
1019- record_events_on_stop_iteration = GeneratorProxy .record_events_on_stop_iteration
1020- record_error = GeneratorProxy .record_error
1021-
1017+ class AsyncGeneratorProxy (BedrockRecordEventMixin , ObjectProxy ):
10221018 def __aiter__ (self ):
10231019 return self
10241020
0 commit comments