Skip to content

Commit 30ca5f6

Browse files
committed
optimize _setup_pruned_test_set logic
1 parent 567c68d commit 30ca5f6

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import fastobo
1919
import networkx as nx
20+
import numpy as np
2021
import pandas as pd
2122
import requests
2223
import torch
@@ -465,43 +466,43 @@ def _setup_pruned_test_set(
465466
Returns:
466467
pd.DataFrame: The pruned test dataset.
467468
"""
468-
# TODO: find a more efficient way to do this
469-
filename_old = "classes.txt"
470-
# filename_new = f"classes_v{self.chebi_version_train}.txt"
471-
# dataset = torch.load(os.path.join(self.processed_dir, "test.pt"))
469+
classes_file_name = "classes.txt"
472470

473-
# Load original classes (from the current ChEBI version - chebi_version)
474-
with open(os.path.join(self.processed_dir_main, filename_old), "r") as file:
475-
orig_classes = file.readlines()
476-
477-
# Load new classes (from the training ChEBI version - chebi_version_train)
471+
# Load original and new classes
472+
with open(os.path.join(self.processed_dir_main, classes_file_name), "r") as f:
473+
orig_classes = f.readlines()
478474
with open(
479475
os.path.join(
480-
self._chebi_version_train_obj.processed_dir_main, filename_old
476+
self._chebi_version_train_obj.processed_dir_main, classes_file_name
481477
),
482478
"r",
483-
) as file:
484-
new_classes = file.readlines()
485-
486-
# Create a mapping which give index of a class from chebi_version, if the corresponding
487-
# class exists in chebi_version_train, Size = Number of classes in chebi_version
488-
mapping = [
489-
None if or_class not in new_classes else new_classes.index(or_class)
490-
for or_class in orig_classes
491-
]
479+
) as f:
480+
new_classes = f.readlines()
481+
482+
# Mapping array (-1 means no match in new classes)
483+
mapping_array = np.array(
484+
[
485+
-1 if oc not in new_classes else new_classes.index(oc)
486+
for oc in orig_classes
487+
],
488+
dtype=int,
489+
)
490+
491+
# Convert labels column to 2D NumPy array
492+
labels_matrix = np.array(df_test_chebi_version["labels"].tolist(), dtype=bool)
493+
494+
# Allocate new labels matrix
495+
num_new_classes = len(new_classes)
496+
new_labels_matrix = np.zeros(
497+
(labels_matrix.shape[0], num_new_classes), dtype=bool
498+
)
492499

493-
# Iterate over each data instance in the test set which is derived from chebi_version
494-
for row in df_test_chebi_version.itertuples(index=False):
495-
# Size = Number of classes in chebi_version_train
496-
new_labels = [False for _ in new_classes]
497-
for ind, label in enumerate(row.labels):
498-
# If the chebi_version class exists in the chebi_version_train and has a True label,
499-
# set the corresponding label in new_labels to True
500-
if mapping[ind] is not None and label:
501-
new_labels[mapping[ind]] = label
502-
# Update the labels from test instance from chebi_version to the new labels, which are compatible to both versions
503-
row["labels"] = new_labels
500+
# Copy only valid columns
501+
valid_mask = mapping_array != -1
502+
new_labels_matrix[:, mapping_array[valid_mask]] = labels_matrix[:, valid_mask]
504503

504+
# Assign back
505+
df_test_chebi_version["labels"] = new_labels_matrix.tolist()
505506
return df_test_chebi_version
506507

507508
# ------------------------------ Phase: Raw Properties -----------------------------------

0 commit comments

Comments
 (0)