Skip to content

Commit a15d492

Browse files
committed
deepgo1: create non-exclusive val set as a placeholder
1 parent 99b5af1 commit a15d492

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ def _load_data(self) -> None:
8989
# self._validation_df = pd.DataFrame(
9090
# pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl"))
9191
# )
92+
93+
# DeepGO1 data does not include a separate validation split, but our data structure requires one.
94+
# To accommodate this, we will create a placeholder validation split by duplicating a small subset of the
95+
# training data. However, to ensure a fair comparison with DeepGO1, we will retain the full training set
96+
# without creating an exclusive validation split from it.
97+
# Therefore, any metrics calculated on this placeholder validation set should be disregarded, as they do not
98+
# reflect true validation performance.
99+
self._validation_df = self._train_df[len(self._train_df) - 5 :]
92100
self._terms_df = pd.DataFrame(
93101
pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl"))
94102
)
@@ -106,9 +114,9 @@ def _record_splits(self) -> pd.DataFrame:
106114
print("Recording splits...")
107115
split_assignment_list: List[pd.DataFrame] = [
108116
pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}),
109-
# pd.DataFrame(
110-
# {"id": self._validation_df["proteins"], "split": "validation"}
111-
# ),
117+
pd.DataFrame(
118+
{"id": self._validation_df["proteins"], "split": "validation"}
119+
),
112120
pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}),
113121
]
114122

@@ -125,7 +133,7 @@ def migrate(self) -> None:
125133
df is not None
126134
for df in [
127135
self._train_df,
128-
# self._validation_df,
136+
self._validation_df,
129137
self._test_df,
130138
self._terms_df,
131139
]
@@ -166,7 +174,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame:
166174
new_df = pd.concat(
167175
[
168176
self._train_df[required_columns],
169-
# self._validation_df[required_columns],
177+
self._validation_df[required_columns],
170178
self._test_df[required_columns],
171179
],
172180
ignore_index=True,

0 commit comments

Comments
 (0)