@@ -72,10 +72,12 @@ def __init__(
7272 self ,
7373 scope_version : str ,
7474 scope_version_train : Optional [str ] = None ,
75+ max_sequence_len : int = 1000 ,
7576 ** kwargs ,
7677 ):
7778 self .scope_version : str = scope_version
7879 self .scope_version_train : str = scope_version_train
80+ self .max_sequence_len : int = max_sequence_len
7981
8082 super (_SCOPeDataExtractor , self ).__init__ (** kwargs )
8183
@@ -195,21 +197,93 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
195197 """
196198 print ("Extracting class hierarchy..." )
197199 df_scope = self ._get_scope_data ()
200+ pdb_chain_df = self ._parse_pdb_sequence_file ()
201+ pdb_id_set = set (pdb_chain_df ["pdb_id" ]) # Search time complexity - O(1)
202+
203+ # Initialize sets and dictionaries for storing edges and attributes
204+ parent_node_edges , node_child_edges = set (), set ()
205+ node_attrs = {}
206+ px_level_nodes = set ()
207+ sequence_nodes = dict ()
208+ px_to_seq_edges = set ()
209+ required_graph_nodes = set ()
210+
211+ # Create a lookup dictionary for PDB chain sequences
212+ lookup_dict = (
213+ pdb_chain_df .groupby ("pdb_id" )[["chain_id" , "sequence" ]]
214+ .apply (lambda x : dict (zip (x ["chain_id" ], x ["sequence" ])))
215+ .to_dict ()
216+ )
198217
199- g = nx .DiGraph ()
218+ def add_sequence_nodes_edges (chain_sequence , px_sun_id ):
219+ """Adds sequence nodes and edges connecting px-level nodes to sequence nodes."""
220+ if chain_sequence not in sequence_nodes :
221+ sequence_nodes [chain_sequence ] = f"seq_{ len (sequence_nodes )} "
222+ px_to_seq_edges .add ((px_sun_id , sequence_nodes [chain_sequence ]))
223+
224+ # Step 1: Build the graph structure and store node attributes
225+ for row in df_scope .itertuples (index = False ):
226+ if row .level == "px" :
227+
228+ pdb_id , chain_id = row .sid [1 :5 ], row .sid [5 ]
200229
201- egdes = []
202- for _ , row in df_scope . iterrows ():
203- g . add_node ( row [ "sunid" ], ** { "sid" : row [ "sid" ], "level" : row [ "level" ]})
204- if row [ "parent_sunid" ] != - 1 :
205- egdes . append (( row [ "parent_sunid" ], row [ " sunid" ]) )
230+ if pdb_id not in pdb_id_set or chain_id == "_" :
231+ # Don't add domain level nodes that don't have pdb_id in pdb_sequences.txt file
232+ # Also chain_id with "_" which corresponds to no chain
233+ continue
234+ px_level_nodes . add ( row . sunid )
206235
207- for children_id in row ["children_sunids" ]:
208- egdes .append ((row ["sunid" ], children_id ))
236+ # Add edges between px-level nodes and sequence nodes
237+ if chain_id != "." :
238+ if chain_id not in lookup_dict [pdb_id ]:
239+ continue
240+ add_sequence_nodes_edges (lookup_dict [pdb_id ][chain_id ], row .sunid )
241+ else :
242+ # If chain_id is '.', connect all chains of this PDB ID
243+ for chain , chain_sequence in lookup_dict [pdb_id ].items ():
244+ add_sequence_nodes_edges (chain_sequence , row .sunid )
245+ else :
246+ required_graph_nodes .add (row .sunid )
209247
210- g . add_edges_from ( egdes )
248+ node_attrs [ row . sunid ] = { "sid" : row . sid , "level" : row . level }
211249
212- print ("Computing transitive closure" )
250+ if row .parent_sunid != - 1 :
251+ parent_node_edges .add ((row .parent_sunid , row .sunid ))
252+
253+ for child_id in row .children_sunids :
254+ node_child_edges .add ((row .sunid , child_id ))
255+
256+ del df_scope , pdb_chain_df , pdb_id_set
257+
258+ g = nx .DiGraph ()
259+ g .add_nodes_from (node_attrs .items ())
260+ # Note - `add_edges` internally create a node, if a node doesn't exist already
261+ g .add_edges_from ({(p , c ) for p , c in parent_node_edges if p in node_attrs })
262+ g .add_edges_from ({(p , c ) for p , c in node_child_edges if c in node_attrs })
263+
264+ seq_nodes = set (sequence_nodes .values ())
265+ g .add_nodes_from ([(seq_id , {"level" : "sequence" }) for seq_id in seq_nodes ])
266+ g .add_edges_from (
267+ {
268+ (px_node , seq_node )
269+ for px_node , seq_node in px_to_seq_edges
270+ if px_node in node_attrs and seq_node in seq_nodes
271+ }
272+ )
273+
274+ # Step 2: Count sequence successors for required graph nodes only
275+ for node in required_graph_nodes :
276+ num_seq_successors = sum (
277+ g .nodes [child ]["level" ] == "sequence"
278+ for child in nx .descendants (g , node )
279+ )
280+ g .nodes [node ]["num_seq_successors" ] = num_seq_successors
281+
282+ # Step 3: Remove nodes which are not required before computing transitive closure for better efficiency
283+ g .remove_nodes_from (px_level_nodes | seq_nodes )
284+
285+ print ("Computing Transitive Closure........." )
286+ # Transitive closure is not needed in `select_classes` method but is required in _SCOPeOverXPartial
213287 return nx .transitive_closure_dag (g )
214288
215289 def _get_scope_data (self ) -> pd .DataFrame :
@@ -388,7 +462,8 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
388462
389463 encoded_target_columns = []
390464 for level in hierarchy_levels :
391- encoded_target_columns .extend (lvl_to_target_cols_mapping [level ])
465+ if level in lvl_to_target_cols_mapping :
466+ encoded_target_columns .extend (lvl_to_target_cols_mapping [level ])
392467
393468 print (
394469 f"{ len (encoded_target_columns )} labels has been selected for specified threshold, "
@@ -471,12 +546,12 @@ def _parse_pdb_sequence_file(self) -> pd.DataFrame:
471546 for record in SeqIO .parse (
472547 os .path .join (self .scope_root_dir , self .raw_file_names_dict ["PDB" ]), "fasta"
473548 ):
549+
550+ if not record .seq or len (record .seq ) > self .max_sequence_len :
551+ continue
552+
474553 pdb_id , chain = record .id .split ("_" )
475- sequence = (
476- re .sub (f"[^{ valid_amino_acids } ]" , "X" , str (record .seq ))
477- if record .seq
478- else ""
479- )
554+ sequence = re .sub (f"[^{ valid_amino_acids } ]" , "X" , str (record .seq ))
480555
481556 # Store as a dictionary entry (list of dicts -> DataFrame later)
482557 records .append (
@@ -777,12 +852,15 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]
777852 """
778853 selected_sunids_for_level = {}
779854 for node , attr_dict in g .nodes (data = True ):
780- if g .out_degree (node ) >= self .THRESHOLD :
855+ if attr_dict ["level" ] in {"root" , "px" , "sequence" }:
856+ # Skip nodes with level "root", "px", or "sequence"
857+ continue
858+
859+ # Check if the number of "sequence"-level successors meets or exceeds the threshold
860+ if g .nodes [node ]["num_seq_successors" ] >= self .THRESHOLD :
781861 selected_sunids_for_level .setdefault (attr_dict ["level" ], []).append (
782862 node
783863 )
784- # Remove root node, as it will True for all instances
785- selected_sunids_for_level .pop ("root" , None )
786864 return selected_sunids_for_level
787865
788866
@@ -876,7 +954,8 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial):
876954
877955
878956if __name__ == "__main__" :
879- scope = SCOPeOver2000 (scope_version = "2.08" )
957+ scope = SCOPeOver50 (scope_version = "2.08" )
958+
880959 # g = scope._extract_class_hierarchy("dummy/path")
881960 # # Save graph
882961 # import pickle
0 commit comments