Skip to content

Commit 431da47

Browse files
committed
scope: include all levels
1 parent f4d1d74 commit 431da47

File tree

1 file changed

+48
-37
lines changed
  • chebai/preprocessing/datasets/scope

1 file changed

+48
-37
lines changed

chebai/preprocessing/datasets/scope/scope.py

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,15 @@ class _SCOPeDataExtractor(_DynamicDataset, ABC):
3333
**kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule.
3434
"""
3535

36-
_GO_DATA_INIT = "GO"
37-
_SWISS_DATA_INIT = "SWISS"
38-
3936
# -- Index for columns of processed `data.pkl` (derived from `_get_swiss_to_go_mapping` & `_graph_to_raw_dataset`
4037
# "swiss_id" at row index 0
4138
# "accession" at row index 1
4239
# "go_ids" at row index 2
4340
# "sequence" at row index 3
4441
# labels starting from row index 4
4542
_ID_IDX: int = 0
46-
_DATA_REPRESENTATION_IDX: int = 3 # here `sequence` column
47-
_LABELS_START_IDX: int = 4
43+
_DATA_REPRESENTATION_IDX: int = 2 # here `sequence` column
44+
_LABELS_START_IDX: int = 3
4845

4946
_SCOPE_GENERAL_URL = "https://scop.berkeley.edu/downloads/parse/dir.{data_type}.scope.{version_number}-stable.txt"
5047
_PDB_SEQUENCE_DATA_URL = (
@@ -65,14 +62,8 @@ def __init__(
6562
self,
6663
scope_version: float,
6764
scope_version_train: Optional[float] = None,
68-
scope_hierarchy_level: str = "cl",
6965
**kwargs,
7066
):
71-
72-
assert (
73-
scope_hierarchy_level in self.SCOPE_HIERARCHY.keys()
74-
), f"level can contain only one of the following values {self.SCOPE_HIERARCHY.keys()}"
75-
self.scope_hierarchy_level = scope_hierarchy_level
7667
self.scope_version: float = scope_version
7768
self.scope_version_train: float = scope_version_train
7869

@@ -83,7 +74,6 @@ def __init__(
8374
# This is to get the data from respective directory related to "scope_version_train"
8475
_init_kwargs = kwargs
8576
_init_kwargs["scope_version"] = self.scope_version_train
86-
_init_kwargs["scope_hierarchy_level"] = self.scope_hierarchy_level
8777
self._scope_version_train_obj = self.__class__(
8878
**_init_kwargs,
8979
)
@@ -275,37 +265,50 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
275265
sids = nx.get_node_attributes(graph, "sid")
276266
levels = nx.get_node_attributes(graph, "level")
277267

278-
sun_ids = []
268+
sun_ids = {}
279269
sids_list = []
280270

281271
selected_sids_dict = self.select_classes(graph)
282272

283273
for sun_id, level in levels.items():
284-
if level == self.scope_hierarchy_level and sun_id in selected_sids_dict:
285-
sun_ids.append(sun_id)
274+
if sun_id in selected_sids_dict:
275+
sun_ids.setdefault(level, []).append(sun_id)
286276
sids_list.append(sids.get(sun_id))
287277

278+
# Remove root node, as it will True for all instances
279+
sun_ids.pop("root", None)
280+
288281
# data_df = pd.DataFrame(OrderedDict(sun_id=sun_ids, sids=sids_list))
289282
df_cla = self._get_classification_data()
290-
target_col_name = self.SCOPE_HIERARCHY[self.scope_hierarchy_level]
291-
df_cla = df_cla[df_cla[target_col_name].isin(sun_ids)]
292-
df_cla = df_cla[["sid", target_col_name]]
283+
284+
for level, selected_sun_ids in sun_ids.items():
285+
df_cla = df_cla[df_cla[self.SCOPE_HIERARCHY[level]].isin(selected_sun_ids)]
293286

294287
assert (
295288
len(df_cla) > 1
296289
), "dataframe should have more than one instance for `pd.get_dummies` to work as expected"
297290
df_encoded = pd.get_dummies(
298-
df_cla, columns=[target_col_name], drop_first=False, sparse=True
291+
df_cla,
292+
columns=list(self.SCOPE_HIERARCHY.values()),
293+
drop_first=False,
294+
sparse=True,
299295
)
300296

301297
pdb_chain_seq_mapping = self._parse_pdb_sequence_file()
302298

303-
sequence_hierarchy_df = pd.DataFrame(
304-
columns=list(df_encoded.columns) + ["sids"]
305-
)
299+
encoded_target_cols = {}
300+
for col in self.SCOPE_HIERARCHY.values():
301+
encoded_target_cols[col] = [
302+
t_col for t_col in df_encoded.columns if t_col.startswith(col)
303+
]
304+
305+
encoded_target_columns = []
306+
for level in self.SCOPE_HIERARCHY.values():
307+
encoded_target_columns.extend(encoded_target_cols[level])
308+
309+
sequence_hierarchy_df = pd.DataFrame(columns=["sids"] + encoded_target_columns)
306310

307311
for _, row in df_encoded.iterrows():
308-
assert sum(row.iloc[1:].tolist()) == 1
309312
sid = row["sid"]
310313
# SID: 7-char identifier ("d" + 4-char PDB ID + chain ID ('_' for none, '.' for multiple)
311314
# + domain specifier ('_' if not needed))
@@ -320,19 +323,22 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
320323
chain_sequence = pdb_to_chain_mapping.get(chain_id, None)
321324
if chain_sequence:
322325
self._update_or_add_sequence(
323-
chain_sequence, row, sequence_hierarchy_df
326+
chain_sequence, row, sequence_hierarchy_df, encoded_target_cols
324327
)
325328

326329
else:
327330
# Add nodes and edges for chains in the mapping
328331
for chain, chain_sequence in pdb_to_chain_mapping.items():
329332
self._update_or_add_sequence(
330-
chain_sequence, row, sequence_hierarchy_df
333+
chain_sequence, row, sequence_hierarchy_df, encoded_target_cols
331334
)
332335

333336
sequence_hierarchy_df.drop(columns=["sid"], axis=1, inplace=True)
337+
sequence_hierarchy_df.reset_index(inplace=True)
338+
sequence_hierarchy_df["id"] = range(1, len(sequence_hierarchy_df) + 1)
339+
334340
sequence_hierarchy_df = sequence_hierarchy_df[
335-
["sids"] + [col for col in sequence_hierarchy_df.columns if col != "sids"]
341+
["id", "sids", "sequence"] + encoded_target_columns
336342
]
337343

338344
# This filters the DataFrame to include only the rows where at least one value in the row from 5th column
@@ -355,19 +361,24 @@ def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]:
355361
return pdb_chain_seq_mapping
356362

357363
@staticmethod
358-
def _update_or_add_sequence(sequence, row, sequence_hierarchy_df):
359-
# Check if sequence already exists as an index
360-
# Slice the series starting from column 2
361-
sliced_data = row.iloc[1:] # Slice starting from the second column (index 1)
362-
363-
# Get the column name with the True value
364-
true_column = sliced_data.idxmax() if sliced_data.any() else None
365-
364+
def _update_or_add_sequence(
365+
sequence, row, sequence_hierarchy_df, encoded_col_names
366+
):
366367
if sequence in sequence_hierarchy_df.index:
367368
# Update encoded columns only if they are True
368-
if row[true_column] is True:
369+
for col in encoded_col_names:
370+
assert (
371+
sum(row[encoded_col_names[col]].tolist()) == 1
372+
), "A instance can belong to only one hierarchy level"
373+
sliced_data = row[
374+
encoded_col_names[col]
375+
] # Slice starting from the second column (index 1)
376+
# Get the column name with the True value
377+
true_column = sliced_data.idxmax() if sliced_data.any() else None
369378
sequence_hierarchy_df.loc[sequence, true_column] = True
370-
sequence_hierarchy_df.loc[sequence, "sids"].append(row["sid"])
379+
380+
sequence_hierarchy_df.loc[sequence, "sids"].append(row["sid"])
381+
371382
else:
372383
# Add new row with sequence as the index and hierarchy data
373384
new_row = row
@@ -446,7 +457,7 @@ def raw_file_names_dict(self) -> dict:
446457

447458
class SCOPE(_SCOPeDataExtractor):
448459
READER = ProteinDataReader
449-
THRESHOLD = 1
460+
THRESHOLD = 10000
450461

451462
@property
452463
def _name(self) -> str:

0 commit comments

Comments
 (0)