2929
3030
3131class PredictionWriter (BasePredictionWriter , pl .Callback ):
32- """A callback that writes predictions to disk at specified intervals during training."""
32+ """A callback that writes predictions to disk at specified intervals during training.
33+
34+ Logits, Embeddings, Hiddens, Input IDs, and Labels may all be saved to the disk depending on trainer configuration.
35+ Batch Idxs are provided for each prediction in the same dictionary. These must be used to maintain order between
36+ multi device predictions and single device predictions.
37+ """
3338
3439 def __init__ (
3540 self ,
@@ -42,15 +47,28 @@ def __init__(
4247
4348 Args:
4449 output_dir: The directory where predictions will be written.
45- write_interval: The interval at which predictions will be written. (batch, epoch)
50+ write_interval: The interval at which predictions will be written (batch, epoch). Epoch may not be used with multi-device trainers.
4651 batch_dim_key_defaults: The default batch dimension for each key, if different from the standard 0.
4752 seq_dim_key_defaults: The default sequence dimension for each key, if different from the standard 1.
4853 """
4954 super ().__init__ (write_interval )
55+ self .write_interval = write_interval
5056 self .output_dir = str (output_dir )
5157 self .batch_dim_key_defaults = batch_dim_key_defaults
5258 self .seq_dim_key_defaults = seq_dim_key_defaults
5359
60+ def setup (self , trainer : pl .Trainer , pl_module : pl .LightningModule , * args , ** kwargs ) -> None : # noqa: D417
61+ """Invoked with Trainer.fit, validate, test, and predict are called. Will immediately fail when 'write_interval' is 'epoch' and 'trainer.num_devices' > 1.
62+
63+ Args:
64+ trainer: The Trainer instance.
65+ pl_module: The LightningModule instance.
66+ """
67+ if trainer .num_devices > 1 and self .write_interval == "epoch" :
68+ raise ValueError (
69+ "Multi-GPU predictions are not permitted as outputs are not ordered and batch indices are lost."
70+ )
71+
5472 def write_on_batch_end (
5573 self ,
5674 trainer : pl .Trainer ,
@@ -63,6 +81,9 @@ def write_on_batch_end(
6381 ) -> None :
6482 """Writes predictions to disk at the end of each batch.
6583
84+ Predictions files follow the naming pattern, where rank is the active GPU in which the predictions were made.
85+ predictions__rank_{rank}__batch_{batch_idx}.pt
86+
6687 Args:
6788 trainer: The Trainer instance.
6889 pl_module: The LightningModule instance.
@@ -78,6 +99,8 @@ def write_on_batch_end(
7899
79100 # batch_indices is not captured due to a lightning bug when return_predictions = False
80101 # we use input IDs in the prediction to map the result to input
102+ prediction ["batch_idx" ] = batch_idx
103+
81104 torch .save (prediction , result_path )
82105 logging .info (f"Inference predictions are stored in { result_path } \n { prediction .keys ()} " )
83106
@@ -90,14 +113,23 @@ def write_on_epoch_end(
90113 ) -> None :
91114 """Writes predictions to disk at the end of each epoch.
92115
116+ Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for
117+ large predictions.
118+
119+ Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.
120+
93121 Args:
94122 trainer: The Trainer instance.
95123 pl_module: The LightningModule instance.
96124 predictions: The predictions made by the model.
97125 batch_indices: The indices of the batch.
126+
127+ Raises:
128+ Multi-GPU predictions are output in an inconsistent order with multiple devices.
98129 """
99130 # this will create N (num processes) files in `output_dir` each containing
100131 # the predictions of it's respective rank
132+
101133 result_path = os .path .join (self .output_dir , f"predictions__rank_{ trainer .global_rank } .pt" )
102134
103135 # collate multiple batches / ignore empty ones
@@ -106,13 +138,14 @@ def write_on_epoch_end(
106138 collate_kwargs ["batch_dim_key_defaults" ] = self .batch_dim_key_defaults
107139 if self .seq_dim_key_defaults is not None :
108140 collate_kwargs ["seq_dim_key_defaults" ] = self .seq_dim_key_defaults
141+
109142 prediction = batch_collator ([item for item in predictions if item is not None ], ** collate_kwargs )
110143
111144 # batch_indices is not captured due to a lightning bug when return_predictions = False
112145 # we use input IDs in the prediction to map the result to input
113- torch .save (prediction , result_path )
114146 if isinstance (prediction , dict ):
115147 keys = prediction .keys ()
116148 else :
117149 keys = "tensor"
150+ torch .save (prediction , result_path )
118151 logging .info (f"Inference predictions are stored in { result_path } \n { keys } " )
0 commit comments