@@ -877,43 +877,7 @@ def extract_bedrock_converse_attrs(kwargs, response, response_headers, model, sp
877877    return  bedrock_attrs 
878878
879879
880- class  EventStreamWrapper (ObjectProxy ):
881-     def  __iter__ (self ):
882-         g  =  GeneratorProxy (self .__wrapped__ .__iter__ ())
883-         g ._nr_ft  =  getattr (self , "_nr_ft" , None )
884-         g ._nr_bedrock_attrs  =  getattr (self , "_nr_bedrock_attrs" , {})
885-         g ._nr_model_extractor  =  getattr (self , "_nr_model_extractor" , NULL_EXTRACTOR )
886-         g ._nr_is_converse  =  getattr (self , "_nr_is_converse" , False )
887-         return  g 
888- 
889- 
890- class  GeneratorProxy (ObjectProxy ):
891-     def  __init__ (self , wrapped ):
892-         super ().__init__ (wrapped )
893- 
894-     def  __iter__ (self ):
895-         return  self 
896- 
897-     def  __next__ (self ):
898-         transaction  =  current_transaction ()
899-         if  not  transaction :
900-             return  self .__wrapped__ .__next__ ()
901- 
902-         return_val  =  None 
903-         try :
904-             return_val  =  self .__wrapped__ .__next__ ()
905-             self .record_stream_chunk (return_val , transaction )
906-         except  StopIteration :
907-             self .record_events_on_stop_iteration (transaction )
908-             raise 
909-         except  Exception  as  exc :
910-             self .record_error (transaction , exc )
911-             raise 
912-         return  return_val 
913- 
914-     def  close (self ):
915-         return  super ().close ()
916- 
880+ class  BedrockRecordEventMixin :
917881    def  record_events_on_stop_iteration (self , transaction ):
918882        if  hasattr (self , "_nr_ft" ):
919883            bedrock_attrs  =  getattr (self , "_nr_bedrock_attrs" , {})
@@ -1002,6 +966,44 @@ def converse_record_stream_chunk(self, event, transaction):
1002966        #     self.record_events_on_stop_iteration(transaction) 
1003967
1004968
969+ class  EventStreamWrapper (ObjectProxy ):
970+     def  __iter__ (self ):
971+         g  =  GeneratorProxy (self .__wrapped__ .__iter__ ())
972+         g ._nr_ft  =  getattr (self , "_nr_ft" , None )
973+         g ._nr_bedrock_attrs  =  getattr (self , "_nr_bedrock_attrs" , {})
974+         g ._nr_model_extractor  =  getattr (self , "_nr_model_extractor" , NULL_EXTRACTOR )
975+         g ._nr_is_converse  =  getattr (self , "_nr_is_converse" , False )
976+         return  g 
977+ 
978+ 
979+ class  GeneratorProxy (BedrockRecordEventMixin , ObjectProxy ):
980+     def  __init__ (self , wrapped ):
981+         super ().__init__ (wrapped )
982+ 
983+     def  __iter__ (self ):
984+         return  self 
985+ 
986+     def  __next__ (self ):
987+         transaction  =  current_transaction ()
988+         if  not  transaction :
989+             return  self .__wrapped__ .__next__ ()
990+ 
991+         return_val  =  None 
992+         try :
993+             return_val  =  self .__wrapped__ .__next__ ()
994+             self .record_stream_chunk (return_val , transaction )
995+         except  StopIteration :
996+             self .record_events_on_stop_iteration (transaction )
997+             raise 
998+         except  Exception  as  exc :
999+             self .record_error (transaction , exc )
1000+             raise 
1001+         return  return_val 
1002+ 
1003+     def  close (self ):
1004+         return  super ().close ()
1005+ 
1006+ 
10051007class  AsyncEventStreamWrapper (ObjectProxy ):
10061008    def  __aiter__ (self ):
10071009        g  =  AsyncGeneratorProxy (self .__wrapped__ .__aiter__ ())
@@ -1012,13 +1014,7 @@ def __aiter__(self):
10121014        return  g 
10131015
10141016
1015- class  AsyncGeneratorProxy (ObjectProxy ):
1016-     # Import these methods from the synchronous GeneratorProxy 
1017-     # Avoid direct inheritance so we don't implement both __iter__ and __aiter__ 
1018-     record_stream_chunk  =  GeneratorProxy .record_stream_chunk 
1019-     record_events_on_stop_iteration  =  GeneratorProxy .record_events_on_stop_iteration 
1020-     record_error  =  GeneratorProxy .record_error 
1021- 
1017+ class  AsyncGeneratorProxy (BedrockRecordEventMixin , ObjectProxy ):
10221018    def  __aiter__ (self ):
10231019        return  self 
10241020
0 commit comments