|
17 | 17 | from itertools import cycle, permutations, product |
18 | 18 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union |
19 | 19 |
|
| 20 | +import numpy as np |
20 | 21 | import pandas as pd |
21 | 22 | import torch |
22 | 23 | from rdkit import Chem |
@@ -332,7 +333,7 @@ def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame: |
332 | 333 |
|
333 | 334 | data = pd.DataFrame(data) |
334 | 335 | data = data[~data["SMILES"].isnull()] |
335 | | - data = data[[name not in CHEBI_BLACKLIST for name, _ in data.iterrows()]] |
| 336 | + data = data[~data["name"].isin(CHEBI_BLACKLIST)] |
336 | 337 |
|
337 | 338 | return data |
338 | 339 |
|
@@ -459,18 +460,18 @@ def _load_dict(self, input_file_path: str) -> Generator[dict[str, Any], None, No |
459 | 460 | """ |
460 | 461 | with open(input_file_path, "rb") as input_file: |
461 | 462 | df = pd.read_pickle(input_file) |
462 | | - if self.single_class is not None: |
463 | | - single_cls_index = list(df.columns).index(int(self.single_class)) |
464 | | - for row in df.values: |
465 | | - if self.single_class is None: |
466 | | - labels = row[self._LABELS_START_IDX :].astype(bool) |
467 | | - else: |
468 | | - labels = [bool(row[single_cls_index])] |
469 | | - yield dict( |
470 | | - features=row[self._DATA_REPRESENTATION_IDX], |
471 | | - labels=labels, |
472 | | - ident=row[self._ID_IDX], |
473 | | - ) |
| 463 | + |
| 464 | + if self.single_class is None: |
| 465 | + all_labels = df.iloc[:, self._LABELS_START_IDX :].to_numpy(dtype=bool) |
| 466 | + else: |
| 467 | + single_cls_index = df.columns.get_loc(int(self.single_class)) |
| 468 | + all_labels = df.iloc[:, [single_cls_index]].to_numpy(dtype=bool) |
| 469 | + |
| 470 | + features = df.iloc[:, self._DATA_REPRESENTATION_IDX].to_numpy() |
| 471 | + idents = df.iloc[:, self._ID_IDX].to_numpy() |
| 472 | + |
| 473 | + for feat, labels, ident in zip(features, all_labels, idents): |
| 474 | + yield dict(features=feat, labels=labels, ident=ident) |
474 | 475 |
|
475 | 476 | # ------------------------------ Phase: Dynamic Splits ----------------------------------- |
476 | 477 | def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: |
@@ -566,43 +567,43 @@ def _setup_pruned_test_set( |
566 | 567 | Returns: |
567 | 568 | pd.DataFrame: The pruned test dataset. |
568 | 569 | """ |
569 | | - # TODO: find a more efficient way to do this |
570 | | - filename_old = "classes.txt" |
571 | | - # filename_new = f"classes_v{self.chebi_version_train}.txt" |
572 | | - # dataset = torch.load(os.path.join(self.processed_dir, "test.pt")) |
573 | | - |
574 | | - # Load original classes (from the current ChEBI version - chebi_version) |
575 | | - with open(os.path.join(self.processed_dir_main, filename_old), "r") as file: |
576 | | - orig_classes = file.readlines() |
| 570 | + classes_file_name = "classes.txt" |
577 | 571 |
|
578 | | - # Load new classes (from the training ChEBI version - chebi_version_train) |
| 572 | + # Load original and new classes |
| 573 | + with open(os.path.join(self.processed_dir_main, classes_file_name), "r") as f: |
| 574 | + orig_classes = f.readlines() |
579 | 575 | with open( |
580 | 576 | os.path.join( |
581 | | - self._chebi_version_train_obj.processed_dir_main, filename_old |
| 577 | + self._chebi_version_train_obj.processed_dir_main, classes_file_name |
582 | 578 | ), |
583 | 579 | "r", |
584 | | - ) as file: |
585 | | - new_classes = file.readlines() |
586 | | - |
587 | | - # Create a mapping which give index of a class from chebi_version, if the corresponding |
588 | | - # class exists in chebi_version_train, Size = Number of classes in chebi_version |
589 | | - mapping = [ |
590 | | - None if or_class not in new_classes else new_classes.index(or_class) |
591 | | - for or_class in orig_classes |
592 | | - ] |
| 580 | + ) as f: |
| 581 | + new_classes = f.readlines() |
| 582 | + |
| 583 | + # Mapping array (-1 means no match in new classes) |
| 584 | + mapping_array = np.array( |
| 585 | + [ |
| 586 | + -1 if oc not in new_classes else new_classes.index(oc) |
| 587 | + for oc in orig_classes |
| 588 | + ], |
| 589 | + dtype=int, |
| 590 | + ) |
| 591 | + |
| 592 | + # Convert labels column to 2D NumPy array |
| 593 | + labels_matrix = np.array(df_test_chebi_version["labels"].tolist(), dtype=bool) |
| 594 | + |
| 595 | + # Allocate new labels matrix |
| 596 | + num_new_classes = len(new_classes) |
| 597 | + new_labels_matrix = np.zeros( |
| 598 | + (labels_matrix.shape[0], num_new_classes), dtype=bool |
| 599 | + ) |
593 | 600 |
|
594 | | - # Iterate over each data instance in the test set which is derived from chebi_version |
595 | | - for _, row in df_test_chebi_version.iterrows(): |
596 | | - # Size = Number of classes in chebi_version_train |
597 | | - new_labels = [False for _ in new_classes] |
598 | | - for ind, label in enumerate(row["labels"]): |
599 | | - # If the chebi_version class exists in the chebi_version_train and has a True label, |
600 | | - # set the corresponding label in new_labels to True |
601 | | - if mapping[ind] is not None and label: |
602 | | - new_labels[mapping[ind]] = label |
603 | | - # Update the labels from test instance from chebi_version to the new labels, which are compatible to both versions |
604 | | - row["labels"] = new_labels |
| 601 | + # Copy only valid columns |
| 602 | + valid_mask = mapping_array != -1 |
| 603 | + new_labels_matrix[:, mapping_array[valid_mask]] = labels_matrix[:, valid_mask] |
605 | 604 |
|
| 605 | + # Assign back |
| 606 | + df_test_chebi_version["labels"] = new_labels_matrix.tolist() |
606 | 607 | return df_test_chebi_version |
607 | 608 |
|
608 | 609 | # ------------------------------ Phase: Raw Properties ----------------------------------- |
|
0 commit comments