-
Notifications
You must be signed in to change notification settings - Fork 49
Open
Description
In 1_inference.ipynb the function predict_on_seqs runs OOM with the sequences specified in the tutorial on a 16GB VRAM GPU, while the grelu.interpret.score.ISM_predict does not run OOM even though in both functions inference happens. I took a look into the source code and changed it as follows, fixing my problem:
def predict_on_seqs(
self,
seqs: Union[str, List[str]],
devices: Union[int, str, List[int]] = "cpu",
num_workers: int = 1,
batch_size: int = 1,
precision: Optional[str] = None,
) -> np.ndarray:
"""
A simple function to return model predictions directly
on a single sequence in string format or on multiple
sequences in a string format in a list.
Args:
seqs: DNA sequence as a string or DNA sequences as a list.
devices: Index of the devices to use
num_workers: number of workers for inference
batch_size: batch size for model inference
precision: Precision of the trainer e.g. '32' or 'bf16-mixed'.
Returns:
A numpy array of predictions.
"""
seqs = strings_to_one_hot(seqs, add_batch_axis=True)
dataloader = self.make_predict_loader(
seqs,
num_workers=num_workers,
batch_size=batch_size,
)
accelerator, devices = self.parse_devices(devices)
trainer = pl.Trainer(
accelerator=accelerator,
devices=devices,
logger=None,
precision=precision,
)
# Predict
preds = torch.concat(trainer.predict(self, dataloader))
return preds.detach().cpu().numpy()This basically is what was used in predict_on_dataset, the function which is used in grelu.interpret.score.ISM_predict. I only tested it with the sequences specified in this tutorial, but it looks comparable.
Metadata
Metadata
Assignees
Labels
No labels