Skip to content

Commit add85e3

Browse files
committed
deepgo: raise error if no classes are selected
1 parent 19ab4a7 commit add85e3

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,10 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
394394
)
395395
# Initialize the GO term labels/columns to False
396396
selected_classes = self.select_classes(g, data_df=data_df)
397+
if not selected_classes:
398+
raise ValueError(
399+
f"No classes selected for given threshold {self.THRESHOLD}"
400+
)
397401
new_label_columns = pd.DataFrame(
398402
False, index=data_df.index, columns=selected_classes
399403
)

chebai_proteins/preprocessing/datasets/scope/scope.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -728,8 +728,6 @@ def _setup_pruned_test_set(
728728
"""
729729
# TODO: find a more efficient way to do this
730730
filename_old = "classes.txt"
731-
# filename_new = f"classes_v{self.scope_version_train}.txt"
732-
# dataset = torch.load(os.path.join(self.processed_dir, "test.pt"))
733731

734732
# Load original classes (from the current SCOPe version - scope_version)
735733
with open(os.path.join(self.processed_dir_main, filename_old), "r") as file:
@@ -760,7 +758,7 @@ def _setup_pruned_test_set(
760758
# set the corresponding label in new_labels to True
761759
if mapping[ind] is not None and label:
762760
new_labels[mapping[ind]] = label
763-
# Update the labels from test instance from scope_version to the new labels, which are compatible to both versions
761+
# Update the labels from test instance of scope_version to new labels, which are compatible to both versions
764762
row["labels"] = new_labels
765763

766764
return df_test_scope_version

chebai_proteins/preprocessing/reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class ESM2EmbeddingReader(DataReader):
166166

167167
def __init__(
168168
self,
169+
# --------- Default Parameters as per DeepGO2 ------------
169170
save_model_dir: str = os.path.join("data", "esm2_reader"),
170171
model_name: str = "esm2_t36_3B_UR50D",
171172
device: Optional[torch.device] = None,

0 commit comments

Comments
 (0)