@@ -51,13 +51,28 @@ class _SCOPeDataExtractor(_DynamicDataset, ABC):
5151 "https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz"
5252 )
5353
54+ SCOPE_HIERARCHY : Dict [str , str ] = {
55+ "cl" : "class" ,
56+ "cf" : "fold" ,
57+ "sf" : "superfamily" ,
58+ "fa" : "family" ,
59+ "dm" : "protein" ,
60+ "sp" : "species" ,
61+ "px" : "domain" ,
62+ }
63+
5464 def __init__ (
5565 self ,
5666 scope_version : float ,
5767 scope_version_train : Optional [float ] = None ,
68+ scope_hierarchy_level : str = "cl" ,
5869 ** kwargs ,
5970 ):
6071
72+ assert (
73+ scope_hierarchy_level in self .SCOPE_HIERARCHY .keys ()
74+ ), f"level can contain only one of the following values { self .SCOPE_HIERARCHY .keys ()} "
75+ self .scope_hierarchy_level = scope_hierarchy_level
6176 self .scope_version : float = scope_version
6277 self .scope_version_train : float = scope_version_train
6378
@@ -67,7 +82,8 @@ def __init__(
6782 # Instantiate another same class with "scope_version" as "scope_version_train", if train_version is given
6883 # This is to get the data from respective directory related to "scope_version_train"
6984 _init_kwargs = kwargs
70- _init_kwargs ["chebi_version" ] = self .scope_version_train
85+ _init_kwargs ["scope_version" ] = self .scope_version_train
86+ _init_kwargs ["scope_hierarchy_level" ] = self .scope_hierarchy_level
7187 self ._scope_version_train_obj = self .__class__ (
7288 ** _init_kwargs ,
7389 )
@@ -150,18 +166,40 @@ def _download_scope_raw_data(self) -> str:
150166 open (scope_path , "wb" ).write (r .content )
151167 return "dummy/path"
152168
153- def _parse_pdb_sequence_file (self ) -> Dict [str , Dict [str , str ]]:
154- pdb_chain_seq_mapping : Dict [str , Dict [str , str ]] = {}
155- for record in SeqIO .parse (
156- os .path .join (self .raw_dir , self .raw_file_names_dict ["PDB" ]), "fasta"
157- ):
158- pdb_id , chain = record .id .split ("_" )
159- pdb_chain_seq_mapping .setdefault (pdb_id , {})[chain ] = str (record .seq )
160- return pdb_chain_seq_mapping
161-
162169 def _extract_class_hierarchy (self , data_path : str ) -> nx .DiGraph :
163170 print ("Extracting class hierarchy..." )
171+ df_scope = self ._get_scope_data ()
172+
173+ g = nx .DiGraph ()
174+
175+ egdes = []
176+ for _ , row in df_scope .iterrows ():
177+ g .add_node (row ["sunid" ], ** {"sid" : row ["sid" ], "level" : row ["level" ]})
178+ if row ["parent_sunid" ] != - 1 :
179+ egdes .append ((row ["parent_sunid" ], row ["sunid" ]))
180+
181+ for children_id in row ["children_sunids" ]:
182+ egdes .append ((row ["sunid" ], children_id ))
183+
184+ g .add_edges_from (egdes )
185+
186+ print ("Computing transitive closure" )
187+ return nx .transitive_closure_dag (g )
188+
189+ def _get_scope_data (self ) -> pd .DataFrame :
190+ df_cla = self ._get_classification_data ()
191+ df_hie = self ._get_hierarchy_data ()
192+ df_des = self ._get_node_description_data ()
193+ df_hie_with_cla = pd .merge (df_hie , df_cla , how = "left" , on = "sunid" )
194+ df_all = pd .merge (
195+ df_hie_with_cla ,
196+ df_des .drop (columns = ["sid" ], axis = 1 ),
197+ how = "left" ,
198+ on = "sunid" ,
199+ )
200+ return df_all
164201
202+ def _get_classification_data (self ) -> pd .DataFrame :
165203 # Load and preprocess CLA file
166204 df_cla = pd .read_csv (
167205 os .path .join (self .raw_dir , self .raw_file_names_dict ["CLA" ]),
@@ -175,125 +213,166 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
175213 "description" ,
176214 "sccs" ,
177215 "sunid" ,
178- "ancestor_nodes " ,
216+ "hie_levels " ,
179217 ]
180- df_cla ["sunid" ] = pd .to_numeric (
181- df_cla ["sunid" ], errors = "coerce" , downcast = "integer"
182- )
183- df_cla ["ancestor_nodes" ] = df_cla ["ancestor_nodes" ].apply (
218+
219+ # Convert to dict - {cl:46456, cf:46457, sf:46458, fa:46459, dm:46460, sp:116748, px:113449}
220+ df_cla ["hie_levels" ] = df_cla ["hie_levels" ].apply (
184221 lambda x : {k : int (v ) for k , v in (item .split ("=" ) for item in x .split ("," ))}
185222 )
186- df_cla .set_index ("sunid" , inplace = True )
187223
224+ # Split ancestor_nodes into separate columns and assign values
225+ for key in self .SCOPE_HIERARCHY .keys ():
226+ df_cla [self .SCOPE_HIERARCHY [key ]] = df_cla ["hie_levels" ].apply (
227+ lambda x : x [key ]
228+ )
229+
230+ df_cla ["sunid" ] = df_cla ["sunid" ].astype ("int64" )
231+
232+ return df_cla
233+
234+ def _get_hierarchy_data (self ) -> pd .DataFrame :
188235 # Load and preprocess HIE file
189236 df_hie = pd .read_csv (
190237 os .path .join (self .raw_dir , self .raw_file_names_dict ["HIE" ]),
191238 sep = "\t " ,
192239 header = None ,
193240 comment = "#" ,
241+ low_memory = False ,
194242 )
195243 df_hie .columns = ["sunid" , "parent_sunid" , "children_sunids" ]
196- df_hie ["sunid" ] = pd .to_numeric (
197- df_hie ["sunid" ], errors = "coerce" , downcast = "integer"
198- )
244+
245+ # if not parent id, then insert -1
199246 df_hie ["parent_sunid" ] = df_hie ["parent_sunid" ].replace ("-" , - 1 ).astype (int )
247+ # convert children ids to list of ids
200248 df_hie ["children_sunids" ] = df_hie ["children_sunids" ].apply (
201249 lambda x : list (map (int , x .split ("," ))) if x != "-" else []
202250 )
203251
204- # Initialize directed graph
205- g = nx .DiGraph ()
252+ # Ensure the 'sunid' column in both DataFrames has the same type
253+ df_hie ["sunid" ] = df_hie ["sunid" ].astype ("int64" )
254+ return df_hie
206255
207- # Add nodes and edges efficiently
208- g .add_edges_from (
209- df_hie [df_hie ["parent_sunid" ] != - 1 ].apply (
210- lambda row : (row ["parent_sunid" ], row ["sunid" ]), axis = 1
211- )
212- )
213- g .add_edges_from (
214- df_hie .explode ("children_sunids" )
215- .dropna ()
216- .apply (lambda row : (row ["sunid" ], row ["children_sunids" ]), axis = 1 )
256+ def _get_node_description_data (self ):
257+ # Load and preprocess HIE file
258+ df_des = pd .read_csv (
259+ os .path .join (self .raw_dir , self .raw_file_names_dict ["DES" ]),
260+ sep = "\t " ,
261+ header = None ,
262+ comment = "#" ,
263+ low_memory = False ,
217264 )
265+ df_des .columns = ["sunid" , "level" , "scss" , "sid" , "description" ]
266+ df_des .loc [len (df_des )] = {"sunid" : 0 , "level" : "root" }
218267
219- pdb_chain_seq_mapping = self ._parse_pdb_sequence_file ()
268+ # Ensure the 'sunid' column in both DataFrames has the same type
269+ df_des ["sunid" ] = df_des ["sunid" ].astype ("int64" )
270+ return df_des
220271
221- node_to_pdb_id = df_cla ["PDB_ID" ].to_dict ()
272+ def _graph_to_raw_dataset (self , graph : nx .DiGraph ) -> pd .DataFrame :
273+ print (f"Process graph" )
222274
223- for node in g .nodes ():
224- pdb_id = node_to_pdb_id [node ]
225- chain_mapping = pdb_chain_seq_mapping .get (pdb_id , {})
275+ sids = nx .get_node_attributes (graph , "sid" )
276+ levels = nx .get_node_attributes (graph , "level" )
226277
227- # Add nodes and edges for chains in the mapping
228- for chain , sequence in chain_mapping .items ():
229- chain_node = f"{ pdb_id } _{ chain } "
230- g .add_node (chain_node , sequence = sequence )
231- g .add_edge (node , chain_node )
278+ sun_ids = []
279+ sids_list = []
232280
233- print ("Compute transitive closure..." )
234- return nx .transitive_closure_dag (g )
281+ selected_sids_dict = self .select_classes (graph )
235282
236- def _graph_to_raw_dataset (self , g : nx .DiGraph ) -> pd .DataFrame :
237- """
238- Processes a directed acyclic graph (DAG) to create a raw dataset in DataFrame format. The dataset includes
239- Swiss-Prot protein data and their associations with Gene Ontology (GO) terms.
240-
241- Note:
242- - GO classes are used as labels in the dataset. Each GO term is represented as a column, and its value
243- indicates whether a Swiss-Prot protein is associated with that GO term.
244- - Swiss-Prot proteins serve as samples. There is no 1-to-1 correspondence between Swiss-Prot proteins
245- and GO terms.
246-
247- Data Format: pd.DataFrame
248- - Column 0 : swiss_id (Identifier for SwissProt protein)
249- - Column 1 : Accession of the protein
250- - Column 2 : GO IDs (associated GO terms)
251- - Column 3 : Sequence of the protein
252- - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the
253- protein is associated with this GO term.
283+ for sun_id , level in levels .items ():
284+ if level == self .scope_hierarchy_level and sun_id in selected_sids_dict :
285+ sun_ids .append (sun_id )
286+ sids_list .append (sids .get (sun_id ))
254287
255- Args:
256- g (nx.DiGraph): The class hierarchy graph.
288+ # data_df = pd.DataFrame(OrderedDict(sun_id=sun_ids, sids=sids_list))
289+ df_cla = self ._get_classification_data ()
290+ target_col_name = self .SCOPE_HIERARCHY [self .scope_hierarchy_level ]
291+ df_cla = df_cla [df_cla [target_col_name ].isin (sun_ids )]
292+ df_cla = df_cla [["sid" , target_col_name ]]
257293
258- Returns:
259- pd.DataFrame: The raw dataset created from the graph.
260- """
261- print (f"Processing graph" )
262-
263- data_df = self ._get_swiss_to_go_mapping ()
264- # add ancestors to go ids
265- data_df ["go_ids" ] = data_df ["go_ids" ].apply (
266- lambda go_ids : sorted (
267- set (
268- itertools .chain .from_iterable (
269- [
270- [go_id ] + list (g .predecessors (go_id ))
271- for go_id in go_ids
272- if go_id in g .nodes
273- ]
274- )
275- )
276- )
294+ assert (
295+ len (df_cla ) > 1
296+ ), "dataframe should have more than one instance for `pd.get_dummies` to work as expected"
297+ df_encoded = pd .get_dummies (
298+ df_cla , columns = [target_col_name ], drop_first = False , sparse = True
277299 )
278- # Initialize the GO term labels/columns to False
279- selected_classes = self .select_classes (g , data_df = data_df )
280- new_label_columns = pd .DataFrame (
281- False , index = data_df .index , columns = selected_classes
300+
301+ pdb_chain_seq_mapping = self ._parse_pdb_sequence_file ()
302+
303+ sequence_hierarchy_df = pd .DataFrame (
304+ columns = list (df_encoded .columns ) + ["sids" ]
282305 )
283- data_df = pd .concat ([data_df , new_label_columns ], axis = 1 )
284306
285- # Set True for the corresponding GO IDs in the DataFrame go labels/columns
286- for index , row in data_df .iterrows ():
287- for go_id in row ["go_ids" ]:
288- if go_id in data_df .columns :
289- data_df .at [index , go_id ] = True
307+ for _ , row in df_encoded .iterrows ():
308+ assert sum (row .iloc [1 :].tolist ()) == 1
309+ sid = row ["sid" ]
310+ # SID: 7-char identifier ("d" + 4-char PDB ID + chain ID ('_' for none, '.' for multiple)
311+ # + domain specifier ('_' if not needed))
312+ assert len (sid ) == 7 , "sid should have 7 characters"
313+ pdb_id , chain_id = sid [1 :5 ], sid [5 ]
314+
315+ pdb_to_chain_mapping = pdb_chain_seq_mapping .get (pdb_id , None )
316+ if not pdb_to_chain_mapping :
317+ continue
318+
319+ if chain_id != "_" :
320+ chain_sequence = pdb_to_chain_mapping .get (chain_id , None )
321+ if chain_sequence :
322+ self ._update_or_add_sequence (
323+ chain_sequence , row , sequence_hierarchy_df
324+ )
325+
326+ else :
327+ # Add nodes and edges for chains in the mapping
328+ for chain , chain_sequence in pdb_to_chain_mapping .items ():
329+ self ._update_or_add_sequence (
330+ chain_sequence , row , sequence_hierarchy_df
331+ )
332+
333+ sequence_hierarchy_df .drop (columns = ["sid" ], axis = 1 , inplace = True )
334+ sequence_hierarchy_df = sequence_hierarchy_df [
335+ ["sids" ] + [col for col in sequence_hierarchy_df .columns if col != "sids" ]
336+ ]
290337
291338 # This filters the DataFrame to include only the rows where at least one value in the row from 5th column
292339 # onwards is True/non-zero.
293- # Quote from DeepGo Paper: `For training and testing, we use proteins which have been annotated with at least
294- # one GO term from the set of the GO terms for the model`
295- data_df = data_df [data_df .iloc [:, self ._LABELS_START_IDX :].any (axis = 1 )]
296- return data_df
340+ sequence_hierarchy_df = sequence_hierarchy_df [
341+ sequence_hierarchy_df .iloc [:, self ._LABELS_START_IDX :].any (axis = 1 )
342+ ]
343+ return sequence_hierarchy_df
344+
345+ def _parse_pdb_sequence_file (self ) -> Dict [str , Dict [str , str ]]:
346+ pdb_chain_seq_mapping : Dict [str , Dict [str , str ]] = {}
347+ for record in SeqIO .parse (
348+ os .path .join (self .raw_dir , self .raw_file_names_dict ["PDB" ]), "fasta"
349+ ):
350+ pdb_id , chain = record .id .split ("_" )
351+ if str (record .seq ):
352+ pdb_chain_seq_mapping .setdefault (pdb_id .lower (), {})[chain .lower ()] = (
353+ str (record .seq )
354+ )
355+ return pdb_chain_seq_mapping
356+
357+ @staticmethod
358+ def _update_or_add_sequence (sequence , row , sequence_hierarchy_df ):
359+ # Check if sequence already exists as an index
360+ # Slice the series starting from column 2
361+ sliced_data = row .iloc [1 :] # Slice starting from the second column (index 1)
362+
363+ # Get the column name with the True value
364+ true_column = sliced_data .idxmax () if sliced_data .any () else None
365+
366+ if sequence in sequence_hierarchy_df .index :
367+ # Update encoded columns only if they are True
368+ if row [true_column ] is True :
369+ sequence_hierarchy_df .loc [sequence , true_column ] = True
370+ sequence_hierarchy_df .loc [sequence , "sids" ].append (row ["sid" ])
371+ else :
372+ # Add new row with sequence as the index and hierarchy data
373+ new_row = row
374+ new_row ["sids" ] = [row ["sid" ]]
375+ sequence_hierarchy_df .loc [sequence ] = new_row
297376
298377 # ------------------------------ Phase: Setup data -----------------------------------
299378 def _load_dict (self , input_file_path : str ) -> Generator [Dict [str , Any ], None , None ]:
@@ -367,15 +446,33 @@ def raw_file_names_dict(self) -> dict:
367446
368447class SCOPE (_SCOPeDataExtractor ):
369448 READER = ProteinDataReader
449+ THRESHOLD = 1
370450
371451 @property
372452 def _name (self ) -> str :
373453 return "test"
374454
375- def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> List :
376- pass
455+ def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> Dict :
456+ # Filter nodes and create a dictionary of node and out-degree
457+ sun_ids_dict = {
458+ node : g .out_degree (node ) # Store node and its out-degree
459+ for node in g .nodes
460+ if g .out_degree (node ) >= self .THRESHOLD
461+ }
462+
463+ # Return a sorted dictionary (by out-degree or node id)
464+ sorted_dict = dict (
465+ sorted (sun_ids_dict .items (), key = lambda item : item [0 ], reverse = False )
466+ )
467+
468+ filename = "classes.txt"
469+ with open (os .path .join (self .processed_dir_main , filename ), "wt" ) as fout :
470+ fout .writelines (str (sun_id ) + "\n " for sun_id in sorted_dict .keys ())
471+
472+ return sorted_dict
377473
378474
379475if __name__ == "__main__" :
380476 scope = SCOPE (scope_version = 2.08 )
381- scope ._parse_pdb_sequence_file ()
477+ g = scope ._extract_class_hierarchy ("d" )
478+ scope ._graph_to_raw_dataset (g )
0 commit comments