File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
sub-packages/bionemo-llm/src/bionemo/llm/utils Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -98,8 +98,11 @@ def write_on_batch_end(
9898 result_path = os .path .join (self .output_dir , f"predictions__rank_{ trainer .global_rank } __batch_{ batch_idx } .pt" )
9999
100100 # batch_indices is not captured due to a lightning bug when return_predictions = False
101- # we use input IDs in the prediction to map the result to input
102- prediction ["batch_idx" ] = batch_idx
101+ # we use input IDs in the prediction to map the result to input.
102+
103+ # NOTE store the batch_idx so we do not need to rely on filenames for reconstruction of inputs. This is wrapped
104+ # in a tensor and list container to ensure compatibility with batch_collator.
105+ prediction ["batch_idx" ] = torch .tensor ([batch_idx ], dtype = torch .int64 )
103106
104107 torch .save (prediction , result_path )
105108 logging .info (f"Inference predictions are stored in { result_path } \n { prediction .keys ()} " )
You can’t perform that action at this time.
0 commit comments