Skip to content

Commit 335256c

Browse files
committed
Update API test
1 parent 31c15e2 commit 335256c

File tree

2 files changed

+43
-25
lines changed

2 files changed

+43
-25
lines changed

pypef/llm/esm_lora_tune.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,14 @@ def get_batches(a, dtype, batch_size=5, keep_numpy: bool = False, verbose: bool
107107
return torch.Tensor(a).to(dtype)
108108

109109

110-
def esm_test(xs, attns, scores, loss_fn, model, device: str | None = None):
110+
def esm_test(xs, attention_mask, scores, loss_fn, model, device: str | None = None):
111111
if device is None:
112112
device = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
113+
attention_masks = torch.Tensor(np.full(shape=np.shape(xs), fill_value=attention_mask)).to(torch.int64)
113114
print(f'Infering model for testing using {device.upper()} device...')
114115
model = model.to(device)
115-
xs, attns, scores = xs.to(device), attns.to(device), scores.to(torch.float).to(device)
116-
pbar_epochs = tqdm(zip(xs, attns, scores), total=len(xs))
116+
xs, attention_masks, scores = torch.Tensor(xs).to(device), attention_masks.to(device), torch.Tensor(scores).to(torch.float).to(device)
117+
pbar_epochs = tqdm(zip(xs, attention_masks, scores), total=len(xs))
117118
for i ,(xs_b, attns_b, scores_b) in enumerate(pbar_epochs):
118119
xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64)
119120
with torch.no_grad():

tests/test_api_functions.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_gremlin():
7171
)
7272

7373

