Skip to content

Commit dfc4db9

Browse files
committed
add id filter
1 parent 4ab760e commit dfc4db9

File tree

1 file changed

+16
-1
lines changed
  • chebai/preprocessing/datasets

1 file changed

+16
-1
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,13 +710,17 @@ class _DynamicDataset(XYBaseDataModule, ABC):
710710
dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42.
711711
splits_file_path (str, optional): Path to the splits CSV file. Defaults to None.
712712
apply_label_filter (Optional[str]): Path to a classes.txt file - only labels that are in the labels filter
713-
file will be used (in that order). All labels in the label filter have to be present in the dataset.
713+
file will be used (in that order). All labels in the label filter have to be present in the dataset. This filter
714+
is only active when loading splits from a CSV file. Defaults to None.
715+
apply_id_filter (Optional[str]): Path to a data.pt file from a different dataset - only IDs that are in the
716+
id filter file will be used. Defaults to None. This filter is only active when loading splits from a CSV file.
714717
**kwargs: Additional keyword arguments passed to XYBaseDataModule.
715718
716719
Attributes:
717720
dynamic_data_split_seed (int): The seed for random data splitting, default is 42.
718721
splits_file_path (Optional[str]): Path to the CSV file containing split assignments.
719722
apply_label_filter (Optional[str]): Path to a classes.txt file for label filtering.
723+
apply_id_filter (Optional[str]): Path to a data.pt file for ID filtering.
720724
"""
721725

722726
# ---- Index for columns of processed `data.pkl` (should be derived from `_graph_to_raw_dataset` method) ------
@@ -727,6 +731,7 @@ class _DynamicDataset(XYBaseDataModule, ABC):
727731
def __init__(
728732
self,
729733
apply_label_filter: Optional[str] = None,
734+
apply_id_filter: Optional[str] = None,
730735
**kwargs,
731736
):
732737
super(_DynamicDataset, self).__init__(**kwargs)
@@ -741,6 +746,7 @@ def __init__(
741746
kwargs.get("splits_file_path", None)
742747
)
743748
self.apply_label_filter = apply_label_filter
749+
self.apply_id_filter = apply_id_filter
744750

745751
@staticmethod
746752
def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]:
@@ -1140,6 +1146,15 @@ def _retrieve_splits_from_csv(self) -> None:
11401146
)
11411147
df_data = pd.DataFrame(data)
11421148

1149+
if self.apply_id_filter:
1150+
print(f"Applying ID filter from {self.apply_id_filter}...")
1151+
with open(self.apply_id_filter, "r") as f:
1152+
id_filter = [
1153+
line["ident"]
1154+
for line in torch.load(self.apply_id_filter, weights_only=False)
1155+
]
1156+
df_data = df_data[df_data["ident"].isin(id_filter)]
1157+
11431158
if self.apply_label_filter:
11441159
print(f"Applying label filter from {self.apply_label_filter}...")
11451160
with open(self.apply_label_filter, "r") as f:

0 commit comments

Comments
 (0)