Skip to content

Commit a3cc197

Browse files
committed
run inference on given partition of the data
1 parent 2da2149 commit a3cc197

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

chebai/result/generate_class_properties.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from pathlib import Path
3+
from typing import Literal
34

45
import torchmetrics
56
from jsonargparse import CLI
@@ -115,6 +116,7 @@ def compute_classwise_scores(
115116

116117
def generate_props(
117118
self,
119+
data_partition: Literal["train", "val", "test"],
118120
model_ckpt_path: str,
119121
model_config_file_path: str,
120122
data_config_file_path: str,
@@ -124,14 +126,13 @@ def generate_props(
124126
Run inference on validation set, compute TPV/NPV per class, and save to JSON.
125127
126128
Args:
129+
data_partition: Partition of the dataset to use to generate class properties.
127130
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
128131
model_config_file_path: Path to yaml config file of the model.
129132
data_config_file_path: Path to yaml config file of the data.
130133
output_path: Optional path where to write the JSON metrics file.
131134
Defaults to '<processed_dir_main>/classes.json'.
132135
"""
133-
print("Extracting validation data for computation...")
134-
135136
data_cls_path, data_cls_kwargs = parse_config_file(data_config_file_path)
136137
data_module: XYBaseDataModule = load_data_instance(
137138
data_cls_path, data_cls_kwargs
@@ -157,8 +158,15 @@ def generate_props(
157158
model_ckpt_path, model_class_path, model_kwargs
158159
)
159160

160-
val_loader = data_module.val_dataloader()
161-
print("Running inference on validation data...")
161+
if data_partition == "train":
162+
data_loader = data_module.train_dataloader()
163+
elif data_partition == "val":
164+
data_loader = data_module.val_dataloader()
165+
elif data_partition == "test":
166+
data_loader = data_module.test_dataloader()
167+
else:
168+
raise ValueError(f"Unknown data partition: {data_partition}")
169+
print(f"Running inference on {data_partition} data...")
162170

163171
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
164172
class_names = self.load_class_labels(classes_file)
@@ -168,7 +176,7 @@ def generate_props(
168176
"f1": MultilabelF1Score(num_labels=num_classes, average=None),
169177
}
170178

171-
for batch_idx, batch in enumerate(val_loader):
179+
for batch_idx, batch in enumerate(data_loader):
172180
data = model._process_batch(batch, batch_idx=batch_idx)
173181
labels = data["labels"]
174182
model_output = model(data, **data.get("model_kwargs", {}))
@@ -180,7 +188,9 @@ def generate_props(
180188

181189
print("Computing metrics...")
182190
if output_path is None:
183-
output_file = Path(data_module.processed_dir_main) / "classes.json"
191+
output_file = (
192+
Path(data_module.processed_dir_main) / f"classes_{data_partition}.json"
193+
)
184194
else:
185195
output_file = Path(output_path)
186196

@@ -198,6 +208,7 @@ class Main:
198208

199209
def generate(
200210
self,
211+
data_partition: Literal["train", "val", "test"],
201212
model_ckpt_path: str,
202213
model_config_file_path: str,
203214
data_config_file_path: str,
@@ -207,14 +218,21 @@ def generate(
207218
CLI command to generate JSON with metrics on validation set.
208219
209220
Args:
221+
data_partition: Partition of dataset to use to generate class properties.
210222
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
211223
model_config_file_path: Path to yaml config file of the model.
212224
data_config_file_path: Path to yaml config file of the data.
213225
output_path: Optional path where to write the JSON metrics file.
214226
Defaults to '<processed_dir_main>/classes.json'.
215227
"""
228+
assert data_partition in [
229+
"train",
230+
"val",
231+
"test",
232+
], f"Given data partition invalid: {data_partition}, Choose one of the value among `train`, `val`, `test` "
216233
generator = ClassesPropertiesGenerator()
217234
generator.generate_props(
235+
data_partition,
218236
model_ckpt_path,
219237
model_config_file_path,
220238
data_config_file_path,
@@ -224,6 +242,7 @@ def generate(
224242

225243
if __name__ == "__main__":
226244
# _generate_classes_props_json.py generate \
245+
# --data_partition "val" \
227246
# --model_ckpt_path "model/ckpt/path" \
228247
# --model_config_file_path "model/config/file/path" \
229248
# --data_config_file_path "data/config/file/path" \

0 commit comments

Comments
 (0)