44from typing import TYPE_CHECKING , Any , Dict , Generator , List , Optional , Tuple , Union
55
66import lightning as pl
7+ import numpy as np
78import pandas as pd
89import torch
910import tqdm
@@ -708,11 +709,14 @@ class _DynamicDataset(XYBaseDataModule, ABC):
708709 Args:
709710 dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42.
710711 splits_file_path (str, optional): Path to the splits CSV file. Defaults to None.
712+ apply_label_filter (Optional[str]): Path to a classes.txt file - only labels that are in the labels filter
713+ file will be used (in that order). All labels in the label filter have to be present in the dataset.
711714 **kwargs: Additional keyword arguments passed to XYBaseDataModule.
712715
713716 Attributes:
714717 dynamic_data_split_seed (int): The seed for random data splitting, default is 42.
715718 splits_file_path (Optional[str]): Path to the CSV file containing split assignments.
719+ apply_label_filter (Optional[str]): Path to a classes.txt file for label filtering.
716720 """
717721
718722 # ---- Index for columns of processed `data.pkl` (should be derived from `_graph_to_raw_dataset` method) ------
@@ -722,6 +726,7 @@ class _DynamicDataset(XYBaseDataModule, ABC):
722726
723727 def __init__ (
724728 self ,
729+ apply_label_filter : Optional [str ] = None ,
725730 ** kwargs ,
726731 ):
727732 super (_DynamicDataset , self ).__init__ (** kwargs )
@@ -735,6 +740,7 @@ def __init__(
735740 self .splits_file_path = self ._validate_splits_file_path (
736741 kwargs .get ("splits_file_path" , None )
737742 )
743+ self .apply_label_filter = apply_label_filter
738744
739745 @staticmethod
740746 def _validate_splits_file_path (splits_file_path : Optional [str ]) -> Optional [str ]:
@@ -1134,6 +1140,18 @@ def _retrieve_splits_from_csv(self) -> None:
11341140 )
11351141 df_data = pd .DataFrame (data )
11361142
1143+ if self .apply_label_filter :
1144+ print (f"Applying label filter from { self .apply_label_filter } ..." )
1145+ with open (self .apply_label_filter , "r" ) as f :
1146+ label_filter = [line .strip () for line in f ]
1147+ with open (os .path .join (self .processed_dir_main , "classes.txt" ), "r" ) as cf :
1148+ classes = [line .strip () for line in cf ]
1149+ # reorder labels
1150+ old_labels = np .stack (df_data ["labels" ])
1151+ label_mapping = [classes .index (lbl ) for lbl in label_filter ]
1152+ new_labels = old_labels [:, label_mapping ]
1153+ df_data ["labels" ] = list (new_labels )
1154+
11371155 train_ids = splits_df [splits_df ["split" ] == "train" ]["id" ]
11381156 validation_ids = splits_df [splits_df ["split" ] == "validation" ]["id" ]
11391157 test_ids = splits_df [splits_df ["split" ] == "test" ]["id" ]
0 commit comments