Skip to content

Commit 1267264

Browse files
multi-gpu inference. Adds 'batch index' to the resulting prediction
dictionary. this allows users to reconstruct the original ordering of predictions with multi-gpu inference. Signed-off-by: Steven <skothenhill@nvidia.com>
1 parent 29ce3bc commit 1267264

File tree

1 file changed

+36
-3
lines changed

1 file changed

+36
-3
lines changed

sub-packages/bionemo-llm/src/bionemo/llm/utils/callbacks.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@
2929

3030

3131
class PredictionWriter(BasePredictionWriter, pl.Callback):
32-
"""A callback that writes predictions to disk at specified intervals during training."""
32+
"""A callback that writes predictions to disk at specified intervals during training.
33+
34+
Logits, Embeddings, Hiddens, Input IDs, and Labels may all be saved to the disk depending on trainer configuration.
35+
Batch Idxs are provided for each prediction in the same dictionary. These must be used to maintain order between
36+
multi device predictions and single device predictions.
37+
"""
3338

3439
def __init__(
3540
self,
@@ -42,15 +47,28 @@ def __init__(
4247
4348
Args:
4449
output_dir: The directory where predictions will be written.
45-
write_interval: The interval at which predictions will be written. (batch, epoch)
50+
write_interval: The interval at which predictions will be written (batch, epoch). Epoch may not be used with multi-device trainers.
4651
batch_dim_key_defaults: The default batch dimension for each key, if different from the standard 0.
4752
seq_dim_key_defaults: The default sequence dimension for each key, if different from the standard 1.
4853
"""
4954
super().__init__(write_interval)
55+
self.write_interval = write_interval
5056
self.output_dir = str(output_dir)
5157
self.batch_dim_key_defaults = batch_dim_key_defaults
5258
self.seq_dim_key_defaults = seq_dim_key_defaults
5359

60+
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None: # noqa: D417
61+
"""Invoked with Trainer.fit, validate, test, and predict are called. Will immediately fail when 'write_interval' is 'epoch' and 'trainer.num_devices' > 1.
62+
63+
Args:
64+
trainer: The Trainer instance.
65+
pl_module: The LightningModule instance.
66+
"""
67+
if trainer.num_devices > 1 and self.write_interval == "epoch":
68+
raise ValueError(
69+
"Multi-GPU predictions are not permitted as outputs are not ordered and batch indices are lost."
70+
)
71+
5472
def write_on_batch_end(
5573
self,
5674
trainer: pl.Trainer,
@@ -63,6 +81,9 @@ def write_on_batch_end(
6381
) -> None:
6482
"""Writes predictions to disk at the end of each batch.
6583
84+
Predictions files follow the naming pattern, where rank is the active GPU in which the predictions were made.
85+
predictions__rank_{rank}__batch_{batch_idx}.pt
86+
6687
Args:
6788
trainer: The Trainer instance.
6889
pl_module: The LightningModule instance.
@@ -78,6 +99,8 @@ def write_on_batch_end(
7899

79100
# batch_indices is not captured due to a lightning bug when return_predictions = False
80101
# we use input IDs in the prediction to map the result to input
102+
prediction["batch_idx"] = batch_idx
103+
81104
torch.save(prediction, result_path)
82105
logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")
83106

@@ -90,14 +113,23 @@ def write_on_epoch_end(
90113
) -> None:
91114
"""Writes predictions to disk at the end of each epoch.
92115
116+
Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for
117+
large predictions.
118+
119+
Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.
120+
93121
Args:
94122
trainer: The Trainer instance.
95123
pl_module: The LightningModule instance.
96124
predictions: The predictions made by the model.
97125
batch_indices: The indices of the batch.
126+
127+
Raises:
128+
Multi-GPU predictions are output in an inconsistent order with multiple devices.
98129
"""
99130
# this will create N (num processes) files in `output_dir` each containing
100131
# the predictions of it's respective rank
132+
101133
result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}.pt")
102134

103135
# collate multiple batches / ignore empty ones
@@ -106,13 +138,14 @@ def write_on_epoch_end(
106138
collate_kwargs["batch_dim_key_defaults"] = self.batch_dim_key_defaults
107139
if self.seq_dim_key_defaults is not None:
108140
collate_kwargs["seq_dim_key_defaults"] = self.seq_dim_key_defaults
141+
109142
prediction = batch_collator([item for item in predictions if item is not None], **collate_kwargs)
110143

111144
# batch_indices is not captured due to a lightning bug when return_predictions = False
112145
# we use input IDs in the prediction to map the result to input
113-
torch.save(prediction, result_path)
114146
if isinstance(prediction, dict):
115147
keys = prediction.keys()
116148
else:
117149
keys = "tensor"
150+
torch.save(prediction, result_path)
118151
logging.info(f"Inference predictions are stored in {result_path}\n{keys}")

0 commit comments

Comments
 (0)