@@ -60,8 +60,11 @@ def index_mappings(self):
6060 raise NotImplementedError
6161
6262 @abc .abstractmethod
63- def build_elastic_actions (self , messages_chunk : messages .MessagesChunk ) -> typing .Iterable [tuple [int , dict ]]:
64- # yield (message_target_id, elastic_action) pairs
63+ def build_elastic_actions (
64+ self ,
65+ messages_chunk : messages .MessagesChunk ,
66+ ) -> typing .Iterable [tuple [int , dict | typing .Iterable [dict ]]]:
67+ # yield (message_target_id, [elastic_action, ...]) pairs
6568 raise NotImplementedError
6669
6770 def before_chunk (
@@ -148,10 +151,17 @@ def pls_handle_messages_chunk(self, messages_chunk):
148151 _indexname = _response_body ['_index' ]
149152 _is_done = _ok or (_op_type == 'delete' and _status == 404 )
150153 if _is_done :
151- _action_tracker .action_done (_indexname , _docid )
154+ _finished_message_id = _action_tracker .action_done (_indexname , _docid )
155+ if _finished_message_id is not None :
156+ yield messages .IndexMessageResponse (
157+ is_done = True ,
158+ index_message = messages .IndexMessage (messages_chunk .message_type , _finished_message_id ),
159+ status_code = HTTPStatus .OK .value ,
160+ error_text = None ,
161+ )
162+ _action_tracker .forget_message (_finished_message_id )
152163 else :
153164 _action_tracker .action_errored (_indexname , _docid )
154- # yield error responses immediately
155165 yield messages .IndexMessageResponse (
156166 is_done = False ,
157167 index_message = messages .IndexMessage (
@@ -161,16 +171,14 @@ def pls_handle_messages_chunk(self, messages_chunk):
161171 status_code = _status ,
162172 error_text = str (_response_body ),
163173 )
164- self .after_chunk (messages_chunk , _indexnames )
165- # yield successes after the whole chunk completes
166- # (since one message may involve several actions)
167- for _messageid in _action_tracker .all_done_messages ():
174+ for _message_id in _action_tracker .remaining_done_messages ():
168175 yield messages .IndexMessageResponse (
169176 is_done = True ,
170- index_message = messages .IndexMessage (messages_chunk .message_type , _messageid ),
177+ index_message = messages .IndexMessage (messages_chunk .message_type , _message_id ),
171178 status_code = HTTPStatus .OK .value ,
172179 error_text = None ,
173180 )
181+ self .after_chunk (messages_chunk , _indexnames )
174182
175183 # abstract method from IndexStrategy
176184 def pls_make_default_for_searching (self , specific_index : IndexStrategy .SpecificIndex ):
@@ -202,14 +210,18 @@ def _alias_for_keeping_live(self):
202210 def _elastic_actions_with_index (self , messages_chunk , indexnames , action_tracker : _ActionTracker ):
203211 if not indexnames :
204212 raise ValueError ('cannot index to no indexes' )
205- for _message_target_id , _elastic_action in self .build_elastic_actions (messages_chunk ):
206- _docid = _elastic_action ['_id' ]
207- for _indexname in indexnames :
208- action_tracker .add_action (_message_target_id , _indexname , _docid )
209- yield {
210- ** _elastic_action ,
211- '_index' : _indexname ,
212- }
213+ for _message_target_id , _elastic_actions in self .build_elastic_actions (messages_chunk ):
214+ if isinstance (_elastic_actions , dict ): # allow a single action
215+ _elastic_actions = [_elastic_actions ]
216+ for _elastic_action in _elastic_actions :
217+ _docid = _elastic_action ['_id' ]
218+ for _indexname in indexnames :
219+ action_tracker .add_action (_message_target_id , _indexname , _docid )
220+ yield {
221+ ** _elastic_action ,
222+ '_index' : _indexname ,
223+ }
224+ action_tracker .done_scheduling (_message_target_id )
213225
214226 def _get_indexnames_for_alias (self , alias_name ) -> set [str ]:
215227 try :
@@ -371,24 +383,37 @@ class _ActionTracker:
371383 default_factory = lambda : collections .defaultdict (set ),
372384 )
373385 errored_messageids : set [int ] = dataclasses .field (default_factory = set )
386+ fully_scheduled_messageids : set [int ] = dataclasses .field (default_factory = set )
374387
375388 def add_action (self , message_id : int , index_name : str , doc_id : str ):
376389 self .messageid_by_docid [doc_id ] = message_id
377390 self .actions_by_messageid [message_id ].add ((index_name , doc_id ))
378391
379- def action_done (self , index_name : str , doc_id : str ):
380- _messageid = self .messageid_by_docid [doc_id ]
381- _message_actions = self .actions_by_messageid [_messageid ]
382- _message_actions .discard ((index_name , doc_id ))
392+ def action_done (self , index_name : str , doc_id : str ) -> int | None :
393+ _messageid = self .get_message_id (doc_id )
394+ _remaining_message_actions = self .actions_by_messageid [_messageid ]
395+ _remaining_message_actions .discard ((index_name , doc_id ))
396+ # return the message id only if this was the last action for that message
397+ return (
398+ None
399+ if _remaining_message_actions or (_messageid not in self .fully_scheduled_messageids )
400+ else _messageid
401+ )
383402
384403 def action_errored (self , index_name : str , doc_id : str ):
385404 _messageid = self .messageid_by_docid [doc_id ]
386405 self .errored_messageids .add (_messageid )
387406
407+ def done_scheduling (self , message_id : int ):
408+ self .fully_scheduled_messageids .add (message_id )
409+
410+ def forget_message (self , message_id : int ):
411+ del self .actions_by_messageid [message_id ]
412+
388413 def get_message_id (self , doc_id : str ):
389414 return self .messageid_by_docid [doc_id ]
390415
391- def all_done_messages (self ):
416+ def remaining_done_messages (self ):
392417 for _messageid , _actions in self .actions_by_messageid .items ():
393418 if _messageid not in self .errored_messageids :
394419 assert not _actions
0 commit comments