Skip to content

Commit bcf96f6

Browse files
committed
add id filter
1 parent dfc4db9 commit bcf96f6

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

chebai/result/generate_class_properties.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def generate_props(
121121
model_config_file_path: str,
122122
data_config_file_path: str,
123123
output_path: str | None = None,
124+
apply_id_filter: str | None = None,
124125
) -> None:
125126
"""
126127
Run inference on validation set, compute TPV/NPV per class, and save to JSON.
@@ -132,11 +133,13 @@ def generate_props(
132133
data_config_file_path: Path to yaml config file of the data.
133134
output_path: Optional path where to write the JSON metrics file.
134135
Defaults to '<processed_dir_main>/classes.json'.
136+
apply_id_filter: Optional path to a (data.pt) file containing IDs to filter the dataset. This is useful for comparing datasets with different ids.
135137
"""
136138
data_cls_path, data_cls_kwargs = parse_config_file(data_config_file_path)
137139
data_module: XYBaseDataModule = load_data_instance(
138140
data_cls_path, data_cls_kwargs
139141
)
142+
data_module.apply_id_filter = apply_id_filter
140143

141144
splits_file_path = Path(data_module.processed_dir_main, "splits.csv")
142145
if data_module.splits_file_path is None:
@@ -222,6 +225,7 @@ def generate(
222225
model_config_file_path: str,
223226
data_config_file_path: str,
224227
output_path: str | None = None,
228+
apply_id_filter: str | None = None,
225229
) -> None:
226230
"""
227231
CLI command to generate JSON with metrics on validation set.
@@ -246,6 +250,7 @@ def generate(
246250
model_config_file_path,
247251
data_config_file_path,
248252
output_path,
253+
apply_id_filter=apply_id_filter,
249254
)
250255

251256

0 commit comments

Comments
 (0)