1414
1515import  logging 
1616from  typing  import  (
17-     TYPE_CHECKING ,
1817    Collection ,
1918    Dict ,
2019    FrozenSet ,
3231from  synapse .api .constants  import  RelationTypes 
3332from  synapse .events  import  EventBase 
3433from  synapse .storage ._base  import  SQLBaseStore 
35- from  synapse .storage .database  import  (
36-     DatabasePool ,
37-     LoggingDatabaseConnection ,
38-     LoggingTransaction ,
39-     make_in_list_sql_clause ,
40- )
34+ from  synapse .storage .database  import  LoggingTransaction , make_in_list_sql_clause 
4135from  synapse .storage .databases .main .stream  import  generate_pagination_where_clause 
4236from  synapse .storage .engines  import  PostgresEngine 
4337from  synapse .types  import  JsonDict , RoomStreamToken , StreamToken 
4438from  synapse .util .caches .descriptors  import  cached , cachedList 
4539
46- if  TYPE_CHECKING :
47-     from  synapse .server  import  HomeServer 
48- 
4940logger  =  logging .getLogger (__name__ )
5041
5142
@@ -63,16 +54,6 @@ class _RelatedEvent:
6354
6455
6556class  RelationsWorkerStore (SQLBaseStore ):
66-     def  __init__ (
67-         self ,
68-         database : DatabasePool ,
69-         db_conn : LoggingDatabaseConnection ,
70-         hs : "HomeServer" ,
71-     ):
72-         super ().__init__ (database , db_conn , hs )
73- 
74-         self ._msc3440_enabled  =  hs .config .experimental .msc3440_enabled 
75- 
7657    @cached (uncached_args = ("event" ,), tree = True ) 
7758    async  def  get_relations_for_event (
7859        self ,
@@ -497,7 +478,7 @@ def _get_thread_summaries_txn(
497478                        AND parent.room_id = child.room_id 
498479                    WHERE 
499480                        %s 
500-                         AND %s  
481+                         AND relation_type = ?  
501482                    ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC 
502483                """ 
503484            else :
@@ -512,22 +493,16 @@ def _get_thread_summaries_txn(
512493                        AND parent.room_id = child.room_id 
513494                    WHERE 
514495                        %s 
515-                         AND %s  
496+                         AND relation_type = ?  
516497                    ORDER BY child.topological_ordering DESC, child.stream_ordering DESC 
517498                """ 
518499
519500            clause , args  =  make_in_list_sql_clause (
520501                txn .database_engine , "relates_to_id" , event_ids 
521502            )
503+             args .append (RelationTypes .THREAD )
522504
523-             if  self ._msc3440_enabled :
524-                 relations_clause  =  "(relation_type = ? OR relation_type = ?)" 
525-                 args .extend ((RelationTypes .THREAD , RelationTypes .UNSTABLE_THREAD ))
526-             else :
527-                 relations_clause  =  "relation_type = ?" 
528-                 args .append (RelationTypes .THREAD )
529- 
530-             txn .execute (sql  %  (clause , relations_clause ), args )
505+             txn .execute (sql  %  (clause ,), args )
531506            latest_event_ids  =  {}
532507            for  parent_event_id , child_event_id  in  txn :
533508                # Only consider the latest threaded reply (by topological ordering). 
@@ -547,7 +522,7 @@ def _get_thread_summaries_txn(
547522                    AND parent.room_id = child.room_id 
548523                WHERE 
549524                    %s 
550-                     AND %s  
525+                     AND relation_type = ?  
551526                GROUP BY parent.event_id 
552527            """ 
553528
@@ -556,15 +531,9 @@ def _get_thread_summaries_txn(
556531            clause , args  =  make_in_list_sql_clause (
557532                txn .database_engine , "relates_to_id" , latest_event_ids .keys ()
558533            )
534+             args .append (RelationTypes .THREAD )
559535
560-             if  self ._msc3440_enabled :
561-                 relations_clause  =  "(relation_type = ? OR relation_type = ?)" 
562-                 args .extend ((RelationTypes .THREAD , RelationTypes .UNSTABLE_THREAD ))
563-             else :
564-                 relations_clause  =  "relation_type = ?" 
565-                 args .append (RelationTypes .THREAD )
566- 
567-             txn .execute (sql  %  (clause , relations_clause ), args )
536+             txn .execute (sql  %  (clause ,), args )
568537            counts  =  dict (cast (List [Tuple [str , int ]], txn .fetchall ()))
569538
570539            return  counts , latest_event_ids 
@@ -622,7 +591,7 @@ async def get_threaded_messages_per_user(
622591                parent.event_id = relates_to_id 
623592                AND parent.room_id = child.room_id 
624593            WHERE 
625-                 %s  
594+                 relation_type = ?  
626595                AND %s 
627596                AND %s 
628597            GROUP BY parent.event_id, child.sender 
@@ -638,16 +607,9 @@ def _get_threaded_messages_per_user_txn(
638607                txn .database_engine , "relates_to_id" , event_ids 
639608            )
640609
641-             if  self ._msc3440_enabled :
642-                 relations_clause  =  "(relation_type = ? OR relation_type = ?)" 
643-                 relations_args  =  [RelationTypes .THREAD , RelationTypes .UNSTABLE_THREAD ]
644-             else :
645-                 relations_clause  =  "relation_type = ?" 
646-                 relations_args  =  [RelationTypes .THREAD ]
647- 
648610            txn .execute (
649-                 sql  %  (users_sql , events_clause ,  relations_clause ),
650-                 users_args  +  events_args  +  relations_args ,
611+                 sql  %  (users_sql , events_clause ),
612+                 [ RelationTypes . THREAD ]  +  users_args  +  events_args ,
651613            )
652614            return  {(row [0 ], row [1 ]): row [2 ] for  row  in  txn }
653615
@@ -677,7 +639,7 @@ async def get_threads_participated(
677639            user participated in that event's thread, otherwise false. 
678640        """ 
679641
680-         def  _get_thread_summary_txn (txn : LoggingTransaction ) ->  Set [str ]:
642+         def  _get_threads_participated_txn (txn : LoggingTransaction ) ->  Set [str ]:
681643            # Fetch whether the requester has participated or not. 
682644            sql  =  """ 
683645                SELECT DISTINCT relates_to_id 
@@ -688,28 +650,20 @@ def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]:
688650                    AND parent.room_id = child.room_id 
689651                WHERE 
690652                    %s 
691-                     AND %s  
653+                     AND relation_type = ?  
692654                    AND child.sender = ? 
693655            """ 
694656
695657            clause , args  =  make_in_list_sql_clause (
696658                txn .database_engine , "relates_to_id" , event_ids 
697659            )
660+             args .extend ([RelationTypes .THREAD , user_id ])
698661
699-             if  self ._msc3440_enabled :
700-                 relations_clause  =  "(relation_type = ? OR relation_type = ?)" 
701-                 args .extend ((RelationTypes .THREAD , RelationTypes .UNSTABLE_THREAD ))
702-             else :
703-                 relations_clause  =  "relation_type = ?" 
704-                 args .append (RelationTypes .THREAD )
705- 
706-             args .append (user_id )
707- 
708-             txn .execute (sql  %  (clause , relations_clause ), args )
662+             txn .execute (sql  %  (clause ,), args )
709663            return  {row [0 ] for  row  in  txn .fetchall ()}
710664
711665        participated_threads  =  await  self .db_pool .runInteraction (
712-             "get_thread_summary " , _get_thread_summary_txn 
666+             "get_threads_participated " , _get_threads_participated_txn 
713667        )
714668
715669        return  {event_id : event_id  in  participated_threads  for  event_id  in  event_ids }
0 commit comments