Skip to content

Commit 4bca8f1

Browse files
committed
Update model loading process to work offline if models are available
1 parent cfd03dd commit 4bca8f1

File tree

4 files changed

+79
-53
lines changed

4 files changed

+79
-53
lines changed

pypef/llm/esm_lora_tune.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,17 @@
2828
from peft import LoraConfig, get_peft_model
2929
from transformers import logging as hf_logging
3030
hf_logging.set_verbosity_error()
31-
from transformers import EsmForMaskedLM, EsmTokenizer
3231

3332
from pypef.utils.helpers import get_device
34-
from pypef.llm.utils import corr_loss
33+
from pypef.llm.utils import corr_loss, load_model_and_tokenizer
3534

3635

3736
def get_esm_models():
38-
base_model = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_3')
39-
tokenizer = EsmTokenizer.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_3')
37+
base_model, tokenizer = load_model_and_tokenizer(
38+
f'facebook/esm1v_t33_650M_UR90S_3'
39+
# Just sticking to AutoModelForMaskedLM and AutoTokenizer
40+
# instead to EsmForMaskedLM and EsmTokenizer
41+
)
4042
peft_config = LoraConfig(r=8, target_modules=["query", "value"])
4143
lora_model = get_peft_model(base_model, peft_config)
4244
optimizer = torch.optim.Adam(lora_model.parameters(), lr=0.01)

pypef/llm/prosst_lora_tune.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import numpy as np
2222
from scipy.stats import spearmanr
2323
from tqdm import tqdm
24-
from transformers import AutoModelForMaskedLM, AutoTokenizer
2524
from peft import LoraConfig, get_peft_model
2625
from Bio import SeqIO, BiopythonParserWarning
2726
warnings.filterwarnings(action='ignore', category=BiopythonParserWarning)

