11import json
22from pathlib import Path
3+ from typing import Literal
34
45import torchmetrics
56from jsonargparse import CLI
@@ -115,6 +116,7 @@ def compute_classwise_scores(
115116
116117 def generate_props (
117118 self ,
119+ data_partition : Literal ["train" , "val" , "test" ],
118120 model_ckpt_path : str ,
119121 model_config_file_path : str ,
120122 data_config_file_path : str ,
@@ -124,14 +126,13 @@ def generate_props(
124126 Run inference on validation set, compute TPV/NPV per class, and save to JSON.
125127
126128 Args:
129+ data_partition: Partition of the dataset to use to generate class properties.
127130 model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
128131 model_config_file_path: Path to yaml config file of the model.
129132 data_config_file_path: Path to yaml config file of the data.
130133 output_path: Optional path where to write the JSON metrics file.
131134 Defaults to '<processed_dir_main>/classes.json'.
132135 """
133- print ("Extracting validation data for computation..." )
134-
135136 data_cls_path , data_cls_kwargs = parse_config_file (data_config_file_path )
136137 data_module : XYBaseDataModule = load_data_instance (
137138 data_cls_path , data_cls_kwargs
@@ -157,8 +158,15 @@ def generate_props(
157158 model_ckpt_path , model_class_path , model_kwargs
158159 )
159160
160- val_loader = data_module .val_dataloader ()
161- print ("Running inference on validation data..." )
161+ if data_partition == "train" :
162+ data_loader = data_module .train_dataloader ()
163+ elif data_partition == "val" :
164+ data_loader = data_module .val_dataloader ()
165+ elif data_partition == "test" :
166+ data_loader = data_module .test_dataloader ()
167+ else :
168+ raise ValueError (f"Unknown data partition: { data_partition } " )
169+ print (f"Running inference on { data_partition } data..." )
162170
163171 classes_file = Path (data_module .processed_dir_main ) / "classes.txt"
164172 class_names = self .load_class_labels (classes_file )
@@ -168,7 +176,7 @@ def generate_props(
168176 "f1" : MultilabelF1Score (num_labels = num_classes , average = None ),
169177 }
170178
171- for batch_idx , batch in enumerate (val_loader ):
179+ for batch_idx , batch in enumerate (data_loader ):
172180 data = model ._process_batch (batch , batch_idx = batch_idx )
173181 labels = data ["labels" ]
174182 model_output = model (data , ** data .get ("model_kwargs" , {}))
@@ -180,7 +188,9 @@ def generate_props(
180188
181189 print ("Computing metrics..." )
182190 if output_path is None :
183- output_file = Path (data_module .processed_dir_main ) / "classes.json"
191+ output_file = (
192+ Path (data_module .processed_dir_main ) / f"classes_{ data_partition } .json"
193+ )
184194 else :
185195 output_file = Path (output_path )
186196
@@ -198,6 +208,7 @@ class Main:
198208
199209 def generate (
200210 self ,
211+ data_partition : Literal ["train" , "val" , "test" ],
201212 model_ckpt_path : str ,
202213 model_config_file_path : str ,
203214 data_config_file_path : str ,
@@ -207,14 +218,21 @@ def generate(
207218 CLI command to generate JSON with metrics on validation set.
208219
209220 Args:
221+ data_partition: Partition of dataset to use to generate class properties.
210222 model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
211223 model_config_file_path: Path to yaml config file of the model.
212224 data_config_file_path: Path to yaml config file of the data.
213225 output_path: Optional path where to write the JSON metrics file.
214226 Defaults to '<processed_dir_main>/classes.json'.
215227 """
228+ assert data_partition in [
229+ "train" ,
230+ "val" ,
231+ "test" ,
232+ ], f"Given data partition invalid: { data_partition } , Choose one of the value among `train`, `val`, `test` "
216233 generator = ClassesPropertiesGenerator ()
217234 generator .generate_props (
235+ data_partition ,
218236 model_ckpt_path ,
219237 model_config_file_path ,
220238 data_config_file_path ,
@@ -224,6 +242,7 @@ def generate(
224242
225243if __name__ == "__main__" :
226244 # _generate_classes_props_json.py generate \
245+ # --data_partition "val" \
227246 # --model_ckpt_path "model/ckpt/path" \
228247 # --model_config_file_path "model/config/file/path" \
229248 # --data_config_file_path "data/config/file/path" \
0 commit comments