|
17 | 17 |
|
18 | 18 | import fastobo |
19 | 19 | import networkx as nx |
| 20 | +import numpy as np |
20 | 21 | import pandas as pd |
21 | 22 | import requests |
22 | 23 | import torch |
@@ -465,43 +466,43 @@ def _setup_pruned_test_set( |
465 | 466 | Returns: |
466 | 467 | pd.DataFrame: The pruned test dataset. |
467 | 468 | """ |
468 | | - # TODO: find a more efficient way to do this |
469 | | - filename_old = "classes.txt" |
470 | | - # filename_new = f"classes_v{self.chebi_version_train}.txt" |
471 | | - # dataset = torch.load(os.path.join(self.processed_dir, "test.pt")) |
| 469 | + classes_file_name = "classes.txt" |
472 | 470 |
|
473 | | - # Load original classes (from the current ChEBI version - chebi_version) |
474 | | - with open(os.path.join(self.processed_dir_main, filename_old), "r") as file: |
475 | | - orig_classes = file.readlines() |
476 | | - |
477 | | - # Load new classes (from the training ChEBI version - chebi_version_train) |
| 471 | + # Load original and new classes |
| 472 | + with open(os.path.join(self.processed_dir_main, classes_file_name), "r") as f: |
| 473 | + orig_classes = f.readlines() |
478 | 474 | with open( |
479 | 475 | os.path.join( |
480 | | - self._chebi_version_train_obj.processed_dir_main, filename_old |
| 476 | + self._chebi_version_train_obj.processed_dir_main, classes_file_name |
481 | 477 | ), |
482 | 478 | "r", |
483 | | - ) as file: |
484 | | - new_classes = file.readlines() |
485 | | - |
486 | | - # Create a mapping which give index of a class from chebi_version, if the corresponding |
487 | | - # class exists in chebi_version_train, Size = Number of classes in chebi_version |
488 | | - mapping = [ |
489 | | - None if or_class not in new_classes else new_classes.index(or_class) |
490 | | - for or_class in orig_classes |
491 | | - ] |
| 479 | + ) as f: |
| 480 | + new_classes = f.readlines() |
| 481 | + |
| 482 | + # Mapping array (-1 means no match in new classes) |
| 483 | + mapping_array = np.array( |
| 484 | + [ |
| 485 | + -1 if oc not in new_classes else new_classes.index(oc) |
| 486 | + for oc in orig_classes |
| 487 | + ], |
| 488 | + dtype=int, |
| 489 | + ) |
| 490 | + |
| 491 | + # Convert labels column to 2D NumPy array |
| 492 | + labels_matrix = np.array(df_test_chebi_version["labels"].tolist(), dtype=bool) |
| 493 | + |
| 494 | + # Allocate new labels matrix |
| 495 | + num_new_classes = len(new_classes) |
| 496 | + new_labels_matrix = np.zeros( |
| 497 | + (labels_matrix.shape[0], num_new_classes), dtype=bool |
| 498 | + ) |
492 | 499 |
|
493 | | - # Iterate over each data instance in the test set which is derived from chebi_version |
494 | | - for row in df_test_chebi_version.itertuples(index=False): |
495 | | - # Size = Number of classes in chebi_version_train |
496 | | - new_labels = [False for _ in new_classes] |
497 | | - for ind, label in enumerate(row.labels): |
498 | | - # If the chebi_version class exists in the chebi_version_train and has a True label, |
499 | | - # set the corresponding label in new_labels to True |
500 | | - if mapping[ind] is not None and label: |
501 | | - new_labels[mapping[ind]] = label |
502 | | - # Update the labels from test instance from chebi_version to the new labels, which are compatible to both versions |
503 | | - row["labels"] = new_labels |
| 500 | + # Copy only valid columns |
| 501 | + valid_mask = mapping_array != -1 |
| 502 | + new_labels_matrix[:, mapping_array[valid_mask]] = labels_matrix[:, valid_mask] |
504 | 503 |
|
| 504 | + # Assign back |
| 505 | + df_test_chebi_version["labels"] = new_labels_matrix.tolist() |
505 | 506 | return df_test_chebi_version |
506 | 507 |
|
507 | 508 | # ------------------------------ Phase: Raw Properties ----------------------------------- |
|
0 commit comments