@@ -413,8 +413,7 @@ async def store_state_deltas_for_batched(
413413 prev_group : int ,
414414 ) -> List [Tuple [EventBase , UnpersistedEventContext ]]:
415415 """Generate and store state deltas for a group of events and contexts created to be
416- batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c)
417- and must be state events.
416+ batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
418417
419418 Args:
420419 events_and_context: the events to generate and store a state groups for
@@ -449,31 +448,32 @@ def insert_deltas_group_txn(
449448 % (prev_group ,)
450449 )
451450
452- num_state_groups = len (events_and_context )
451+ num_state_groups = 0
452+ for event , _ in events_and_context :
453+ if event .is_state ():
454+ num_state_groups += 1
453455
454456 state_groups = self ._state_group_seq_gen .get_next_mult_txn (
455457 txn , num_state_groups
456458 )
457459
460+ sg_before = prev_group
458461 for index , (event , context ) in enumerate (events_and_context ):
459- context .state_group_after_event = state_groups [index ]
460- # The first prev_group will be the last persisted state group, which is passed in
461- # else it will be the group most recently assigned
462- if index > 0 :
463- context .prev_group_for_state_group_after_event = state_groups [
464- index - 1
465- ]
466- context .state_group_before_event = state_groups [index - 1 ]
467- else :
468- context .prev_group_for_state_group_after_event = prev_group
469- context .state_group_before_event = prev_group
470- context .delta_ids_to_state_group_after_event = {
462+ if not event .is_state ():
463+ context .state_group_after_event = sg_before
464+ context .state_group_before_event = sg_before
465+ pass
466+
467+ sg_after = state_groups [index ]
468+ context .state_group_after_event = sg_after
469+ context .state_group_before_event = sg_before
470+ context .delta_ids_to_state_group_before_event = {
471471 (event .type , event .state_key ): event .event_id
472472 }
473473 context .state_delta_due_to_event = {
474474 (event .type , event .state_key ): event .event_id
475475 }
476- index += 1
476+ sg_before = sg_after
477477
478478 self .db_pool .simple_insert_many_txn (
479479 txn ,
@@ -492,29 +492,35 @@ def insert_deltas_group_txn(
492492 values = [
493493 (
494494 context .state_group_after_event ,
495- context .prev_group_for_state_group_after_event ,
495+ context .state_group_before_event ,
496496 )
497497 for _ , context in events_and_context
498498 ],
499499 )
500500
501+ values = []
501502 for _ , context in events_and_context :
502- assert context .delta_ids_to_state_group_after_event is not None
503- self . db_pool . simple_insert_many_txn (
504- txn ,
505- table = "state_groups_state" ,
506- keys = ( "state_group" , "room_id" , "type" , "state_key" , "event_id" ),
507- values = [
503+ assert context .delta_ids_to_state_group_before_event is not None
504+ for (
505+ key ,
506+ state_id ,
507+ ) in context . delta_ids_to_state_group_before_event . items ():
508+ values . append (
508509 (
509510 context .state_group_after_event ,
510511 room_id ,
511512 key [0 ],
512513 key [1 ],
513514 state_id ,
514515 )
515- for key , state_id in context .delta_ids_to_state_group_after_event .items ()
516- ],
517- )
516+ )
517+
518+ self .db_pool .simple_insert_many_txn (
519+ txn ,
520+ table = "state_groups_state" ,
521+ keys = ("state_group" , "room_id" , "type" , "state_key" , "event_id" ),
522+ values = values ,
523+ )
518524 return events_and_context
519525
520526 return await self .db_pool .runInteraction (
0 commit comments