Skip to content

Commit 4ab760e

Browse files
committed
add label filter
1 parent eb86e3f commit 4ab760e

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
55

66
import lightning as pl
7+
import numpy as np
78
import pandas as pd
89
import torch
910
import tqdm
@@ -708,11 +709,14 @@ class _DynamicDataset(XYBaseDataModule, ABC):
708709
Args:
709710
dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42.
710711
splits_file_path (str, optional): Path to the splits CSV file. Defaults to None.
712+
apply_label_filter (Optional[str]): Path to a classes.txt file - only labels that are in the labels filter
713+
file will be used (in that order). All labels in the label filter have to be present in the dataset.
711714
**kwargs: Additional keyword arguments passed to XYBaseDataModule.
712715
713716
Attributes:
714717
dynamic_data_split_seed (int): The seed for random data splitting, default is 42.
715718
splits_file_path (Optional[str]): Path to the CSV file containing split assignments.
719+
apply_label_filter (Optional[str]): Path to a classes.txt file for label filtering.
716720
"""
717721

718722
# ---- Index for columns of processed `data.pkl` (should be derived from `_graph_to_raw_dataset` method) ------
@@ -722,6 +726,7 @@ class _DynamicDataset(XYBaseDataModule, ABC):
722726

723727
def __init__(
724728
self,
729+
apply_label_filter: Optional[str] = None,
725730
**kwargs,
726731
):
727732
super(_DynamicDataset, self).__init__(**kwargs)
@@ -735,6 +740,7 @@ def __init__(
735740
self.splits_file_path = self._validate_splits_file_path(
736741
kwargs.get("splits_file_path", None)
737742
)
743+
self.apply_label_filter = apply_label_filter
738744

739745
@staticmethod
740746
def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]:
@@ -1134,6 +1140,18 @@ def _retrieve_splits_from_csv(self) -> None:
11341140
)
11351141
df_data = pd.DataFrame(data)
11361142

1143+
if self.apply_label_filter:
1144+
print(f"Applying label filter from {self.apply_label_filter}...")
1145+
with open(self.apply_label_filter, "r") as f:
1146+
label_filter = [line.strip() for line in f]
1147+
with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf:
1148+
classes = [line.strip() for line in cf]
1149+
# reorder labels
1150+
old_labels = np.stack(df_data["labels"])
1151+
label_mapping = [classes.index(lbl) for lbl in label_filter]
1152+
new_labels = old_labels[:, label_mapping]
1153+
df_data["labels"] = list(new_labels)
1154+
11371155
train_ids = splits_df[splits_df["split"] == "train"]["id"]
11381156
validation_ids = splits_df[splits_df["split"] == "validation"]["id"]
11391157
test_ids = splits_df[splits_df["split"] == "test"]["id"]

chebai/result/generate_class_properties.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,10 @@ def generate_props(
168168
raise ValueError(f"Unknown data partition: {data_partition}")
169169
print(f"Running inference on {data_partition} data...")
170170

171-
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
171+
if data_module.apply_label_filter is not None:
172+
classes_file = data_module.apply_label_filter
173+
else:
174+
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
172175
class_names = self.load_class_labels(classes_file)
173176
num_classes = len(class_names)
174177
metrics_obj_dict: dict[str, torchmetrics.Metric] = {
@@ -181,6 +184,7 @@ def generate_props(
181184
}
182185

183186
for batch_idx, batch in enumerate(data_loader):
187+
batch = batch.to(device=model.device)
184188
data = model._process_batch(batch, batch_idx=batch_idx)
185189
labels = data["labels"].to(device=model.device)
186190
data["features"][0].to(device=model.device)

0 commit comments

Comments
 (0)