|
24 | 24 | import requests |
25 | 25 | import torch |
26 | 26 | from Bio import SwissProt |
| 27 | +from torch.utils.data import DataLoader |
27 | 28 |
|
28 | 29 | from chebai.preprocessing import reader as dr |
29 | 30 | from chebai.preprocessing.datasets.base import _DynamicDataset |
@@ -74,6 +75,11 @@ def __init__(self, **kwargs): |
74 | 75 | self.go_branch: str = self._get_go_branch(**kwargs) |
75 | 76 | super(_GOUniProtDataExtractor, self).__init__(**kwargs) |
76 | 77 |
|
| 78 | + self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002)) |
| 79 | + assert ( |
| 80 | + self.max_sequence_length >= 1 |
| 81 | + ), "Max sequence length should be greater than or equal to 1." |
| 82 | + |
77 | 83 | @classmethod |
78 | 84 | def _get_go_branch(cls, **kwargs) -> str: |
79 | 85 | """ |
@@ -397,15 +403,14 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: |
397 | 403 | } |
398 | 404 | # https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 |
399 | 405 | AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "X", "Z", "*"} |
400 | | - MAX_LENGTH = 1002 |
401 | 406 |
|
402 | 407 | for record in swiss_data: |
403 | 408 | if record.data_class != "Reviewed": |
404 | 409 | # To consider only manually-annotated swiss data |
405 | 410 | continue |
406 | 411 |
|
407 | | - if not record.sequence or record.sequence_length > MAX_LENGTH: |
408 | | - # Consider protein with only sequence representation and a maximum length of 1002 |
| 412 | + if not record.sequence: |
| 413 | + # Consider protein with only sequence representation |
409 | 414 | continue |
410 | 415 |
|
411 | 416 | if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence): |
@@ -524,6 +529,29 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: |
524 | 529 |
|
525 | 530 | return df_train, df_val, df_test |
526 | 531 |
|
| 532 | + # ------------------------------ Phase: DataLoaders ----------------------------------- |
| 533 | + def dataloader(self, kind: str, **kwargs) -> DataLoader: |
| 534 | + """ |
| 535 | + Returns a DataLoader object with truncated sequences for the specified kind of data (train, val, or test). |
| 536 | +
|
| 537 | + This method overrides the dataloader method from the superclass. After fetching the dataset from the |
| 538 | + superclass, it truncates the 'features' of each data instance to a maximum length specified by |
| 539 | + `self.max_sequence_length`. |
| 540 | +
|
| 541 | + Args: |
| 542 | + kind (str): The kind of data to load (e.g., 'train', 'val', 'test'). |
| 543 | + **kwargs: Additional keyword arguments passed to the superclass dataloader method. |
| 544 | +
|
| 545 | + Returns: |
| 546 | + DataLoader: A DataLoader object with the truncated sequences. |
| 547 | + """ |
| 548 | + dataloader = super().dataloader(kind, **kwargs) |
| 549 | + |
| 550 | + # Truncate the 'features' to max_sequence_length for each instance |
| 551 | + for instance in dataloader.dataset: |
| 552 | + instance["features"] = instance["features"][: self.max_sequence_length] |
| 553 | + return dataloader |
| 554 | + |
527 | 555 | # ------------------------------ Phase: Raw Properties ----------------------------------- |
528 | 556 | @property |
529 | 557 | def base_dir(self) -> str: |
@@ -619,7 +647,7 @@ def select_classes( |
619 | 647 | - The `THRESHOLD` attribute, which defines the minimum number of annotations required to select a GO term, should be defined in the subclass. |
620 | 648 | """ |
621 | 649 | # Retrieve the DataFrame containing GO annotations per protein from the keyword arguments |
622 | | - data_df: pd.DataFrame = kwargs.get("data_df", None) |
| 650 | + data_df = kwargs.get("data_df", None) |
623 | 651 | if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty: |
624 | 652 | raise AttributeError( |
625 | 653 | "The 'data_df' argument must be provided and must be a non-empty pandas DataFrame." |
|
0 commit comments