Skip to content

Commit 9e71e2b

Browse files
committed
Fix model is None and add verbosity parameters to many functions
1 parent 4f4a22d commit 9e71e2b

File tree

8 files changed

+112
-59
lines changed

8 files changed

+112
-59
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,6 @@ def train_llm(self):
399399
input_ids=self.input_ids,
400400
attention_mask=self.llm_attention_mask,
401401
structure_input_ids=self.structure_input_ids,
402-
train=False,
403402
device=self.device
404403
)
405404
y_llm_ttrain = self.llm_inference_function(
@@ -408,7 +407,6 @@ def train_llm(self):
408407
input_ids=self.input_ids,
409408
attention_mask=self.llm_attention_mask,
410409
structure_input_ids=self.structure_input_ids,
411-
train=False,
412410
device=self.device
413411
)
414412
elif self.llm_key == 'esm1v':
@@ -598,7 +596,6 @@ def hybrid_prediction(
598596
self.input_ids,
599597
self.llm_attention_mask,
600598
self.structure_input_ids,
601-
train=False,
602599
verbose=verbose,
603600
device=self.device).detach().cpu().numpy()
604601
y_llm_lora = self.llm_inference_function(
@@ -607,7 +604,6 @@ def hybrid_prediction(
607604
self.input_ids,
608605
self.llm_attention_mask,
609606
self.structure_input_ids,
610-
train=False,
611607
verbose=verbose,
612608
device=self.device).detach().cpu().numpy()
613609
elif self.llm_key == 'esm1v':
@@ -1004,7 +1000,8 @@ def performance_ls_ts(
10041000
x_train_dca=np.array(x_train),
10051001
y_train=np.array(y_train),
10061002
llm_model_input=llm_dict,
1007-
x_wt=x_wt
1003+
x_wt=x_wt,
1004+
device=device
10081005
)
10091006
y_test_pred = hybrid_model.hybrid_prediction(np.array(x_test), x_llm_test)
10101007
logger.info(f'Hybrid performance: {spearmanr(y_test, y_test_pred)[0]:.3f} N={len(y_test)}')

pypef/llm/esm_lora_tune.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def get_y_pred_scores(encoded_sequences, attention_masks,
8181
return log_probs
8282

8383

84-
def esm_test(xs, attention_mask, scores, loss_fn, model, device: str | None = None, verbose: bool = True):
84+
def esm_test(xs, attention_mask, scores, loss_fn, model,
85+
device: str | None = None, verbose: bool = True):
8586
if device is None:
8687
device = get_device()
8788
attention_masks = torch.Tensor(np.full(
@@ -148,15 +149,18 @@ def esm_train(xs, attention_mask, scores, loss_fn, model, optimizer, n_epochs=3,
148149
attention_masks = torch.Tensor(np.full(
149150
shape=np.shape(xs), fill_value=attention_mask)).to(torch.int64)
150151
xs, attention_masks, scores = xs.to(device), attention_masks.to(device), scores.to(device)
151-
pbar_epochs = tqdm(range(1, n_epochs + 1))
152+
pbar_epochs = tqdm(range(1, n_epochs + 1), disable=not verbose)
152153
loss = np.nan
153154
for epoch in pbar_epochs:
154155
try:
155156
pbar_epochs.set_description(f'Epoch: {epoch}/{n_epochs}. Loss: {loss.detach():>1f}')
156157
except AttributeError:
157158
pbar_epochs.set_description(f'Epoch: {epoch}/{n_epochs}')
158159
model.train()
159-
pbar_batches = tqdm(zip(xs, attention_masks, scores), total=len(xs), leave=False, disable=not verbose)
160+
pbar_batches = tqdm(
161+
zip(xs, attention_masks, scores),
162+
total=len(xs), leave=False, disable=not verbose
163+
)
160164
for batch, (xs_b, attns_b, scores_b) in enumerate(pbar_batches):
161165
xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64)
162166
y_preds_b = get_y_pred_scores(xs_b, attns_b, model, device=device)
@@ -173,11 +177,11 @@ def esm_train(xs, attention_mask, scores, loss_fn, model, optimizer, n_epochs=3,
173177
model.train(False)
174178

175179

176-
def esm_setup(sequences, device: str | None = None):
180+
def esm_setup(sequences, device: str | None = None, verbose: bool = True):
177181
esm_base_model, esm_lora_model, esm_tokenizer, esm_optimizer = get_esm_models()
178182
esm_base_model = esm_base_model.to(device)
179183
x_esm, esm_attention_mask = esm_tokenize_sequences(
180-
sequences, esm_tokenizer, max_length=len(sequences[0]))
184+
sequences, esm_tokenizer, max_length=len(sequences[0]), verbose=verbose)
181185
llm_dict_esm = {
182186
'esm1v': {
183187
'llm_base_model': esm_base_model,

pypef/llm/inference.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from pypef.utils.helpers import get_device
1010
from pypef.llm.utils import get_batches
11-
from pypef.llm.esm_lora_tune import esm_setup, esm_tokenize_sequences
12-
from pypef.llm.prosst_lora_tune import prosst_setup, prosst_tokenize_sequences
11+
from pypef.llm.esm_lora_tune import esm_setup, esm_tokenize_sequences, esm_infer
12+
from pypef.llm.prosst_lora_tune import prosst_setup, prosst_tokenize_sequences, prosst_infer
1313

1414
import logging
1515
logger = logging.getLogger('pypef.llm.inference')
@@ -40,38 +40,42 @@ def inference(
4040
pdb_file: str | None = None,
4141
wt_seq: str | None = None,
4242
device: str| None = None,
43-
model = None
43+
model = None,
44+
verbose: bool = True
4445
):
4546
"""
4647
Inference of input or base model.
4748
"""
4849
if device is None:
4950
device = get_device()
51+
if llm == 'esm':
5052
logger.info("Zero-shot LLM inference on test set using ESM1v...")
51-
llm_dict = esm_setup(sequences)
52-
if llm == 'esm':
53-
if model is None:
54-
model = llm_dict['esm1v']['llm_base_model']
55-
x_llm_test = llm_embedder(llm_dict, sequences)
56-
y_test_pred = llm_dict['esm1v']['llm_inference_function'](
53+
llm_dict = esm_setup(sequences, verbose=verbose)
54+
if model is None:
55+
model = llm_dict['esm1v']['llm_base_model']
56+
x_llm_test = llm_embedder(llm_dict, sequences, verbose)
57+
y_test_pred = esm_infer(#llm_dict['esm1v']['llm_inference_function'](
5758
xs=get_batches(x_llm_test, batch_size=1, dtype=int),
5859
attention_mask=llm_dict['esm1v']['llm_attention_mask'],
5960
model=model,
60-
device=device
61+
device=device,
62+
verbose=verbose
6163
).cpu()
6264
elif llm == 'prosst':
63-
if model is None:
64-
model = llm_dict['prosst']['llm_base_model']
6565
logger.info("Zero-shot LLM inference on test set using ProSST...")
6666
llm_dict = prosst_setup(
67-
wt_seq, pdb_file, sequences=sequences)
68-
x_llm_test = llm_embedder(llm_dict, sequences)
69-
y_test_pred = llm_dict['prosst']['llm_inference_function'](
67+
wt_seq, pdb_file, sequences=sequences, verbose=verbose
68+
)
69+
if model is None:
70+
model = llm_dict['prosst']['llm_base_model']
71+
x_llm_test = llm_embedder(llm_dict, sequences, verbose)
72+
y_test_pred = prosst_infer(#llm_dict['prosst']['llm_inference_function'](
7073
xs=x_llm_test,
7174
model=model,
7275
input_ids=llm_dict['prosst']['input_ids'],
7376
attention_mask=llm_dict['prosst']['llm_attention_mask'],
7477
structure_input_ids=llm_dict['prosst']['structure_input_ids'],
78+
verbose=verbose,
7579
device=device
7680
).cpu()
7781
else:

pypef/llm/prosst_lora_tune.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,27 @@ def get_logits_from_full_seqs(
9494
return log_probs
9595

9696

97+
def prosst_infer(
98+
xs,
99+
model,
100+
input_ids,
101+
attention_mask,
102+
structure_input_ids,
103+
verbose: bool = False,
104+
device: str | None = None
105+
):
106+
return get_logits_from_full_seqs(
107+
xs,
108+
model,
109+
input_ids,
110+
attention_mask,
111+
structure_input_ids,
112+
train = False,
113+
verbose = verbose,
114+
device = device
115+
)
116+
117+
97118
def checkpoint(model, filename):
98119
torch.save(model.state_dict(), filename)
99120

@@ -205,8 +226,8 @@ def get_prosst_models():
205226
return prosst_base_model, prosst_lora_model, tokenizer, optimizer
206227

207228

208-
def get_structure_quantizied(pdb_file, tokenizer, wt_seq):
209-
structure_sequence = PdbQuantizer()(pdb_file=pdb_file)
229+
def get_structure_quantizied(pdb_file, tokenizer, wt_seq, verbose: bool = True):
230+
structure_sequence = PdbQuantizer(verbose=verbose)(pdb_file=pdb_file)
210231
structure_sequence_offset = [i + 3 for i in structure_sequence]
211232
tokenized_res = tokenizer([wt_seq], return_tensors='pt')
212233
input_ids = tokenized_res['input_ids']
@@ -216,7 +237,7 @@ def get_structure_quantizied(pdb_file, tokenizer, wt_seq):
216237
return input_ids, attention_mask, structure_input_ids
217238

218239

219-
def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None):
240+
def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None, verbose: bool = True):
220241
if wt_seq is None:
221242
raise SystemError(
222243
"Running ProSST requires a wild-type sequence "
@@ -240,16 +261,18 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None):
240261
prosst_base_model = prosst_base_model.to(device)
241262
prosst_optimizer = torch.optim.Adam(prosst_lora_model.parameters(), lr=0.0001)
242263
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
243-
pdb_file, prosst_tokenizer, wt_seq)
264+
pdb_file, prosst_tokenizer, wt_seq, verbose=verbose
265+
)
244266
x_llm_train_prosst = prosst_tokenize_sequences(
245-
sequences=sequences, vocab=prosst_vocab)
267+
sequences=sequences, vocab=prosst_vocab, verbose=verbose
268+
)
246269
llm_dict_prosst = {
247270
'prosst': {
248271
'llm_base_model': prosst_base_model,
249272
'llm_model': prosst_lora_model,
250273
'llm_optimizer': prosst_optimizer,
251274
'llm_train_function': prosst_train,
252-
'llm_inference_function': get_logits_from_full_seqs,
275+
'llm_inference_function': prosst_infer,
253276
'llm_loss_function': corr_loss,
254277
'x_llm' : x_llm_train_prosst,
255278
'llm_attention_mask': prosst_attention_mask,

pypef/llm/prosst_structure/quantizer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ def process_pdb_file(
408408
pdb_file,
409409
subgraph_depth,
410410
subgraph_interval,
411-
max_distance
411+
max_distance,
412+
verbose: bool = True
412413
):
413414
result_dict, subgraph_dict = {}, {}
414415
result_dict["name"] = Path(pdb_file).name
@@ -436,7 +437,7 @@ def process_subgraph(anchor_node):
436437
subgraph = convert_graph(subgraph)
437438
return anchor_node, subgraph
438439

439-
for anchor_node in tqdm(anchor_nodes, desc='Getting ProSST structure embeddings'):
440+
for anchor_node in tqdm(anchor_nodes, desc='Getting ProSST structure embeddings', disable=not verbose):
440441
anchor, subgraph = process_subgraph(anchor_node)
441442
subgraph_dict[anchor] = subgraph
442443

@@ -449,7 +450,8 @@ def pdb_conventer(
449450
pdb_files,
450451
subgraph_depth,
451452
subgraph_interval,
452-
max_distance
453+
max_distance,
454+
verbose: bool = True
453455
):
454456
error_proteins, error_messages = [], []
455457
dataset, results, node_counts = [], [], []
@@ -460,6 +462,7 @@ def pdb_conventer(
460462
subgraph_depth,
461463
subgraph_interval,
462464
max_distance,
465+
verbose=verbose
463466
)
464467

465468
if pdb_subgraphs is None:
@@ -502,7 +505,8 @@ def __init__(
502505
model_path=None,
503506
cluster_dir=None,
504507
cluster_model=None,
505-
device=None
508+
device=None,
509+
verbose: bool = True
506510
) -> None:
507511
self.max_distance = max_distance
508512
self.subgraph_depth = subgraph_depth
@@ -512,6 +516,7 @@ def __init__(
512516
self.device = get_device()
513517
else:
514518
self.device = device
519+
self.verbose = verbose
515520
if model_path is None:
516521
if self.device == 'cpu':
517522
self.model_path = str(Path(__file__).parent / "static" / "AE_CPU.pt")
@@ -554,7 +559,8 @@ def __call__(self, pdb_file, return_residue_seq=False):
554559
],
555560
self.subgraph_depth,
556561
self.subgraph_interval,
557-
self.max_distance
562+
self.max_distance,
563+
verbose=self.verbose
558564
)
559565
sturctures = predict_structure(
560566
self.model, self.cluster_models, data_loader, self.device

scripts/ProteinGym_runs/protgym_hybrid_perf_test_crossval.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,13 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
140140
gremlin = GREMLIN(alignment=msa_path, opt_iter=100, optimize=True)
141141
sequences_batched = get_batches(sequences, batch_size=1000,
142142
dtype=str, keep_remaining=True, verbose=True)
143-
x_dca = []
144-
for seq_b in tqdm(sequences_batched, desc="Getting GREMLIN sequence encodings"):
143+
x_dca = [] # required later on also
144+
for seq_b in tqdm(sequences_batched, desc="Getting GREMLIN sequence encodings", disable=True):
145145
for x in gremlin.collect_encoded_sequences(seq_b):
146146
x_dca.append(x)
147147
x_wt = gremlin.x_wt
148148
y_pred_dca = get_delta_e_statistical_model(x_dca, x_wt)
149-
print(f'DCA (unsupervised performance): {spearmanr(fitnesses, y_pred_dca)[0]:.3f}')
149+
print(f'DCA (unsupervised performance): {spearmanr(fitnesses, y_pred_dca)[0]:.3f}')
150150
dca_unopt_perf = spearmanr(fitnesses, y_pred_dca)[0]
151151
# ESM unsupervised
152152
try:
@@ -158,7 +158,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
158158
# esm_attention_mask,
159159
# esm_base_model
160160
#)
161-
y_esm = inference(sequences, 'esm', model=esm_base_model)
161+
y_esm = inference(sequences, 'esm', model=esm_base_model, verbose=False)
162162
print(f'ESM1v (unsupervised performance): '
163163
f'{spearmanr(fitnesses, y_esm.cpu())[0]:.3f}')
164164
esm_unopt_perf = spearmanr(fitnesses, y_esm.cpu())[0]
@@ -167,13 +167,14 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
167167
# ProSST unsupervised
168168
try:
169169
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
170-
pdb, prosst_tokenizer, wt_seq)
170+
pdb, prosst_tokenizer, wt_seq, verbose=False
171+
)
171172
x_prosst = prosst_tokenize_sequences(sequences=sequences, vocab=prosst_vocab, verbose=False)
172173
#y_prosst = get_logits_from_full_seqs(
173174
# x_prosst, prosst_base_model, input_ids, prosst_attention_mask,
174175
# structure_input_ids, train=False
175176
#)
176-
y_prosst = inference(sequences, 'prosst', pdb_file=pdb, wt_seq=wt_seq, model=prosst_base_model)
177+
y_prosst = inference(sequences, 'prosst', pdb_file=pdb, wt_seq=wt_seq, model=prosst_base_model, verbose=False)
177178
print(f'ProSST (unsupervised performance): '
178179
f'{spearmanr(fitnesses, y_prosst.cpu())[0]:.3f}')
179180
prosst_unopt_perf = spearmanr(fitnesses, y_prosst.cpu())[0]
@@ -195,9 +196,10 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
195196
for i_split, (train_i, test_i) in enumerate(zip(
196197
train_indices, test_indices
197198
)):
198-
print(f'Split: {i_split + 1}')
199+
print(f' Split: {i_split + 1}')
199200
temp_results[category].update({f'Split {i_split}': {}})
200201
try:
202+
_train_sequences, test_sequences = np.asarray(sequences)[train_i], np.asarray(sequences)[test_i]
201203
x_dca_train, x_dca_test = np.asarray(x_dca)[train_i], np.asarray(x_dca)[test_i]
202204
x_llm_train_prosst, x_llm_test_prosst = np.asarray(x_prosst)[train_i], np.asarray(x_prosst)[test_i]
203205
x_llm_train_esm, x_llm_test_esm = np.asarray(x_esm)[train_i], np.asarray(x_esm)[test_i]
@@ -253,7 +255,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
253255
'structure_input_ids': structure_input_ids
254256
}
255257
}
256-
print(f'Train: {len(np.array(y_train))} --> Test: {len(np.array(y_test))}')
258+
print(f' Train: {len(np.array(y_train))} --> Test: {len(np.array(y_test))}')
257259
if len(y_test) <= 20: # TODO: 50
258260
print(f"Only {len(fitnesses)} in total, splitting the data "
259261
f"in N_Train = {len(y_train)} and N_Test = {len(y_test)} "
@@ -264,6 +266,17 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
264266
ns_y_test.append(np.nan)
265267
continue
266268
#get_vram()
269+
270+
y_test_pred_dca = get_delta_e_statistical_model(x_dca_test, x_wt)
271+
temp_results[category][f'Split {i_split}'].update({'DCA': spearmanr(y_test, y_test_pred_dca)[0]})
272+
print(f' DCA ZeroShot (split {i_split + 1}) performance: {spearmanr(y_test, y_test_pred_dca)[0]:.3f}')
273+
y_test_pred_esm = inference(test_sequences, 'esm', model=esm_base_model, verbose=False)
274+
temp_results[category][f'Split {i_split}'].update({'ESM1v': spearmanr(y_test, y_test_pred_esm)[0]})
275+
print(f' ESM1v ZeroShot (split {i_split + 1}) performance: {spearmanr(y_test, y_test_pred_esm)[0]:.3f}')
276+
y_test_pred_prosst = inference(test_sequences, 'prosst', model=prosst_base_model, pdb_file=pdb, wt_seq=wt_seq, verbose=False)
277+
temp_results[category][f'Split {i_split}'].update({'ProSST': spearmanr(y_test, y_test_pred_prosst)[0]})
278+
print(f' ProSST ZeroShot (split {i_split + 1}) performance: {spearmanr(y_test, y_test_pred_prosst)[0]:.3f}')
279+
267280
for i_m, method in enumerate([None, llm_dict_esm, llm_dict_prosst]):
268281
m_str = ['DCA hybrid', 'DCA+ESM1v hybrid', 'DCA+ProSST hybrid'][i_m]
269282
#print('\n~~~ ' + m_str + ' ~~~')
@@ -284,7 +297,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
284297
][i_m],
285298
verbose=False
286299
)
287-
print(f'{m_str} (split {i_split + 1}) performance: {spearmanr(y_test, y_test_pred)[0]:.3f} '
300+
print(f' {m_str} (split {i_split + 1}) performance: {spearmanr(y_test, y_test_pred)[0]:.3f} '
288301
f'(train size={train_size}, test_size={test_size})')
289302
temp_results[category][f'Split {i_split}'].update({m_str: spearmanr(y_test, y_test_pred)[0]})
290303
except RuntimeError as e: # modeling_prosst.py, line 920, in forward

0 commit comments

Comments
 (0)