Skip to content

Commit 44c6e7e

Browse files
authored
Merge pull request #120 from ChEB-AI/fix/avoid_iterrows
Avoid using iterrows, use vectorization wherever possible
2 parents d52b422 + 13afc28 commit 44c6e7e

File tree

3 files changed

+59
-58
lines changed

3 files changed

+59
-58
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from itertools import cycle, permutations, product
1818
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
1919

20+
import numpy as np
2021
import pandas as pd
2122
import torch
2223
from rdkit import Chem
@@ -332,7 +333,7 @@ def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame:
332333

333334
data = pd.DataFrame(data)
334335
data = data[~data["SMILES"].isnull()]
335-
data = data[[name not in CHEBI_BLACKLIST for name, _ in data.iterrows()]]
336+
data = data[~data["name"].isin(CHEBI_BLACKLIST)]
336337

337338
return data
338339

@@ -459,18 +460,18 @@ def _load_dict(self, input_file_path: str) -> Generator[dict[str, Any], None, No
459460
"""
460461
with open(input_file_path, "rb") as input_file:
461462
df = pd.read_pickle(input_file)
462-
if self.single_class is not None:
463-
single_cls_index = list(df.columns).index(int(self.single_class))
464-
for row in df.values:
465-
if self.single_class is None:
466-
labels = row[self._LABELS_START_IDX :].astype(bool)
467-
else:
468-
labels = [bool(row[single_cls_index])]
469-
yield dict(
470-
features=row[self._DATA_REPRESENTATION_IDX],
471-
labels=labels,
472-
ident=row[self._ID_IDX],
473-
)
463+
464+
if self.single_class is None:
465+
all_labels = df.iloc[:, self._LABELS_START_IDX :].to_numpy(dtype=bool)
466+
else:
467+
single_cls_index = df.columns.get_loc(int(self.single_class))
468+
all_labels = df.iloc[:, [single_cls_index]].to_numpy(dtype=bool)
469+
470+
features = df.iloc[:, self._DATA_REPRESENTATION_IDX].to_numpy()
471+
idents = df.iloc[:, self._ID_IDX].to_numpy()
472+
473+
for feat, labels, ident in zip(features, all_labels, idents):
474+
yield dict(features=feat, labels=labels, ident=ident)
474475

475476
# ------------------------------ Phase: Dynamic Splits -----------------------------------
476477
def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
@@ -566,43 +567,43 @@ def _setup_pruned_test_set(
566567
Returns:
567568
pd.DataFrame: The pruned test dataset.
568569
"""
569-
# TODO: find a more efficient way to do this
570-
filename_old = "classes.txt"
571-
# filename_new = f"classes_v{self.chebi_version_train}.txt"
572-
# dataset = torch.load(os.path.join(self.processed_dir, "test.pt"))
573-
574-
# Load original classes (from the current ChEBI version - chebi_version)
575-
with open(os.path.join(self.processed_dir_main, filename_old), "r") as file:
576-
orig_classes = file.readlines()
570+
classes_file_name = "classes.txt"
577571

578-
# Load new classes (from the training ChEBI version - chebi_version_train)
572+
# Load original and new classes
573+
with open(os.path.join(self.processed_dir_main, classes_file_name), "r") as f:
574+
orig_classes = f.readlines()
579575
with open(
580576
os.path.join(
581-
self._chebi_version_train_obj.processed_dir_main, filename_old
577+
self._chebi_version_train_obj.processed_dir_main, classes_file_name
582578
),
583579
"r",
584-
) as file:
585-
new_classes = file.readlines()
586-
587-
# Create a mapping which give index of a class from chebi_version, if the corresponding
588-
# class exists in chebi_version_train, Size = Number of classes in chebi_version
589-
mapping = [
590-
None if or_class not in new_classes else new_classes.index(or_class)
591-
for or_class in orig_classes
592-
]
580+
) as f:
581+
new_classes = f.readlines()
582+
583+
# Mapping array (-1 means no match in new classes)
584+
mapping_array = np.array(
585+
[
586+
-1 if oc not in new_classes else new_classes.index(oc)
587+
for oc in orig_classes
588+
],
589+
dtype=int,
590+
)
591+
592+
# Convert labels column to 2D NumPy array
593+
labels_matrix = np.array(df_test_chebi_version["labels"].tolist(), dtype=bool)
594+
595+
# Allocate new labels matrix
596+
num_new_classes = len(new_classes)
597+
new_labels_matrix = np.zeros(
598+
(labels_matrix.shape[0], num_new_classes), dtype=bool
599+
)
593600

