@@ -121,6 +121,7 @@ def generate_props(
121121 model_config_file_path : str ,
122122 data_config_file_path : str ,
123123 output_path : str | None = None ,
124+ apply_id_filter : str | None = None ,
124125 ) -> None :
125126 """
126127 Run inference on validation set, compute TPV/NPV per class, and save to JSON.
@@ -132,11 +133,13 @@ def generate_props(
132133 data_config_file_path: Path to yaml config file of the data.
133134 output_path: Optional path where to write the JSON metrics file.
134135 Defaults to '<processed_dir_main>/classes.json'.
136+ apply_id_filter: Optional path to a (data.pt) file containing IDs to filter the dataset. This is useful for comparing datasets with different ids.
135137 """
136138 data_cls_path , data_cls_kwargs = parse_config_file (data_config_file_path )
137139 data_module : XYBaseDataModule = load_data_instance (
138140 data_cls_path , data_cls_kwargs
139141 )
142+ data_module .apply_id_filter = apply_id_filter
140143
141144 splits_file_path = Path (data_module .processed_dir_main , "splits.csv" )
142145 if data_module .splits_file_path is None :
@@ -222,6 +225,7 @@ def generate(
222225 model_config_file_path : str ,
223226 data_config_file_path : str ,
224227 output_path : str | None = None ,
228+ apply_id_filter : str | None = None ,
225229 ) -> None :
226230 """
227231 CLI command to generate JSON with metrics on validation set.
@@ -246,6 +250,7 @@ def generate(
246250 model_config_file_path ,
247251 data_config_file_path ,
248252 output_path ,
253+ apply_id_filter = apply_id_filter ,
249254 )
250255
251256
0 commit comments