Skip to content

Commit 84fc7a3

Browse files
authored
Merge pull request #13 from OpenProteinAI/ttruong/test-embedding-e2e
embeddings api e2e tests
2 parents d5a03c3 + 37cce68 commit 84fc7a3

File tree

6 files changed

+507
-3
lines changed

6 files changed

+507
-3
lines changed

openprotein/api/jobs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ class Job(BaseModel):
4141
start_date: Optional[datetime]
4242
end_date: Optional[datetime]
4343
prerequisite_job_id: Optional[str]
44-
progress_message: Optional[str]
44+
progress_message: Optional[str] = None
4545
progress_counter: Optional[int]
46-
num_records: Optional[int]
46+
num_records: Optional[int] = None
4747

4848
def refresh(self, session: APISession):
4949
"""refresh job status"""

openprotein/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class SVDMetadata(BaseModel):
5959
model_id: str
6060
n_components: int
6161
reduction: Optional[str]
62-
sequence_length: Optional[int]
62+
sequence_length: Optional[int] = None
6363

6464
def is_done(self):
6565
return self.status.done()

tests/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
def pytest_addoption(parser):
2+
# adapted from https://stackoverflow.com/a/33181491
3+
parser.addoption('--longrun', action='store_true', dest="longrun",
4+
default=False, help="enable longrundecorated tests")
5+
6+
7+
def pytest_configure(config):
8+
config.addinivalue_line(
9+
"markers", "longrun: mark tests that take a long time to run"
10+
)
11+
if not config.option.longrun:
12+
if config.option.markexpr != "":
13+
setattr(
14+
config.option,
15+
'markexpr',
16+
config.option.markexpr + ' and not longrun',
17+
)
18+
else:
19+
setattr(config.option, 'markexpr', 'not longrun')

