@@ -710,13 +710,17 @@ class _DynamicDataset(XYBaseDataModule, ABC):
710710 dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42.
711711 splits_file_path (str, optional): Path to the splits CSV file. Defaults to None.
712712 apply_label_filter (Optional[str]): Path to a classes.txt file - only labels that are in the labels filter
713- file will be used (in that order). All labels in the label filter have to be present in the dataset.
713+ file will be used (in that order). All labels in the label filter have to be present in the dataset. This filter
714+ is only active when loading splits from a CSV file. Defaults to None.
715+ apply_id_filter (Optional[str]): Path to a data.pt file from a different dataset - only IDs that are in the
716+ id filter file will be used. Defaults to None. This filter is only active when loading splits from a CSV file.
714717 **kwargs: Additional keyword arguments passed to XYBaseDataModule.
715718
716719 Attributes:
717720 dynamic_data_split_seed (int): The seed for random data splitting, default is 42.
718721 splits_file_path (Optional[str]): Path to the CSV file containing split assignments.
719722 apply_label_filter (Optional[str]): Path to a classes.txt file for label filtering.
723+ apply_id_filter (Optional[str]): Path to a data.pt file for ID filtering.
720724 """
721725
722726 # ---- Index for columns of processed `data.pkl` (should be derived from `_graph_to_raw_dataset` method) ------
@@ -727,6 +731,7 @@ class _DynamicDataset(XYBaseDataModule, ABC):
727731 def __init__ (
728732 self ,
729733 apply_label_filter : Optional [str ] = None ,
734+ apply_id_filter : Optional [str ] = None ,
730735 ** kwargs ,
731736 ):
732737 super (_DynamicDataset , self ).__init__ (** kwargs )
@@ -741,6 +746,7 @@ def __init__(
741746 kwargs .get ("splits_file_path" , None )
742747 )
743748 self .apply_label_filter = apply_label_filter
749+ self .apply_id_filter = apply_id_filter
744750
745751 @staticmethod
746752 def _validate_splits_file_path (splits_file_path : Optional [str ]) -> Optional [str ]:
@@ -1140,6 +1146,15 @@ def _retrieve_splits_from_csv(self) -> None:
11401146 )
11411147 df_data = pd .DataFrame (data )
11421148
1149+ if self .apply_id_filter :
1150+ print (f"Applying ID filter from { self .apply_id_filter } ..." )
1151+ with open (self .apply_id_filter , "r" ) as f :
1152+ id_filter = [
1153+ line ["ident" ]
1154+ for line in torch .load (self .apply_id_filter , weights_only = False )
1155+ ]
1156+ df_data = df_data [df_data ["ident" ].isin (id_filter )]
1157+
11431158 if self .apply_label_filter :
11441159 print (f"Applying label filter from { self .apply_label_filter } ..." )
11451160 with open (self .apply_label_filter , "r" ) as f :
0 commit comments