33
44from lightning .pytorch .callbacks import BasePredictionWriter
55import torch
6+ from typing import Any , Dict , List , Union
67
78
89class ChebaiPredictionWriter (BasePredictionWriter ):
9- def __init__ (self , output_dir , write_interval , target_file = "predictions.json" ):
10+ """
11+ A custom prediction writer for saving batch and epoch predictions during model training.
12+
13+ This class inherits from `BasePredictionWriter` and is designed to save predictions
14+ in a specified output directory at specified intervals.
15+
16+ Args:
17+ output_dir (str): The directory where predictions will be saved.
18+ write_interval (str): The interval at which predictions will be written.
19+ target_file (str): The name of the file where epoch predictions will be saved (default: "predictions.json").
20+ """
21+
22+ def __init__ (
23+ self ,
24+ output_dir : str ,
25+ write_interval : str ,
26+ target_file : str = "predictions.json" ,
27+ ) -> None :
1028 super ().__init__ (write_interval )
1129 self .output_dir = output_dir
1230 self .target_file = target_file
1331
1432 def write_on_batch_end (
1533 self ,
16- trainer ,
17- pl_module ,
18- prediction ,
19- batch_indices ,
20- batch ,
21- batch_idx ,
22- dataloader_idx ,
23- ):
24- outpath = os .path .join (self .output_dir , dataloader_idx , f"{ batch_idx } .pt" )
25- os .makedirs (outpath , exist_ok = True )
34+ trainer : Any ,
35+ pl_module : Any ,
36+ prediction : Union [torch .Tensor , List [torch .Tensor ]],
37+ batch_indices : List [int ],
38+ batch : Any ,
39+ batch_idx : int ,
40+ dataloader_idx : int ,
41+ ) -> None :
42+ """
43+ Saves batch predictions at the end of each batch.
44+
45+ Args:
46+ trainer (Any): The trainer instance.
47+ pl_module (Any): The LightningModule instance.
48+ prediction (Union[torch.Tensor, List[torch.Tensor]]): The prediction output from the model.
49+ batch_indices (List[int]): The indices of the batch.
50+ batch (Any): The current batch.
51+ batch_idx (int): The index of the batch.
52+ dataloader_idx (int): The index of the dataloader.
53+ """
54+ outpath = os .path .join (self .output_dir , str (dataloader_idx ), f"{ batch_idx } .pt" )
55+ os .makedirs (os .path .dirname (outpath ), exist_ok = True )
2656 torch .save (prediction , outpath )
2757
28- def write_on_epoch_end (self , trainer , pl_module , predictions , batch_indices ):
58+ def write_on_epoch_end (
59+ self ,
60+ trainer : Any ,
61+ pl_module : Any ,
62+ predictions : List [Dict [str , Any ]],
63+ batch_indices : List [int ],
64+ ) -> None :
65+ """
66+ Saves all predictions at the end of each epoch in a JSON file.
67+
68+ Args:
69+ trainer (Any): The trainer instance.
70+ pl_module (Any): The LightningModule instance.
71+ predictions (List[Dict[str, Any]]): The list of prediction outputs from the model.
72+ batch_indices (List[int]): The indices of the batches.
73+ """
2974 pred_list = []
3075 for p in predictions :
3176 idents = p ["data" ]["idents" ]
@@ -35,7 +80,7 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
3580 else :
3681 labels = [None for _ in idents ]
3782 output = torch .sigmoid (p ["output" ]["logits" ]).tolist ()
38- for i , l , p in zip (idents , labels , output ):
39- pred_list .append (dict (ident = i , labels = l , predictions = p ))
83+ for i , l , o in zip (idents , labels , output ):
84+ pred_list .append (dict (ident = i , labels = l , predictions = o ))
4085 with open (os .path .join (self .output_dir , self .target_file ), "wt" ) as fout :
4186 json .dump (pred_list , fout )
0 commit comments