@@ -198,49 +198,91 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
198198 pdb_chain_df = self ._parse_pdb_sequence_file ()
199199 pdb_id_set = set (pdb_chain_df ["pdb_id" ]) # Search time complexity - O(1)
200200
201- g = nx .DiGraph ()
202-
203- edges = []
201+ # Initialize sets and dictionaries for storing edges and attributes
202+ parent_node_edges , node_child_edges = set (), set ()
204203 node_attrs = {}
205204 px_level_nodes = set ()
205+ sequence_nodes = dict ()
206+ px_to_seq_edges = set ()
207+ required_graph_nodes = set ()
208+
209+ # Create a lookup dictionary for PDB chain sequences
210+ lookup_dict = (
211+ pdb_chain_df .groupby ("pdb_id" )[["chain_id" , "sequence" ]]
212+ .apply (lambda x : dict (zip (x ["chain_id" ], x ["sequence" ])))
213+ .to_dict ()
214+ )
215+
216+ def add_sequence_nodes_edges (chain_sequence , px_sun_id ):
217+ """Adds sequence nodes and edges connecting px-level nodes to sequence nodes."""
218+ if chain_sequence not in sequence_nodes :
219+ sequence_nodes [chain_sequence ] = f"seq_{ len (sequence_nodes )} "
220+ px_to_seq_edges .add ((px_sun_id , sequence_nodes [chain_sequence ]))
206221
207- # Step 1: Build the graph and store attributes
222+ # Step 1: Build the graph structure and store node attributes
208223 for row in df_scope .itertuples (index = False ):
209224 if row .level == "px" :
210- if row .sid [1 :5 ] not in pdb_id_set :
225+
226+ pdb_id , chain_id = row .sid [1 :5 ], row .sid [5 ]
227+
228+ if pdb_id not in pdb_id_set or chain_id == "_" :
211229 # Don't add domain level nodes that don't have pdb_id in pdb_sequences.txt file
230+ # Also chain_id with "_" which corresponds to no chain
212231 continue
213232 px_level_nodes .add (row .sunid )
214233
234+ # Add edges between px-level nodes and sequence nodes
235+ if chain_id != "." :
236+ if chain_id not in lookup_dict [pdb_id ]:
237+ continue
238+ add_sequence_nodes_edges (lookup_dict [pdb_id ][chain_id ], row .sunid )
239+ else :
240+ # If chain_id is '.', connect all chains of this PDB ID
241+ for chain , chain_sequence in lookup_dict [pdb_id ].items ():
242+ add_sequence_nodes_edges (chain_sequence , row .sunid )
243+ else :
244+ required_graph_nodes .add (row .sunid )
245+
215246 node_attrs [row .sunid ] = {"sid" : row .sid , "level" : row .level }
216247
217248 if row .parent_sunid != - 1 :
218- edges . append ((row .parent_sunid , row .sunid ))
249+ parent_node_edges . add ((row .parent_sunid , row .sunid ))
219250
220251 for child_id in row .children_sunids :
221- edges . append ((row .sunid , child_id ))
252+ node_child_edges . add ((row .sunid , child_id ))
222253
223- g .add_nodes_from ((node , attrs ) for node , attrs in node_attrs .items ())
224- g .add_edges_from (edges )
254+ del df_scope , pdb_chain_df , pdb_id_set
225255
226- # Step 2: Compute the transitive closure first
227- print ("Computing transitive closure" )
228- g_tc = nx .transitive_closure_dag (g )
229-
230- print (
231- "Remove node without domain descendants that don't have pdb correspondence"
256+ g = nx .DiGraph ()
257+ g .add_nodes_from (node_attrs .items ())
258+ # Note - `add_edges` internally create a node, if a node doesn't exist already
259+ g .add_edges_from ({(p , c ) for p , c in parent_node_edges if p in node_attrs })
260+ g .add_edges_from ({(p , c ) for p , c in node_child_edges if c in node_attrs })
261+
262+ seq_nodes = set (sequence_nodes .values ())
263+ g .add_nodes_from ([(seq_id , {"level" : "sequence" }) for seq_id in seq_nodes ])
264+ g .add_edges_from (
265+ {
266+ (px_node , seq_node )
267+ for px_node , seq_node in px_to_seq_edges
268+ if px_node in node_attrs and seq_node in seq_nodes
269+ }
232270 )
233- # Step 3: Identify and remove nodes that don’t have a "px" descendant with correspondence to pdb_sequences file
234- nodes_to_remove = set ()
235- for node in g_tc .nodes :
236- if node not in px_level_nodes and not any (
237- desc in px_level_nodes for desc in g_tc .successors (node )
238- ):
239- nodes_to_remove .add (node )
240271
241- g_tc .remove_nodes_from (nodes_to_remove )
272+ # Step 2: Count sequence successors for required graph nodes only
273+ for node in required_graph_nodes :
274+ num_seq_successors = sum (
275+ g .nodes [child ]["level" ] == "sequence"
276+ for child in nx .descendants (g , node )
277+ )
278+ g .nodes [node ]["num_seq_successors" ] = num_seq_successors
242279
243- return g_tc
280+ # Step 3: Remove nodes which are not required before computing transitive closure for better efficiency
281+ g .remove_nodes_from (px_level_nodes | seq_nodes )
282+
283+ print ("Computing Transitive Closure........." )
284+ # Transitive closure is not needed in `select_classes` method but is required in _SCOPeOverXPartial
285+ return nx .transitive_closure_dag (g )
244286
245287 def _get_scope_data (self ) -> pd .DataFrame :
246288 """
@@ -808,12 +850,15 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]
808850 """
809851 selected_sunids_for_level = {}
810852 for node , attr_dict in g .nodes (data = True ):
811- if g .out_degree (node ) >= self .THRESHOLD :
853+ if attr_dict ["level" ] in {"root" , "px" , "sequence" }:
854+ # Skip nodes with level "root", "px", or "sequence"
855+ continue
856+
857+ # Check if the number of "sequence"-level successors meets or exceeds the threshold
858+ if g .nodes [node ]["num_seq_successors" ] >= self .THRESHOLD :
812859 selected_sunids_for_level .setdefault (attr_dict ["level" ], []).append (
813860 node
814861 )
815- # Remove root node, as it will True for all instances
816- selected_sunids_for_level .pop ("root" , None )
817862 return selected_sunids_for_level
818863
819864
0 commit comments