Skip to content

Commit 541772e

Browse files
committed
ProSST PDB quantizer fix: load cpu weights if gpu not available
1 parent f3f5915 commit 541772e

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

pypef/llm/inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88

9+
from pypef.utils.helpers import get_device
910
from pypef.llm.utils import get_batches
1011
from pypef.llm.esm_lora_tune import esm_setup, esm_tokenize_sequences
1112
from pypef.llm.prosst_lora_tune import prosst_setup, prosst_tokenize_sequences
@@ -43,6 +44,8 @@ def inference(
4344
"""
4445
Inference of base models.
4546
"""
47+
if device is None:
48+
device = get_device()
4649
if llm == 'esm':
4750
logger.info("Zero-shot LLM inference on test set using ESM1v...")
4851
llm_dict = esm_setup(sequences)

pypef/llm/prosst_structure/quantizer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,15 @@ def __init__(
508508
self.subgraph_depth = subgraph_depth
509509
self.subgraph_interval = subgraph_interval
510510
self.anchor_nodes = anchor_nodes
511+
if device is None:
512+
self.device = get_device()
513+
else:
514+
self.device = device
511515
if model_path is None:
512-
self.model_path = str(Path(__file__).parent / "static" / "AE.pt")
516+
if self.device is 'cpu':
517+
self.model_path = str(Path(__file__).parent / "static" / "AE_CPU.pt")
518+
else:
519+
self.model_path = str(Path(__file__).parent / "static" / "AE.pt")
513520
else:
514521
self.model_path = model_path
515522

@@ -522,11 +529,6 @@ def __init__(
522529
self.cluster_dir = cluster_dir
523530
self.cluster_model = cluster_model
524531

525-
if device is None:
526-
self.device = get_device()
527-
else:
528-
self.device = device
529-
530532
# Load model
531533
node_dim = (256, 32)
532534
edge_dim = (64, 2)

0 commit comments

Comments
 (0)