@@ -195,22 +195,52 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
195195 """
196196 print ("Extracting class hierarchy..." )
197197 df_scope = self ._get_scope_data ()
198+ pdb_chain_df = self ._parse_pdb_sequence_file ()
199+ pdb_id_set = set (pdb_chain_df ["pdb_id" ]) # Search time complexity - O(1)
198200
199201 g = nx .DiGraph ()
200202
201- egdes = []
202- for _ , row in df_scope .iterrows ():
203- g .add_node (row ["sunid" ], ** {"sid" : row ["sid" ], "level" : row ["level" ]})
204- if row ["parent_sunid" ] != - 1 :
205- egdes .append ((row ["parent_sunid" ], row ["sunid" ]))
203+ edges = []
204+ node_attrs = {}
205+ px_level_nodes = set ()
206+
207+ # Step 1: Build the graph and store attributes
208+ for row in df_scope .itertuples (index = False ):
209+ if row .level == "px" :
210+ if row .sid [1 :5 ] not in pdb_id_set :
211+ # Don't add domain level nodes that don't have pdb_id in pdb_sequences.txt file
212+ continue
213+ px_level_nodes .add (row .sunid )
214+
215+ node_attrs [row .sunid ] = {"sid" : row .sid , "level" : row .level }
216+
217+ if row .parent_sunid != - 1 :
218+ edges .append ((row .parent_sunid , row .sunid ))
206219
207- for children_id in row [ " children_sunids" ] :
208- egdes .append ((row [ " sunid" ], children_id ))
220+ for child_id in row . children_sunids :
221+ edges .append ((row . sunid , child_id ))
209222
210- g .add_edges_from (egdes )
223+ g .add_nodes_from ((node , attrs ) for node , attrs in node_attrs .items ())
224+ g .add_edges_from (edges )
211225
226+ # Step 2: Compute the transitive closure first
212227 print ("Computing transitive closure" )
213- return nx .transitive_closure_dag (g )
228+ g_tc = nx .transitive_closure_dag (g )
229+
230+ print (
231+ "Remove node without domain descendants that don't have pdb correspondence"
232+ )
233+ # Step 3: Identify and remove nodes that don’t have a "px" descendant with correspondence to pdb_sequences file
234+ nodes_to_remove = set ()
235+ for node in g_tc .nodes :
236+ if node not in px_level_nodes and not any (
237+ desc in px_level_nodes for desc in g_tc .successors (node )
238+ ):
239+ nodes_to_remove .add (node )
240+
241+ g_tc .remove_nodes_from (nodes_to_remove )
242+
243+ return g_tc
214244
215245 def _get_scope_data (self ) -> pd .DataFrame :
216246 """
@@ -388,7 +418,8 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
388418
389419 encoded_target_columns = []
390420 for level in hierarchy_levels :
391- encoded_target_columns .extend (lvl_to_target_cols_mapping [level ])
421+ if level in lvl_to_target_cols_mapping :
422+ encoded_target_columns .extend (lvl_to_target_cols_mapping [level ])
392423
393424 print (
394425 f"{ len (encoded_target_columns )} labels has been selected for specified threshold, "
@@ -471,12 +502,12 @@ def _parse_pdb_sequence_file(self) -> pd.DataFrame:
471502 for record in SeqIO .parse (
472503 os .path .join (self .scope_root_dir , self .raw_file_names_dict ["PDB" ]), "fasta"
473504 ):
505+
506+ if not record .seq :
507+ continue
508+
474509 pdb_id , chain = record .id .split ("_" )
475- sequence = (
476- re .sub (f"[^{ valid_amino_acids } ]" , "X" , str (record .seq ))
477- if record .seq
478- else ""
479- )
510+ sequence = re .sub (f"[^{ valid_amino_acids } ]" , "X" , str (record .seq ))
480511
481512 # Store as a dictionary entry (list of dicts -> DataFrame later)
482513 records .append (
@@ -876,7 +907,8 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial):
876907
877908
878909if __name__ == "__main__" :
879- scope = SCOPeOver2000 (scope_version = "2.08" )
910+ scope = SCOPeOver50 (scope_version = "2.08" )
911+
880912 # g = scope._extract_class_hierarchy("dummy/path")
881913 # # Save graph
882914 # import pickle
0 commit comments