tests/test_embeddings_e2e.py

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
"""
2+
assumes models are available at data/models
3+
change seeds at the appropriate places to avoid backend caching
4+
set --longrun flag when running pytest to run these tests
5+
"""
6+
import json
7+
import os
8+
import time
9+
from pathlib import Path
10+
from typing import Optional, Union
11+
12+
import numpy as np
13+
14+
import torch
15+
16+
import fsspec
17+
from fsspec.implementations.dirfs import DirFileSystem
18+
19+
import pytest
20+
21+
from esm import ESM2, ProteinBertModel
22+
from esm.pretrained import load_model_and_alphabet_core, _has_regression_weights
23+
24+
from protembed.alphabets import Uniprot21
25+
from protembed.datasets import pad_tensor_1d
26+
from protembed.factory import ProtembedModelLoader
27+
28+
import openprotein
29+
from openprotein import OpenProtein
30+
from openprotein.api.embedding import SVDModel
31+
from tests.utils.svd import TorchLowRankSVDTransform
32+
33+
34+
ALPHABET = Uniprot21()
35+
36+
37+
def load_model_and_alphabet_local(model_location, device):
38+
"""Load from local path. The regression weights need to be co-located"""
39+
model_location = Path(model_location)
40+
model_data = torch.load(str(model_location), map_location=device)
41+
model_name = model_location.stem
42+
if _has_regression_weights(model_name):
43+
regression_location = str(model_location.with_suffix("")) + "-contact-regression.pt"
44+
regression_data = torch.load(regression_location, map_location=device)
45+
else:
46+
regression_data = None
47+
return load_model_and_alphabet_core(model_name, model_data, regression_data)
48+
49+
50+
@pytest.fixture()
51+
def session() -> OpenProtein:
52+
with open("secrets.config", "r") as f:
53+
config = json.load(f)
54+
return openprotein.connect(
55+
config["username"],
56+
config["password"],
57+
backend="https://dev.api.openprotein.ai/api/",
58+
)
59+
60+
61+
@pytest.fixture()
62+
def loader() -> ProtembedModelLoader:
63+
root_fs = fsspec.filesystem('file')
64+
dir_fs = DirFileSystem("data/models", root_fs)
65+
return ProtembedModelLoader(dir_fs)
66+
67+
68+
@pytest.fixture()
69+
def sequences() -> list[bytes]:
70+
rng = np.random.default_rng(188501)
71+
return [
72+
ALPHABET.decode(rng.integers(
73+
low=0,
74+
high=21,
75+
size=rng.integers(250, 500),
76+
))
77+
for _ in range(5)
78+
]
79+
80+
81+
@pytest.fixture()
82+
def same_length_sequences() -> list[bytes]:
83+
rng = np.random.default_rng(376735)
84+
return [
85+
ALPHABET.decode(rng.integers(low=0, high=21, size=331))
86+
for _ in range(5)
87+
]
88+
89+
90+
@pytest.mark.longrun
91+
@pytest.mark.parametrize("local_model_id,model_id", [
92+
("prosst", "prot-seq"),
93+
("rotaprot-seq-900m-uniref90-v1", "rotaprot-large-uniref90-ft"),
94+
])
95+
@torch.inference_mode()
96+
def test_protembed(
97+
loader: ProtembedModelLoader,
98+
local_model_id: str,
99+
session: OpenProtein,
100+
model_id: str,
101+
sequences: list[bytes],
102+
):
103+
print("testing...", model_id)
104+
local_model = loader.load_model(local_model_id, device=torch.device("cuda"))
105+
model = session.embedding.get_model(model_id=model_id)
106+
107+
sequences_as_idxs, mask = pad_tensor_1d(
108+
[torch.from_numpy(ALPHABET.encode(s)).cuda().long() for s in sequences],
109+
ALPHABET.mask_token,
110+
return_padding=True,
111+
)
112+
with torch.autocast(device_type="cuda", dtype=torch.float16):
113+
_, attn = local_model.embed(
114+
sequences_as_idxs, padding_mask=mask, return_attention=True
115+
)
116+
embeddings = local_model.embed(sequences_as_idxs, padding_mask=mask)
117+
logits = local_model.logits(embeddings)
118+
if isinstance(attn, list):
119+
# TODO: kinda hacky. doing this b/c inferface of prosst and rotaformer are not
120+
# the same
121+
attn = torch.stack(attn, dim=1)
122+
attn = [
123+
x.float().cpu().numpy()[-1][:, :len(s), :len(s)]
124+
for s, x in zip(sequences, attn)
125+
]
126+
embeddings = [
127+
x.float().cpu().numpy()[:len(s)]
128+
for s, x in zip(sequences, embeddings)
129+
]
130+
logits = [x.float().cpu().numpy()[:len(s)] for s, x in zip(sequences, logits)]
131+
132+
# we can't really make these difference tests too stringent, probably due to
133+
# numerical precision issues (fp16 may be particuarly problematic)
134+
future = model.attn(sequences)
135+
time.sleep(1)
136+
future.wait_until_done()
137+
result = {s: x for s, x in future.get()}
138+
for s, actual in zip(sequences, attn):
139+
mean_delta = np.abs(result[s] - actual).mean()
140+
random_mean_delta = np.abs(
141+
result[s] - actual[np.random.permutation(len(actual))]
142+
).mean()
143+
print(
144+
"attn",
145+
mean_delta,
146+
random_mean_delta,
147+
random_mean_delta / mean_delta,
148+
)
149+
assert np.abs(result[s] - actual).mean() < 1e-4
150+
151+
for reduction in [None, "MEAN", "SUM"]:
152+
future = model.embed(sequences, reduction=reduction)
153+
time.sleep(1)
154+
future.wait_until_done()
155+
result = {s: x for s, x in future.get()}
156+
for s, actual in zip(sequences, embeddings):
157+
if reduction == "MEAN":
158+
actual = actual.mean(axis=0)
159+
elif reduction == "SUM":
160+
# compare means to average out errors
161+
actual = actual.mean(axis=0)
162+
result[s] = result[s] / len(s)
163+
mean_delta = np.abs(result[s] - actual).mean()
164+
print("embed", reduction, mean_delta)
165+
assert np.abs(result[s] - actual).mean() < 1e-2
166+
167+
future = model.logits(sequences)
168+
time.sleep(1)
169+
future.wait_until_done()
170+
result = {s: x for s, x in future.get()}
171+
for s, actual in zip(sequences, logits):
172+
mean_delta = np.abs(result[s] - actual).mean()
173+
print("logits", mean_delta)
174+
assert np.abs(result[s] - actual).mean() < 1e-2
175+
176+
177+
@pytest.mark.longrun
178+
@pytest.mark.parametrize(
179+
"model_id", ["esm1b_t33_650M_UR50S", "esm1v_t33_650M_UR90S_1", "esm2_t6_8M_UR50D"]
180+
)
181+
@torch.inference_mode()
182+
def test_esm(session: OpenProtein, model_id: str, sequences: list[bytes]):
183+
print("testing...", model_id)
184+
device = (
185+
torch.device("cpu") # using cpu in case of low vram
186+
if model_id != "esm2_t6_8M_UR50D" else torch.device("cuda")
187+
)
188+
local_model: Union[ESM2, ProteinBertModel]
189+
model_dir = "data/models"
190+
model_pt_path = os.path.join(model_dir, f"{model_id}.pt")
191+
local_model, alphabet = load_model_and_alphabet_local(
192+
model_pt_path, device
193+
)
194+
batch_converter = alphabet.get_batch_converter()
195+
local_model = local_model.eval() # disables dropout for deterministic results
196+
if isinstance(local_model, ESM2):
197+
# half precision inference should be safe, per https://github.com/facebookresearch/esm/issues/283#issuecomment-1254283417
198+
local_model = local_model.half()
199+
local_model = local_model.to(device)
200+
can_predict_contacts = _has_regression_weights(model_id)
201+
202+
_, _, batch_tokens = batch_converter(list(zip(
203+
[f"{i}" for i in range(len(sequences))],
204+
[s.decode().replace("X", "<mask>") for s in sequences],
205+
)))
206+
results = local_model(
207+
batch_tokens.to(device),
208+
repr_layers=[local_model.num_layers],
209+
need_head_weights=True,
210+
return_contacts=can_predict_contacts,
211+
)
212+
213+
embeddings = results["representations"][local_model.num_layers].float()
214+
attn = results["attentions"].float()
215+
logits = results["logits"].float()
216+
if can_predict_contacts:
217+
contacts = results["contacts"].float()
218+
else:
219+
contacts = None
220+
221+
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
222+
embeddings = [
223+
embeddings[i, :tokens_len]
224+
for i, tokens_len in enumerate(batch_lens)
225+
]
226+
mean_embeddings = torch.vstack([e[1:-1].mean(dim=0) for e in embeddings])
227+
sum_embeddings = torch.vstack([e[1:-1].sum(dim=0) for e in embeddings])
228+
attn = [
229+
attn[i, -1, :, :tokens_len, :tokens_len]
230+
for i, tokens_len in enumerate(batch_lens)
231+
]
232+
logits = [
233+
logits[i, :tokens_len]
234+
for i, tokens_len in enumerate(batch_lens)
235+
]
236+
if contacts is not None:
237+
contacts = [
238+
contacts[i, :tokens_len-2, :tokens_len-2]
239+
for i, tokens_len in enumerate(batch_lens)
240+
]
241+
else:
242+
contacts = None
243+
244+
embeddings = [x.float().cpu().numpy() for x in embeddings]
245+
mean_embeddings = [x.float().cpu().numpy() for x in mean_embeddings]
246+
sum_embeddings = [x.float().cpu().numpy() for x in sum_embeddings]
247+
attn = [x.float().cpu().numpy() for x in attn]
248+
logits = [x.float().cpu().numpy() for x in logits]
249+
contacts = (
250+
[x.float().cpu().numpy() for x in contacts]
251+
if contacts is not None else None
252+
)
253+
254+
model = session.embedding.get_model(model_id=model_id)
255+
future = model.attn(sequences)
256+
time.sleep(1)
257+
future.wait_until_done()
258+
result = {s: x for s, x in future.get()}
259+
for s, actual in zip(sequences, attn):
260+
mean_delta = np.abs(result[s] - actual).mean()
261+
random_mean_delta = np.abs(
262+
result[s] - actual[np.random.permutation(len(actual))]
263+
).mean()
264+
print(
265+
"attn",
266+
mean_delta,
267+
random_mean_delta,
268+
random_mean_delta / mean_delta,
269+
)
270+
assert np.abs(result[s] - actual).mean() < 1e-4
271+
272+
for reduction in [None, "MEAN", "SUM"]:
273+
future = model.embed(sequences, reduction=reduction)
274+
time.sleep(1)
275+
future.wait_until_done()
276+
result = {s: x for s, x in future.get()}
277+
for i, s in enumerate(sequences):
278+
if reduction is None:
279+
actual = embeddings[i]
280+
elif reduction == "MEAN":
281+
actual = mean_embeddings[i]
282+
elif reduction == "SUM":
283+
# compare means to average out errors
284+
actual = sum_embeddings[i] / len(s)
285+
result[s] = result[s] / len(s)
286+
mean_delta = np.abs(result[s] - actual).mean()
287+
print("embed", reduction, mean_delta)
288+
assert np.abs(result[s] - actual).mean() < 1e-2
289+
290+
future = model.logits(sequences)
291+
time.sleep(1)
292+
future.wait_until_done()
293+
result = {s: x for s, x in future.get()}
294+
for s, actual in zip(sequences, logits):
295+
mean_delta = np.abs(result[s] - actual).mean()
296+
print("logits", mean_delta)
297+
assert np.abs(result[s] - actual).mean() < 1e-2
298+
299+
300+
@pytest.mark.parametrize("reduction", [None, "MEAN", "SUM"])
301+
@pytest.mark.parametrize("random_state,should_fail", [(47, False), (100, True)])
302+
def test_svd(
303+
session: OpenProtein,
304+
same_length_sequences: list[bytes],
305+
reduction: Optional[str],
306+
random_state: int,
307+
should_fail: bool,
308+
):
309+
print("testing svd...", reduction, random_state, should_fail)
310+
sequences = same_length_sequences
311+
# this is an extremely strong test!
312+
# it depends on the svd random_state being the same
313+
model_id = "prot-seq"
314+
n_components = 1024
315+
model = session.embedding.get_model(model_id=model_id)
316+
317+
# get embeddings to svd
318+
future = model.embed(sequences, reduction=reduction)
319+
time.sleep(1)
320+
future.wait_until_done()
321+
result = {s: x for s, x in future.get()}
322+
embeddings = np.stack([result[s] for s in sequences])
323+
if embeddings.ndim > 2:
324+
assert embeddings.ndim == 3
325+
embeddings = embeddings.reshape(len(sequences), -1)
326+
assert embeddings.ndim == 2
327+
328+
# compute svd locally
329+
local_svd = TorchLowRankSVDTransform(
330+
n_components=n_components, random_state=random_state, device="cpu"
331+
)
332+
reduced_embeddings = local_svd.fit_transform(
333+
torch.from_numpy(embeddings).float()
334+
).cpu().numpy()
335+
336+
# get svd from remote
337+
svd: SVDModel = model.fit_svd(sequences, n_components=n_components, reduction=reduction)
338+
time.sleep(1)
339+
svd.get_job().wait_until_done(session=session)
340+
future = svd.embed(sequences)
341+
time.sleep(1)
342+
future.wait_until_done()
343+
result = {s: x for s, x in future.get()}
344+
for s, actual in zip(sequences, reduced_embeddings):
345+
mean_delta = np.abs(result[s] - actual).mean()
346+
random_mean_delta = np.abs(
347+
result[s] - actual[np.random.permutation(len(actual))]
348+
).mean()
349+
print(
350+
"svd embed",
351+
mean_delta,
352+
random_mean_delta,
353+
random_mean_delta / mean_delta,
354+
)
355+
if not should_fail:
356+
assert random_mean_delta / mean_delta > 1e4
357+
else:
358+
assert random_mean_delta / mean_delta < 1e2

tests/utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)