File tree Expand file tree Collapse file tree 2 files changed +11
-6
lines changed
Expand file tree Collapse file tree 2 files changed +11
-6
lines changed Original file line number Diff line number Diff line change 66
77import numpy as np
88
9+ from pypef .utils .helpers import get_device
910from pypef .llm .utils import get_batches
1011from pypef .llm .esm_lora_tune import esm_setup , esm_tokenize_sequences
1112from pypef .llm .prosst_lora_tune import prosst_setup , prosst_tokenize_sequences
@@ -43,6 +44,8 @@ def inference(
4344 """
4445 Inference of base models.
4546 """
47+ if device is None :
48+ device = get_device ()
4649 if llm == 'esm' :
4750 logger .info ("Zero-shot LLM inference on test set using ESM1v..." )
4851 llm_dict = esm_setup (sequences )
Original file line number Diff line number Diff line change @@ -508,8 +508,15 @@ def __init__(
508508 self .subgraph_depth = subgraph_depth
509509 self .subgraph_interval = subgraph_interval
510510 self .anchor_nodes = anchor_nodes
511+ if device is None :
512+ self .device = get_device ()
513+ else :
514+ self .device = device
511515 if model_path is None :
512- self .model_path = str (Path (__file__ ).parent / "static" / "AE.pt" )
516+ if self .device is 'cpu' :
517+ self .model_path = str (Path (__file__ ).parent / "static" / "AE_CPU.pt" )
518+ else :
519+ self .model_path = str (Path (__file__ ).parent / "static" / "AE.pt" )
513520 else :
514521 self .model_path = model_path
515522
@@ -522,11 +529,6 @@ def __init__(
522529 self .cluster_dir = cluster_dir
523530 self .cluster_model = cluster_model
524531
525- if device is None :
526- self .device = get_device ()
527- else :
528- self .device = device
529-
530532 # Load model
531533 node_dim = (256 , 32 )
532534 edge_dim = (64 , 2 )
You can’t perform that action at this time.
0 commit comments