@@ -33,18 +33,15 @@ class _SCOPeDataExtractor(_DynamicDataset, ABC):
3333 **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule.
3434 """
3535
36- _GO_DATA_INIT = "GO"
37- _SWISS_DATA_INIT = "SWISS"
38-
3936 # -- Index for columns of processed `data.pkl` (derived from `_get_swiss_to_go_mapping` & `_graph_to_raw_dataset`
4037 # "swiss_id" at row index 0
4138 # "accession" at row index 1
4239 # "go_ids" at row index 2
4340 # "sequence" at row index 3
4441 # labels starting from row index 4
4542 _ID_IDX : int = 0
46- _DATA_REPRESENTATION_IDX : int = 3 # here `sequence` column
47- _LABELS_START_IDX : int = 4
43+ _DATA_REPRESENTATION_IDX : int = 2 # here `sequence` column
44+ _LABELS_START_IDX : int = 3
4845
4946 _SCOPE_GENERAL_URL = "https://scop.berkeley.edu/downloads/parse/dir.{data_type}.scope.{version_number}-stable.txt"
5047 _PDB_SEQUENCE_DATA_URL = (
@@ -65,14 +62,8 @@ def __init__(
6562 self ,
6663 scope_version : float ,
6764 scope_version_train : Optional [float ] = None ,
68- scope_hierarchy_level : str = "cl" ,
6965 ** kwargs ,
7066 ):
71-
72- assert (
73- scope_hierarchy_level in self .SCOPE_HIERARCHY .keys ()
74- ), f"level can contain only one of the following values { self .SCOPE_HIERARCHY .keys ()} "
75- self .scope_hierarchy_level = scope_hierarchy_level
7667 self .scope_version : float = scope_version
7768 self .scope_version_train : float = scope_version_train
7869
@@ -83,7 +74,6 @@ def __init__(
8374 # This is to get the data from respective directory related to "scope_version_train"
8475 _init_kwargs = kwargs
8576 _init_kwargs ["scope_version" ] = self .scope_version_train
86- _init_kwargs ["scope_hierarchy_level" ] = self .scope_hierarchy_level
8777 self ._scope_version_train_obj = self .__class__ (
8878 ** _init_kwargs ,
8979 )
@@ -275,37 +265,50 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
275265 sids = nx .get_node_attributes (graph , "sid" )
276266 levels = nx .get_node_attributes (graph , "level" )
277267
278- sun_ids = []
268+ sun_ids = {}
279269 sids_list = []
280270
281271 selected_sids_dict = self .select_classes (graph )
282272
283273 for sun_id , level in levels .items ():
284- if level == self . scope_hierarchy_level and sun_id in selected_sids_dict :
285- sun_ids .append (sun_id )
274+ if sun_id in selected_sids_dict :
275+ sun_ids .setdefault ( level , []). append (sun_id )
286276 sids_list .append (sids .get (sun_id ))
287277
278+ # Remove root node, as it will True for all instances
279+ sun_ids .pop ("root" , None )
280+
288281 # data_df = pd.DataFrame(OrderedDict(sun_id=sun_ids, sids=sids_list))
289282 df_cla = self ._get_classification_data ()
290- target_col_name = self . SCOPE_HIERARCHY [ self . scope_hierarchy_level ]
291- df_cla = df_cla [ df_cla [ target_col_name ]. isin ( sun_ids )]
292- df_cla = df_cla [[ "sid" , target_col_name ] ]
283+
284+ for level , selected_sun_ids in sun_ids . items ():
285+ df_cla = df_cla [df_cla [ self . SCOPE_HIERARCHY [ level ]]. isin ( selected_sun_ids ) ]
293286
294287 assert (
295288 len (df_cla ) > 1
296289 ), "dataframe should have more than one instance for `pd.get_dummies` to work as expected"
297290 df_encoded = pd .get_dummies (
298- df_cla , columns = [target_col_name ], drop_first = False , sparse = True
291+ df_cla ,
292+ columns = list (self .SCOPE_HIERARCHY .values ()),
293+ drop_first = False ,
294+ sparse = True ,
299295 )
300296
301297 pdb_chain_seq_mapping = self ._parse_pdb_sequence_file ()
302298
303- sequence_hierarchy_df = pd .DataFrame (
304- columns = list (df_encoded .columns ) + ["sids" ]
305- )
299+ encoded_target_cols = {}
300+ for col in self .SCOPE_HIERARCHY .values ():
301+ encoded_target_cols [col ] = [
302+ t_col for t_col in df_encoded .columns if t_col .startswith (col )
303+ ]
304+
305+ encoded_target_columns = []
306+ for level in self .SCOPE_HIERARCHY .values ():
307+ encoded_target_columns .extend (encoded_target_cols [level ])
308+
309+ sequence_hierarchy_df = pd .DataFrame (columns = ["sids" ] + encoded_target_columns )
306310
307311 for _ , row in df_encoded .iterrows ():
308- assert sum (row .iloc [1 :].tolist ()) == 1
309312 sid = row ["sid" ]
310313 # SID: 7-char identifier ("d" + 4-char PDB ID + chain ID ('_' for none, '.' for multiple)
311314 # + domain specifier ('_' if not needed))
@@ -320,19 +323,22 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
320323 chain_sequence = pdb_to_chain_mapping .get (chain_id , None )
321324 if chain_sequence :
322325 self ._update_or_add_sequence (
323- chain_sequence , row , sequence_hierarchy_df
326+ chain_sequence , row , sequence_hierarchy_df , encoded_target_cols
324327 )
325328
326329 else :
327330 # Add nodes and edges for chains in the mapping
328331 for chain , chain_sequence in pdb_to_chain_mapping .items ():
329332 self ._update_or_add_sequence (
330- chain_sequence , row , sequence_hierarchy_df
333+ chain_sequence , row , sequence_hierarchy_df , encoded_target_cols
331334 )
332335
333336 sequence_hierarchy_df .drop (columns = ["sid" ], axis = 1 , inplace = True )
337+ sequence_hierarchy_df .reset_index (inplace = True )
338+ sequence_hierarchy_df ["id" ] = range (1 , len (sequence_hierarchy_df ) + 1 )
339+
334340 sequence_hierarchy_df = sequence_hierarchy_df [
335- ["sids" ] + [ col for col in sequence_hierarchy_df . columns if col != "sids" ]
341+ ["id" , "sids" , "sequence" ] + encoded_target_columns
336342 ]
337343
338344 # This filters the DataFrame to include only the rows where at least one value in the row from 5th column
@@ -355,19 +361,24 @@ def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]:
355361 return pdb_chain_seq_mapping
356362
357363 @staticmethod
358- def _update_or_add_sequence (sequence , row , sequence_hierarchy_df ):
359- # Check if sequence already exists as an index
360- # Slice the series starting from column 2
361- sliced_data = row .iloc [1 :] # Slice starting from the second column (index 1)
362-
363- # Get the column name with the True value
364- true_column = sliced_data .idxmax () if sliced_data .any () else None
365-
364+ def _update_or_add_sequence (
365+ sequence , row , sequence_hierarchy_df , encoded_col_names
366+ ):
366367 if sequence in sequence_hierarchy_df .index :
367368 # Update encoded columns only if they are True
368- if row [true_column ] is True :
369+ for col in encoded_col_names :
370+ assert (
371+ sum (row [encoded_col_names [col ]].tolist ()) == 1
372+ ), "A instance can belong to only one hierarchy level"
373+ sliced_data = row [
374+ encoded_col_names [col ]
375+ ] # Slice starting from the second column (index 1)
376+ # Get the column name with the True value
377+ true_column = sliced_data .idxmax () if sliced_data .any () else None
369378 sequence_hierarchy_df .loc [sequence , true_column ] = True
370- sequence_hierarchy_df .loc [sequence , "sids" ].append (row ["sid" ])
379+
380+ sequence_hierarchy_df .loc [sequence , "sids" ].append (row ["sid" ])
381+
371382 else :
372383 # Add new row with sequence as the index and hierarchy data
373384 new_row = row
@@ -446,7 +457,7 @@ def raw_file_names_dict(self) -> dict:
446457
447458class SCOPE (_SCOPeDataExtractor ):
448459 READER = ProteinDataReader
449- THRESHOLD = 1
460+ THRESHOLD = 10000
450461
451462 @property
452463 def _name (self ) -> str :
0 commit comments