594-
# Iterate over each data instance in the test set which is derived from chebi_version
595-
for _, row in df_test_chebi_version.iterrows():
596-
# Size = Number of classes in chebi_version_train
597-
new_labels = [False for _ in new_classes]
598-
for ind, label in enumerate(row["labels"]):
599-
# If the chebi_version class exists in the chebi_version_train and has a True label,
600-
# set the corresponding label in new_labels to True
601-
if mapping[ind] is not None and label:
602-
new_labels[mapping[ind]] = label
603-
# Update the labels from test instance from chebi_version to the new labels, which are compatible to both versions
604-
row["labels"] = new_labels
601+
# Copy only valid columns
602+
valid_mask = mapping_array != -1
603+
new_labels_matrix[:, mapping_array[valid_mask]] = labels_matrix[:, valid_mask]
605604

605+
# Assign back
606+
df_test_chebi_version["labels"] = new_labels_matrix.tolist()
606607
return df_test_chebi_version
607608

608609
# ------------------------------ Phase: Raw Properties -----------------------------------

chebai/preprocessing/datasets/pubchem.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -628,8 +628,8 @@ def download(self):
628628
if not os.path.exists(os.path.join(self.raw_dir, f"{name}.txt")):
629629
open(os.path.join(self.raw_dir, f"{name}.txt"), "x").close()
630630
with open(os.path.join(self.raw_dir, f"{name}.txt"), "w") as f:
631-
for id, row in splits[i].iterrows():
632-
f.writelines(f"{id}\t{row['smiles']}\n")
631+
for id, row in splits[i].itertuples(index=True):
632+
f.writelines(f"{id}\t{row.smiles}\n")
633633

634634

635635
class PubChemDissimilarSMILES(PubChemDissimilar):
@@ -809,12 +809,12 @@ def download(self):
809809
csv_path = os.path.join(self.raw_dir, "pubchem_hazardous_compound_list.csv")
810810
compounds = pd.read_csv(csv_path)
811811
smiles_list = []
812-
for id, compound in compounds.iterrows():
812+
for compound in compounds.itertuples(index=False):
813813
if (
814-
not isinstance(compound["cmpdsynonym"], str)
815-
or "CHEBI" not in compound["cmpdsynonym"]
814+
not isinstance(compound.cmpdsynonym, str)
815+
or "CHEBI" not in compound.cmpdsynonym
816816
):
817-
smiles_list.append(f"{compound['cid']}\t{compound['isosmiles']}")
817+
smiles_list.append(f"{compound.cid}\t{compound.isosmiles}")
818818
with open(os.path.join(self.raw_dir, "smiles.txt"), "w") as f:
819819
f.write("\n".join(smiles_list))
820820

chebai/train.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,11 @@ def prepare_data(infile: pickle.Pickler) -> pd.DataFrame:
246246
data_frame[col] = data_frame[col].astype(int)
247247

248248
train_data = []
249-
for index, row in data_frame.iterrows():
249+
for row in data_frame.itertuples(index=False):
250250
train_data.append(
251251
[
252-
data_frame.iloc[index].values[1],
253-
data_frame.iloc[index].values[2:502].tolist(),
252+
row.SMILES,
253+
row.LABELS,
254254
]
255255
)
256256

@@ -309,28 +309,28 @@ def load_data() -> (
309309
train_dataset = []
310310
train_actual_labels = []
311311

312-
for index, row in prepare_data(train_infile).iterrows():
312+
for row in prepare_data(train_infile).itertuples(index=False):
313313
try:
314-
mol = Molecule(row["SMILES"], True)
314+
mol = Molecule(row.SMILES, True)
315315

316316
# DAGs_meta_info = mol.dag_to_node
317317
train_dataset.append(mol)
318-
train_actual_labels.append(torch.tensor(row["LABELS"]).float())
318+
train_actual_labels.append(torch.tensor(row.LABELS).float())
319319
except Exception:
320320
pass
321321

322322
print("prepare validation data!")
323323
validation_dataset = []
324324
validation_actual_labels = []
325325

326-
for index, row in prepare_data(validation_infile).iterrows():
326+
for row in prepare_data(validation_infile).itertuples(index=False):
327327
try:
328-
mol = Molecule(row["SMILES"], True)
328+
mol = Molecule(row.SMILES, True)
329329

330330
# DAGs_meta_info = mol.dag_to_node
331331

332332
validation_dataset.append(mol)
333-
validation_actual_labels.append(torch.tensor(row["LABELS"]).float())
333+
validation_actual_labels.append(torch.tensor(row.LABELS).float())
334334
except Exception:
335335
pass
336336

0 commit comments

Comments
 (0)