74-
def test_hybrid_model():
74+
def test_hybrid_model_dca_esm():
7575
g = GREMLIN(
7676
alignment=msa_file_aneh,
7777
char_alphabet="ARNDCQEGHILKMFPSTWYV-",
@@ -90,45 +90,58 @@ def test_hybrid_model():
9090
print(len(train_seqs[0]), train_seqs[0])
9191
assert len(train_seqs[0]) == len(g.wt_seq)
9292
base_model, lora_model, tokenizer, optimizer = get_esm_models()
93-
encoded_seqs_train, attention_masks_train = esm_tokenize_sequences(list(train_seqs), tokenizer, max_length=len(train_seqs[0]))
94-
x_esm_b, attention_masks_b, train_ys_b = (
95-
get_batches(encoded_seqs_train, dtype=int),
96-
get_batches(attention_masks_train, dtype=int),
93+
x_train_esm, attention_mask_esm = esm_tokenize_sequences(
94+
list(train_seqs), tokenizer, max_length=len(train_seqs[0]))
95+
x_esm_b, train_ys_b = (
96+
get_batches(x_train_esm, dtype=int),
9797
get_batches(train_ys, dtype=float)
9898
)
99-
y_true, y_pred_esm = esm_test(x_esm_b, attention_masks_b, train_ys_b, loss_fn=corr_loss, model=base_model)
99+
y_true, y_pred_esm = esm_test(
100+
x_esm_b, attention_mask_esm, train_ys_b,
101+
loss_fn=corr_loss, model=base_model
102+
)
100103
print(spearmanr(
101104
y_true,
102105
y_pred_esm
103106
), len(y_true))
104107

108+
llm_dict_esm = {
109+
'esm1v': {
110+
'llm_base_model': base_model,
111+
'llm_model': lora_model,
112+
'llm_optimizer': optimizer,
113+
'llm_train_function': esm_train,
114+
'llm_inference_function': esm_infer,
115+
'llm_loss_function': corr_loss,
116+
'x_llm_train' : x_train_esm,
117+
'llm_attention_mask': attention_mask_esm
118+
}
119+
}
120+
105121
hm = DCALLMHybridModel(
106122
x_train_dca=np.array(x_dca_train),
107-
x_train_llm=np.array(encoded_seqs_train),
108-
x_train_llm_attention_mask=np.array(attention_masks_train),
109123
y_train=train_ys,
110-
llm_model=lora_model,
111-
llm_base_model=base_model,
112-
llm_optimizer=optimizer,
113-
llm_train_function=esm_train,
114-
llm_inference_function=esm_infer,
115-
llm_loss_function=corr_loss,
124+
llm_model_input=llm_dict_esm,
116125
x_wt=g.x_wt,
117126
seed=42
118127
)
119128

120129
x_dca_test = g.get_scores(test_seqs, encode=True)
121-
encoded_seqs_test, attention_masks_test = esm_tokenize_sequences(list(test_seqs), tokenizer, max_length=len(test_seqs[0]))
122-
y_pred_test = hm.hybrid_prediction(x_dca=x_dca_test, x_llm=encoded_seqs_test, attns_llm=attention_masks_test)
130+
encoded_seqs_test, attention_masks_test = esm_tokenize_sequences(
131+
list(test_seqs), tokenizer, max_length=len(test_seqs[0]))
132+
y_pred_test = hm.hybrid_prediction(x_dca=x_dca_test, x_llm=encoded_seqs_test)
123133
print(hm.beta1, hm.beta2, hm.beta3, hm.beta4, hm.ridge_opt)
124134
print('hm.y_dca_ttest', spearmanr(hm.y_ttest, hm.y_dca_ttest), len(hm.y_ttest))
125135
print('hm.y_dca_ridge_ttest', spearmanr(hm.y_ttest, hm.y_dca_ridge_ttest), len(hm.y_ttest))
126136
print('hm.y_llm_ttest', spearmanr(hm.y_ttest, hm.y_llm_ttest), len(hm.y_ttest))
127137
print('hm.y_llm_lora_ttest', spearmanr(hm.y_ttest, hm.y_llm_lora_ttest), len(hm.y_ttest))
128138
print('Hybrid', spearmanr(test_ys, y_pred_test), len(test_ys))
129-
np.testing.assert_almost_equal(spearmanr(hm.y_ttest, hm.y_dca_ttest)[0], -0.5342743713116743, decimal=5)
130-
np.testing.assert_almost_equal(spearmanr(hm.y_ttest, hm.y_dca_ridge_ttest)[0], 0.717333573331078, decimal=5)
131-
np.testing.assert_almost_equal(spearmanr(hm.y_ttest, hm.y_llm_ttest)[0], -0.21761360470606333, decimal=5)
139+
np.testing.assert_almost_equal(spearmanr(
140+
hm.y_ttest, hm.y_dca_ttest)[0], -0.5342743713116743, decimal=5)
141+
np.testing.assert_almost_equal(spearmanr(
142+
hm.y_ttest, hm.y_dca_ridge_ttest)[0], 0.717333573331078, decimal=5)
143+
np.testing.assert_almost_equal(spearmanr(
144+
hm.y_ttest, hm.y_llm_ttest)[0], -0.21761360470606333, decimal=5)
132145
# Nondeterministic behavior, should be about ~ 0.8, checking if not NaN
133146
# Torch reproducibility documentation: https://pytorch.org/docs/stable/notes/randomness.html
134147
assert -1.0 <= spearmanr(hm.y_ttest, hm.y_llm_lora_ttest)[0] <= 1.0
@@ -137,8 +150,12 @@ def test_hybrid_model():
137150

138151
def test_dataset_b_results():
139152
aaindex = "WOLR810101.txt"
140-
x_fft_train, _ = AAIndexEncoding(full_aaidx_txt_path(aaindex), train_seqs).collect_encoded_sequences()
141-
x_fft_test, _ = AAIndexEncoding(full_aaidx_txt_path(aaindex), test_seqs).collect_encoded_sequences()
153+
x_fft_train, _ = AAIndexEncoding(
154+
full_aaidx_txt_path(aaindex), train_seqs
155+
).collect_encoded_sequences()
156+
x_fft_test, _ = AAIndexEncoding(
157+
full_aaidx_txt_path(aaindex), test_seqs
158+
).collect_encoded_sequences()
142159
performances = get_regressor_performances(
143160
x_learn=x_fft_train,
144161
x_test=x_fft_test,
@@ -157,6 +174,6 @@ def test_dataset_b_results():
157174

158175
if __name__ == "__main__":
159176
test_gremlin()
160-
test_hybrid_model()
177+
test_hybrid_model_dca_esm()
161178
test_dataset_b_results()
162179

0 commit comments

Comments
 (0)