99from datetime import datetime
1010from datetime import timezone
1111from decimal import Decimal
12+ from functools import partial
1213from typing import Any
1314from typing import AsyncGenerator
1415from typing import AsyncIterator
1920from typing import Dict
2021from typing import Generator
2122from typing import List
23+ from typing import NamedTuple
2224from typing import NoReturn
2325from typing import Optional
2426from typing import Set
2527from typing import Tuple
28+ from typing import Union
2629from typing import cast
2730
2831from pysignalr .client import SignalRClient
5154from dipdup .models import HeadBlockData
5255from dipdup .models import OperationData
5356from dipdup .models import QuoteData
57+ from dipdup .utils import FormattedLogger
5458from dipdup .utils import split_by_chunks
5559from dipdup .utils .watchdog import Watchdog
5660
@@ -258,6 +262,55 @@ async def fetch_big_maps_by_level(self) -> AsyncGenerator[Tuple[int, Tuple[BigMa
258262 yield big_maps [0 ].level , big_maps
259263
260264
265+ MessageData = Union [Dict [str , Any ], List [Dict [str , Any ]]]
266+
267+
268+ class BufferedMessage (NamedTuple ):
269+ type : MessageType
270+ data : MessageData
271+
272+
273+ class MessageBuffer :
274+ """Buffers realtime TzKT messages and yields them in by level."""
275+
276+ def __init__ (self , size : int ) -> None :
277+ self ._logger = logging .getLogger ('dipdup.tzkt' )
278+ self ._size = size
279+ self ._messages : DefaultDict [int , List [BufferedMessage ]] = defaultdict (list )
280+
281+ def add (self , type_ : MessageType , level : int , data : MessageData ) -> None :
282+ """Add a message to the buffer."""
283+ self ._messages [level ].append (BufferedMessage (type_ , data ))
284+
285+ def rollback (self , type_ : MessageType , channel_level : int , message_level : int ) -> bool :
286+ """Drop buffered messages in reversed order while possible, return if successful."""
287+ # NOTE: No action required for this channel
288+ if type_ == MessageType .head :
289+ return True
290+
291+ # NOTE: This rollback does not affect us, so we can safely ignore it
292+ if channel_level <= message_level :
293+ return True
294+
295+ self ._logger .info ('Rollback requested from %s to %s' , type_ .value , channel_level , message_level )
296+ levels = range (channel_level , message_level , - 1 )
297+ for level in levels :
298+ if not self ._messages .pop (level , None ):
299+ self ._logger .info ('Level %s is not buffered, can\' t avoid rollback' , level )
300+ return False
301+
302+ self ._logger .info ('All rolled back levels are buffered, no action required' )
303+ return True
304+
305+ def yield_from (self ) -> Generator [BufferedMessage , None , None ]:
306+ """Yield extensively buffered messages by level"""
307+ buffered_levels = sorted (self ._messages .keys ())
308+ yielded_levels = buffered_levels [: len (buffered_levels ) - self ._size ]
309+ for level in yielded_levels :
310+ for buffered_message in self ._messages .pop (level ):
311+ yield buffered_message
312+
313+
261314class TzktDatasource (IndexDatasource ):
262315 _default_http_config = HTTPConfig (
263316 cache = True ,
@@ -284,8 +337,7 @@ def __init__(
284337 )
285338 self ._logger = logging .getLogger ('dipdup.tzkt' )
286339 self ._watchdog = watchdog
287- self ._buffer_size = buffer_size
288- self ._buffer : DefaultDict [int , List [Tuple [MessageType , Dict [str , Any ]]]] = defaultdict (list )
340+ self ._buffer = MessageBuffer (buffer_size )
289341
290342 self ._ws_client : Optional [SignalRClient ] = None
291343 self ._level : DefaultDict [MessageType , Optional [int ]] = defaultdict (lambda : None )
@@ -294,6 +346,10 @@ def __init__(
294346 def request_limit (self ) -> int :
295347 return cast (int , self ._http_config .batch_size )
296348
349+ def set_logger (self , name : str ) -> None :
350+ super ().set_logger (name )
351+ self ._buffer ._logger = FormattedLogger (self ._buffer ._logger .name , name + ': {}' )
352+
297353 def get_channel_level (self , message_type : MessageType ) -> int :
298354 """Get current level of the channel, or sync level is no messages were received yet."""
299355 channel_level = self ._level [message_type ]
@@ -763,9 +819,9 @@ def _get_ws_client(self) -> SignalRClient:
763819 self ._ws_client .on_close (self ._on_disconnect )
764820 self ._ws_client .on_error (self ._on_error )
765821
766- self ._ws_client .on ('operations' , self ._on_operations_message )
767- self ._ws_client .on ('bigmaps' , self ._on_big_maps_message )
768- self ._ws_client .on ('head' , self ._on_head_message )
822+ self ._ws_client .on ('operations' , partial ( self ._on_message , MessageType . operation ) )
823+ self ._ws_client .on ('bigmaps' , partial ( self ._on_message , MessageType . big_map ) )
824+ self ._ws_client .on ('head' , partial ( self ._on_message , MessageType . head ) )
769825
770826 return self ._ws_client
771827
@@ -802,7 +858,7 @@ async def _on_error(self, message: CompletionMessage) -> NoReturn:
802858 """Raise exception from WS server's error message"""
803859 raise DatasourceError (datasource = self .name , msg = cast (str , message .error ))
804860
805- async def _extract_message_data (self , type_ : MessageType , message : List [Any ]) -> AsyncGenerator [ Dict , None ] :
861+ async def _on_message (self , type_ : MessageType , message : List [Dict [ str , Any ]] ) -> None :
806862 """Parse message received from Websocket, ensure it's correct in the current context and yield data."""
807863 # NOTE: Parse messages and either buffer or yield data
808864 for item in message :
@@ -825,98 +881,59 @@ async def _extract_message_data(self, type_: MessageType, message: List[Any]) ->
825881
826882 # NOTE: Put data messages to buffer by level
827883 if tzkt_type == TzktMessageType .DATA :
828- await self ._process_data_message (type_ , message_level , item ['data' ])
884+ self ._buffer . add (type_ , message_level , item ['data' ])
829885
830886 # NOTE: Try to process rollback automatically, emit if failed
831887 elif tzkt_type == TzktMessageType .REORG :
832- await self ._process_reorg_message (type_ , channel_level , message_level )
888+ if not self ._buffer .rollback (type_ , channel_level , message_level ):
889+ await self .emit_rollback (channel_level , message_level )
833890
834891 else :
835- raise NotImplementedError ('Unknown message type' )
836-
837- # NOTE: Yield extensive data from buffer
838- for item in self ._yield_from_buffer (type_ ):
839- yield item
840-
841- def _yield_from_buffer (self , type_ : MessageType ) -> Generator [Dict , None , None ]:
842- buffered_levels = sorted (self ._buffer .keys ())
843- if len (buffered_levels ) < self ._buffer_size :
844- return
845-
846- yielded_levels = buffered_levels [: len (buffered_levels ) - self ._buffer_size ]
847- for level in yielded_levels :
848- for idx , level_data in enumerate (self ._buffer [level ]):
849- level_message_type , level_message = level_data
850- if level_message_type == type_ :
851- yield level_message
852- self ._buffer [level ].pop (idx )
853-
854- if not self ._buffer [level ]:
855- del self ._buffer [level ]
856-
857- async def _process_data_message (self , type_ : MessageType , message_level : int , message_data : Dict [str , Any ]) -> None :
858- self ._buffer [message_level ].append ((type_ , message_data ))
859-
860- async def _process_reorg_message (self , type_ : MessageType , channel_level : int , message_level : int ) -> None :
861- # NOTE: No action required for this channel
862- if type_ == MessageType .head :
863- return
864-
865- # NOTE: This rollback does not affect us, so we can safely ignore it
866- if channel_level <= message_level :
867- return
868-
869- self ._logger .info ('Rollback requested from %s to %s' , channel_level , message_level )
870-
871- # NOTE: Drop buffered messages in reversed order while possible
872- rolled_back_levels = range (channel_level , message_level , - 1 )
873- for rolled_back_level in rolled_back_levels :
874- if self ._buffer .pop (rolled_back_level , None ):
875- self ._logger .info ('Level %s is buffered' , rolled_back_level )
892+ raise NotImplementedError (f'Unknown message type: { tzkt_type } ' )
893+
894+ # NOTE: Process extensive data from buffer
895+ for buffered_message in self ._buffer .yield_from ():
896+ if buffered_message .type == MessageType .operation :
897+ await self ._process_operations_data (cast (list , buffered_message .data ))
898+ elif buffered_message .type == MessageType .big_map :
899+ await self ._process_big_maps_data (cast (list , buffered_message .data ))
900+ elif buffered_message .type == MessageType .head :
901+ await self ._process_head_data (cast (dict , buffered_message .data ))
876902 else :
877- self ._logger .info (
878- 'Level %s is not buffered, emitting rollback to %s' ,
879- rolled_back_level ,
880- message_level ,
881- )
882- await self .emit_rollback (channel_level , message_level )
883- return
884- else :
885- self ._logger .info ('Rollback is not required, continuing' )
903+ raise NotImplementedError (f'Unknown message type: { buffered_message .type } ' )
886904
887- async def _on_operations_message (self , message : List [Dict [str , Any ]]) -> None :
905+ async def _process_operations_data (self , data : List [Dict [str , Any ]]) -> None :
888906 """Parse and emit raw operations from WS"""
889907 level_operations : DefaultDict [int , Deque [OperationData ]] = defaultdict (deque )
890- async for data in self . _extract_message_data ( MessageType . operation , message ):
891- for operation_json in data :
892- if operation_json ['status' ] != 'applied' :
893- continue
894- operation = self .convert_operation (operation_json )
895- level_operations [operation .level ].append (operation )
908+
909+ for operation_json in data :
910+ if operation_json ['status' ] != 'applied' :
911+ continue
912+ operation = self .convert_operation (operation_json )
913+ level_operations [operation .level ].append (operation )
896914
897915 for _level , operations in level_operations .items ():
898916 await self .emit_operations (tuple (operations ))
899917
900- async def _on_big_maps_message (self , message : List [Dict [str , Any ]]) -> None :
918+ async def _process_big_maps_data (self , data : List [Dict [str , Any ]]) -> None :
901919 """Parse and emit raw big map diffs from WS"""
902920 level_big_maps : DefaultDict [int , Deque [BigMapData ]] = defaultdict (deque )
903- async for data in self . _extract_message_data ( MessageType . big_map , message ):
904- big_maps : Deque [BigMapData ] = deque ()
905- for big_map_json in data :
906- big_map = self .convert_big_map (big_map_json )
907- level_big_maps [big_map .level ].append (big_map )
921+
922+ big_maps : Deque [BigMapData ] = deque ()
923+ for big_map_json in data :
924+ big_map = self .convert_big_map (big_map_json )
925+ level_big_maps [big_map .level ].append (big_map )
908926
909927 for _level , big_maps in level_big_maps .items ():
910928 await self .emit_big_maps (tuple (big_maps ))
911929
912- async def _on_head_message (self , message : List [ Dict [str , Any ] ]) -> None :
930+ async def _process_head_data (self , data : Dict [str , Any ]) -> None :
913931 """Parse and emit raw head block from WS"""
914- async for data in self ._extract_message_data (MessageType .head , message ):
915- if self ._watchdog :
916- self ._watchdog .reset ()
932+ if self ._watchdog :
933+ self ._watchdog .reset ()
917934
918- block = self .convert_head_block (data )
919- await self .emit_head (block )
935+ block = self .convert_head_block (data )
936+ await self .emit_head (block )
920937
921938 @classmethod
922939 def convert_operation (cls , operation_json : Dict [str , Any ], type_ : Optional [str ] = None ) -> OperationData :
0 commit comments