Skip to content

Commit c791893

Browse files
committed
scope: avoid data fragmentation and add progress bar
1 parent d3fd0f2 commit c791893

File tree

1 file changed

+20
-2
lines changed
  • chebai/preprocessing/datasets/scope

1 file changed

+20
-2
lines changed

chebai/preprocessing/datasets/scope/scope.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import requests
2424
import torch
2525
from Bio import SeqIO
26+
from tqdm import tqdm
2627

2728
from chebai.preprocessing.datasets.base import _DynamicDataset
2829
from chebai.preprocessing.reader import ProteinDataReader
@@ -365,6 +366,9 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
365366
# Initialize selected target columns
366367
df_encoded = df_cla[["sid", "sunid"]].copy()
367368

369+
# Collect all new columns in a dictionary first (avoids fragmentation)
370+
encoded_df_columns = {}
371+
368372
lvl_to_target_cols_mapping = {}
369373
# Iterate over only the selected sun_ids (nodes) to one-hot encode them
370374
for level, selected_sun_ids in selected_sun_ids_per_lvl.items():
@@ -375,11 +379,17 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
375379
# Create binary encoding for only relevant sun_ids
376380
for sun_id in selected_sun_ids:
377381
col_name = f"{level_column}_{sun_id}"
378-
df_encoded[col_name] = (df_cla[level_column] == sun_id).astype(bool)
382+
encoded_df_columns[col_name] = (
383+
df_cla[level_column] == sun_id
384+
).astype(bool)
385+
379386
lvl_to_target_cols_mapping.setdefault(level_column, []).append(
380387
col_name
381388
)
382389

390+
# Convert the dictionary into a DataFrame and concatenate at once (prevents fragmentation)
391+
df_encoded = pd.concat([df_encoded, pd.DataFrame(encoded_df_columns)], axis=1)
392+
383393
# Filter to select only domains that atleast map to any one selected sunid in any level
384394
df_encoded = df_encoded[df_encoded.iloc[:, 2:].any(axis=1)]
385395

@@ -392,7 +402,15 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
392402
sequence_hierarchy_df = pd.DataFrame(columns=["sids"] + encoded_target_columns)
393403
df_encoded = df_encoded[["sid", "sunid"] + encoded_target_columns]
394404

395-
for _, row in df_encoded.iterrows():
405+
print(
406+
f"{len(encoded_target_columns)} labels has been selected for specified threshold, "
407+
f"Max possible size of dataset is {len(df_encoded)} rows x {len(encoded_target_columns) + 1} columns"
408+
)
409+
print("Constructing data.pkl file .....")
410+
411+
for _, row in tqdm(
412+
df_encoded.iterrows(), total=len(df_encoded), desc="Processing Rows"
413+
):
396414
sid = row["sid"]
397415
# SID: 7-char identifier ("d" + 4-char PDB ID + chain ID ('_' for none, '.' for multiple)
398416
# + domain specifier ('_' if not needed))

0 commit comments

Comments
 (0)