pypef/llm/utils.py

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
import os
77
import platform
8-
from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST
98
from transformers import AutoModelForMaskedLM, AutoTokenizer
109
from transformers.utils import logging as ts_logging
1110
ts_logging.set_verbosity_error()
@@ -64,42 +63,83 @@ def get_default_cache_dir():
6463
"""
6564
system = platform.system()
6665
if system == "Windows":
67-
return os.path.join(os.environ.get("USERPROFILE", ""), ".cache", "huggingface", "transformers")
66+
return os.path.join(
67+
os.environ.get("USERPROFILE", ""), ".cache",
68+
"huggingface", "transformers"
69+
)
6870
elif system == "Darwin":
6971
return os.path.expanduser("~/.cache/huggingface/transformers")
7072
else: # Assume Linux or other Unix-like systems
7173
return os.path.expanduser("~/.cache/huggingface/transformers")
7274

7375

74-
def is_model_cached(repo_id: str, cache_dir: str) -> bool:
76+
def is_model_cached(repo_id: str, cache_dir: str):
7577
"""
7678
Check if the required model and tokenizer files are cached locally.
7779
"""
78-
79-
filepath = try_to_load_from_cache(repo_id=repo_id, filename='model.safetensors', cache_dir=cache_dir)
80-
if isinstance(filepath, str):
81-
return True # file is cached
82-
elif filepath is _CACHED_NO_EXIST:
83-
return False # non-existence of file is cached
80+
snapshot_dir = None
81+
if os.path.isdir(cache_dir):
82+
ref_file = os.path.join(
83+
cache_dir, f'models--{repo_id.replace("/", '--')}', 'refs', 'main'
84+
)
85+
if os.path.isfile(ref_file):
86+
with open(ref_file, 'r') as fh:
87+
t = fh.readlines()
88+
ref = t[0].strip()
89+
else:
90+
return False, snapshot_dir
91+
snapshot_dir = os.path.join(
92+
cache_dir, f'models--{repo_id.replace("/", '--')}', 'snapshots', ref
93+
)
94+
if os.path.isdir(snapshot_dir):
95+
return True, snapshot_dir
96+
else:
97+
return False, None
8498
else:
85-
return False # file is not cached and not in non-existance cache
99+
return False, snapshot_dir
86100

87101

88-
def load_model_and_tokenizer(model_name, cache_dir: str | os.PathLike | None = None):
102+
def load_model_and_tokenizer(
103+
model_name: str,
104+
cache_dir: str | os.PathLike | None = None,
105+
model_loader=None,
106+
tokenizer_loader=None
107+
):
89108
"""
90109
Load the model and tokenizer from cache directory. Downloads to cache if not present.
91110
"""
92111
if cache_dir is None:
93112
cache_dir = get_default_cache_dir()
94-
if is_model_cached(model_name, cache_dir):
95-
logger.info(f"Loading model and tokenizer from cache {cache_dir}...")
113+
if model_loader is None:
114+
model_loader = AutoModelForMaskedLM
115+
if tokenizer_loader is None:
116+
tokenizer_loader = AutoTokenizer
117+
exists, exists_at = is_model_cached(model_name, cache_dir)
118+
if exists:
119+
try:
120+
logger.info(f"Loading model and tokenizer from cache {exists_at}...")
121+
model = model_loader.from_pretrained(
122+
exists_at, trust_remote_code=True
123+
)
124+
tokenizer = tokenizer_loader.from_pretrained(
125+
exists_at, trust_remote_code=True
126+
)
127+
except OSError as e:
128+
logger.info(f"Faced error \"{e}\": Trying to load with regular cache load path...")
129+
model = model_loader.from_pretrained(
130+
model_name, cache_dir=cache_dir, trust_remote_code=True
131+
)
132+
tokenizer = tokenizer_loader.from_pretrained(
133+
model_name, cache_dir=cache_dir, trust_remote_code=True
134+
)
96135
else:
97136
logger.info(f"Did not find model and tokenizer in cache directory, downloading model "
98-
f"and tokenizer from the internet and storing in cache {cache_dir}...")
99-
model = AutoModelForMaskedLM.from_pretrained(
100-
model_name, cache_dir=cache_dir, trust_remote_code=True
101-
)
102-
tokenizer = AutoTokenizer.from_pretrained(
103-
model_name, cache_dir=cache_dir, trust_remote_code=True
104-
)
137+
f"and tokenizer from the internet and storing in cache {cache_dir}...")
138+
model = model_loader.from_pretrained(
139+
model_name, cache_dir=cache_dir, trust_remote_code=True
140+
)
141+
tokenizer = tokenizer_loader.from_pretrained(
142+
model_name, cache_dir=cache_dir, trust_remote_code=True
143+
)
144+
logger.info("Model and tokenizer loaded successfully...")
105145
return model, tokenizer

scripts/ProteinGym_runs/protgym_hybrid_perf_test_crossval.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
5050
get_vram()
5151
MAX_WT_SEQUENCE_LENGTH = 600 # TODO: 1000
5252
MAX_VARIANT_FITNESS_PAIRS = 5000
53+
N_CV = 5
5354
print(f"Maximum sequence length: {MAX_WT_SEQUENCE_LENGTH}")
5455
print(f"Loading LLM models into {device} device...")
5556
prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models()
@@ -160,11 +161,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
160161
x_esm, esm_attention_mask = esm_tokenize_sequences(
161162
sequences, esm_tokenizer, max_length=len(wt_seq), verbose=False
162163
)
163-
#y_esm = esm_infer(
164-
# get_batches(x_esm, dtype=float, batch_size=1),
165-
# esm_attention_mask,
166-
# esm_base_model
167-
#)
168164
y_esm = inference(sequences, 'esm', model=esm_base_model, verbose=False)
169165
print(f'ESM1v (unsupervised performance): '
170166
f'{spearmanr(fitnesses, y_esm.cpu())[0]:.3f}')
@@ -177,10 +173,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
177173
pdb, prosst_tokenizer, wt_seq, verbose=False
178174
)
179175
x_prosst = prosst_tokenize_sequences(sequences=sequences, vocab=prosst_vocab, verbose=False)
180-
#y_prosst = get_logits_from_full_seqs(
181-
# x_prosst, prosst_base_model, input_ids, prosst_attention_mask,
182-
# structure_input_ids, train=False
183-
#)
184176
y_prosst = inference(sequences, 'prosst', pdb_file=pdb, wt_seq=wt_seq, model=prosst_base_model, verbose=False)
185177
print(f'ProSST (unsupervised performance): '
186178
f'{spearmanr(fitnesses, y_prosst.cpu())[0]:.3f}')
@@ -192,8 +184,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
192184
print('Both LLM\'s had RunTimeErrors, skipping dataset...')
193185
continue
194186

195-
ns_y_test = [len(variants)]
196-
ds = DatasetSplitter(df_or_csv_file=csv_path, n_cv=5, mutation_separator=mut_sep)
187+
ds = DatasetSplitter(df_or_csv_file=csv_path, n_cv=N_CV, mutation_separator=mut_sep)
197188
ds.plot_distributions()
198189
if max_muts >= 2: # Only using random cross-validation splits
199190
print("Only performing random splits as data contains multi-substituted variants...")
@@ -202,17 +193,20 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
202193
print("Only single substituted variants found, performing random, modulo, and continuous data splits...")
203194
target_split_indices = ds.get_all_split_indices()
204195
temp_results = {}
205-
# TODO: Get correct indices for full df for multi-muts using DatasetSplitter!
196+
for c in ["Random", "Modulo", "Continuous"]:
197+
temp_results.update({c: {}})
198+
for s in range(N_CV):
199+
temp_results[c].update({f'Split {s}': {}})
200+
for m in ['DCA', 'ESM1v', 'ProSST', 'DCA hybrid', 'DCA+ESM1v hybrid', 'DCA+ProSST hybrid']:
201+
# Prefill with NaN's
202+
temp_results[c][f'Split {s}'].update({m: np.nan})
206203
for i_category, (train_indices, test_indices) in enumerate(target_split_indices):
207204
category = ["Random", "Modulo", "Continuous"][i_category]
208205
print(f'Category: {category}')
209-
temp_results.update({category: {}})
210206
for i_split, (train_i, test_i) in enumerate(zip(
211207
train_indices, test_indices
212208
)):
213209
print(f' Split: {i_split + 1}')
214-
print(test_i)
215-
temp_results[category].update({f'Split {i_split}': {}})
216210
try:
217211
_train_sequences, test_sequences = np.asarray(sequences)[train_i], np.asarray(sequences)[test_i]
218212
x_dca_train, x_dca_test = np.asarray(x_dca)[train_i], np.asarray(x_dca)[test_i]
@@ -224,14 +218,10 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
224218
esm_lora_model_2 = copy.deepcopy(esm_lora_model)
225219
esm_optimizer = torch.optim.Adam(esm_lora_model_2.parameters(), lr=0.0001)
226220
train_size, test_size = len(train_i), len(test_i)
227-
#get_vram()
228221
except ValueError as e:
229222
print(f"Only {len(fitnesses)} variant-fitness pairs in total, "
230223
f"cannot split the data in N_Train = {train_size} and N_Test "
231224
f"(N_Total - N_Train) [Excepted error: {e}].")
232-
for m in ['DCA', 'ESM1v', 'ProSST', 'DCA hybrid', 'DCA+ESM1v hybrid', 'DCA+ProSST hybrid']:
233-
temp_results[category][f'Split {i_split}'].update({m: np.nan})
234-
ns_y_test.append(np.nan)
235225
continue
236226
(
237227
x_dca_train,
@@ -276,9 +266,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
276266
f"in N_Train = {len(y_train)} and N_Test = {len(y_test)} "
277267
f"results in N_Test <= 50 variants - not getting "
278268
f"performance for N_Train = {len(y_train)}...")
279-
ns_y_test.append(np.nan)
280-
for m in ['DCA', 'ESM1v', 'ProSST', 'DCA hybrid', 'DCA+ESM1v hybrid', 'DCA+ProSST hybrid']:
281-
temp_results[category][f'Split {i_split}'].update({m: np.nan})
282269
continue
283270

284271
y_test_pred_dca = get_delta_e_statistical_model(x_dca_test, x_wt)
@@ -313,10 +300,8 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
313300
print(f' {m_str} (split {i_split + 1}) performance: {spearmanr(y_test, y_test_pred)[0]:.3f} '
314301
f'(train size={train_size}, test_size={test_size})')
315302
temp_results[category][f'Split {i_split}'].update({m_str: spearmanr(y_test, y_test_pred)[0]})
316-
except RuntimeError as e: # modeling_prosst.py, line 920, in forward
317-
# or UnboundLocalError in prosst_lora_tune.py, line 167
318-
temp_results[category][f'Split {i_split}'].update({m_str: np.nan})
319-
ns_y_test.append(len(y_test_pred))
303+
except RuntimeError as e: # modeling_prosst.py in forward
304+
continue
320305
del prosst_lora_model_2
321306
del esm_lora_model_2
322307
torch.cuda.empty_cache()
@@ -358,7 +343,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
358343
f'{int(dt)}\n')
359344

360345

361-
def plot_csv_data(csv, plot_name):
346+
def plot_csv_data(csv):
362347
plt.figure(figsize=(24, 12))
363348
sns.set_style("whitegrid")
364349
df = pd.read_csv(csv, sep=',')
@@ -487,4 +472,4 @@ def plot_csv_data(csv, plot_name):
487472
):
488473
fh2.write(line)
489474

490-
plot_csv_data(csv=clean_out_results_csv, plot_name='mut_performance')
475+
plot_csv_data(csv=clean_out_results_csv)

0 commit comments

Comments
 (0)