Skip to content

Commit 924d426

Browse files
wraps the batch_idx value in a shape [1] torch tensor to be compatible with batch_collator
Signed-off-by: Steven <skothenhill@nvidia.com>
1 parent 1267264 commit 924d426

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

sub-packages/bionemo-llm/src/bionemo/llm/utils/callbacks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,11 @@ def write_on_batch_end(
9898
result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}__batch_{batch_idx}.pt")
9999

100100
# batch_indices is not captured due to a lightning bug when return_predictions = False
101-
# we use input IDs in the prediction to map the result to input
102-
prediction["batch_idx"] = batch_idx
101+
# we use input IDs in the prediction to map the result to input.
102+
103+
# NOTE store the batch_idx so we do not need to rely on filenames for reconstruction of inputs. This is wrapped
104+
# in a tensor and list container to ensure compatibility with batch_collator.
105+
prediction["batch_idx"] = torch.tensor([batch_idx], dtype=torch.int64)
103106

104107
torch.save(prediction, result_path)
105108
logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")

0 commit comments

Comments
 (0)