Skip to content

Commit 151a8ec

Browse files
committed
loading embeddings is fixed.
1 parent afed612 commit 151a8ec

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

barcodebert/datasets.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,15 @@ def representations_from_df(
337337

338338
dna_embeddings = []
339339

340+
# Get model device robustly (works for all PyTorch models)
341+
model_device = next(model.parameters()).device
342+
340343
with torch.no_grad():
341344
for barcode in df["nucleotides"]:
342345
x, att_mask = tokenizer(barcode)
343346

344-
x = x.unsqueeze(0).to(model.device)
345-
att_mask = att_mask.unsqueeze(0).to(model.device)
347+
x = x.unsqueeze(0).to(model_device)
348+
att_mask = att_mask.unsqueeze(0).to(model_device)
346349

347350
# Get model output
348351
output = model(x, att_mask)

0 commit comments

Comments
 (0)