|
| 1 | +""" |
| 2 | +assumes models are available at data/models |
| 3 | +change seeds at the appropriate places to avoid backend caching |
| 4 | +set --longrun flag when running pytest to run these tests |
| 5 | +""" |
| 6 | +import json |
| 7 | +import os |
| 8 | +import time |
| 9 | +from pathlib import Path |
| 10 | +from typing import Optional, Union |
| 11 | + |
| 12 | +import numpy as np |
| 13 | + |
| 14 | +import torch |
| 15 | + |
| 16 | +import fsspec |
| 17 | +from fsspec.implementations.dirfs import DirFileSystem |
| 18 | + |
| 19 | +import pytest |
| 20 | + |
| 21 | +from esm import ESM2, ProteinBertModel |
| 22 | +from esm.pretrained import load_model_and_alphabet_core, _has_regression_weights |
| 23 | + |
| 24 | +from protembed.alphabets import Uniprot21 |
| 25 | +from protembed.datasets import pad_tensor_1d |
| 26 | +from protembed.factory import ProtembedModelLoader |
| 27 | + |
| 28 | +import openprotein |
| 29 | +from openprotein import OpenProtein |
| 30 | +from openprotein.api.embedding import SVDModel |
| 31 | +from tests.utils.svd import TorchLowRankSVDTransform |
| 32 | + |
| 33 | + |
| 34 | +ALPHABET = Uniprot21() |
| 35 | + |
| 36 | + |
| 37 | +def load_model_and_alphabet_local(model_location, device): |
| 38 | + """Load from local path. The regression weights need to be co-located""" |
| 39 | + model_location = Path(model_location) |
| 40 | + model_data = torch.load(str(model_location), map_location=device) |
| 41 | + model_name = model_location.stem |
| 42 | + if _has_regression_weights(model_name): |
| 43 | + regression_location = str(model_location.with_suffix("")) + "-contact-regression.pt" |
| 44 | + regression_data = torch.load(regression_location, map_location=device) |
| 45 | + else: |
| 46 | + regression_data = None |
| 47 | + return load_model_and_alphabet_core(model_name, model_data, regression_data) |
| 48 | + |
| 49 | + |
| 50 | +@pytest.fixture() |
| 51 | +def session() -> OpenProtein: |
| 52 | + with open("secrets.config", "r") as f: |
| 53 | + config = json.load(f) |
| 54 | + return openprotein.connect( |
| 55 | + config["username"], |
| 56 | + config["password"], |
| 57 | + backend="https://dev.api.openprotein.ai/api/", |
| 58 | + ) |
| 59 | + |
| 60 | + |
| 61 | +@pytest.fixture() |
| 62 | +def loader() -> ProtembedModelLoader: |
| 63 | + root_fs = fsspec.filesystem('file') |
| 64 | + dir_fs = DirFileSystem("data/models", root_fs) |
| 65 | + return ProtembedModelLoader(dir_fs) |
| 66 | + |
| 67 | + |
| 68 | +@pytest.fixture() |
| 69 | +def sequences() -> list[bytes]: |
| 70 | + rng = np.random.default_rng(188501) |
| 71 | + return [ |
| 72 | + ALPHABET.decode(rng.integers( |
| 73 | + low=0, |
| 74 | + high=21, |
| 75 | + size=rng.integers(250, 500), |
| 76 | + )) |
| 77 | + for _ in range(5) |
| 78 | + ] |
| 79 | + |
| 80 | + |
| 81 | +@pytest.fixture() |
| 82 | +def same_length_sequences() -> list[bytes]: |
| 83 | + rng = np.random.default_rng(376735) |
| 84 | + return [ |
| 85 | + ALPHABET.decode(rng.integers(low=0, high=21, size=331)) |
| 86 | + for _ in range(5) |
| 87 | + ] |
| 88 | + |
| 89 | + |
| 90 | +@pytest.mark.longrun |
| 91 | +@pytest.mark.parametrize("local_model_id,model_id", [ |
| 92 | + ("prosst", "prot-seq"), |
| 93 | + ("rotaprot-seq-900m-uniref90-v1", "rotaprot-large-uniref90-ft"), |
| 94 | +]) |
| 95 | +@torch.inference_mode() |
| 96 | +def test_protembed( |
| 97 | + loader: ProtembedModelLoader, |
| 98 | + local_model_id: str, |
| 99 | + session: OpenProtein, |
| 100 | + model_id: str, |
| 101 | + sequences: list[bytes], |
| 102 | +): |
| 103 | + print("testing...", model_id) |
| 104 | + local_model = loader.load_model(local_model_id, device=torch.device("cuda")) |
| 105 | + model = session.embedding.get_model(model_id=model_id) |
| 106 | + |
| 107 | + sequences_as_idxs, mask = pad_tensor_1d( |
| 108 | + [torch.from_numpy(ALPHABET.encode(s)).cuda().long() for s in sequences], |
| 109 | + ALPHABET.mask_token, |
| 110 | + return_padding=True, |
| 111 | + ) |
| 112 | + with torch.autocast(device_type="cuda", dtype=torch.float16): |
| 113 | + _, attn = local_model.embed( |
| 114 | + sequences_as_idxs, padding_mask=mask, return_attention=True |
| 115 | + ) |
| 116 | + embeddings = local_model.embed(sequences_as_idxs, padding_mask=mask) |
| 117 | + logits = local_model.logits(embeddings) |
| 118 | + if isinstance(attn, list): |
| 119 | + # TODO: kinda hacky. doing this b/c inferface of prosst and rotaformer are not |
| 120 | + # the same |
| 121 | + attn = torch.stack(attn, dim=1) |
| 122 | + attn = [ |
| 123 | + x.float().cpu().numpy()[-1][:, :len(s), :len(s)] |
| 124 | + for s, x in zip(sequences, attn) |
| 125 | + ] |
| 126 | + embeddings = [ |
| 127 | + x.float().cpu().numpy()[:len(s)] |
| 128 | + for s, x in zip(sequences, embeddings) |
| 129 | + ] |
| 130 | + logits = [x.float().cpu().numpy()[:len(s)] for s, x in zip(sequences, logits)] |
| 131 | + |
| 132 | + # we can't really make these difference tests too stringent, probably due to |
| 133 | + # numerical precision issues (fp16 may be particuarly problematic) |
| 134 | + future = model.attn(sequences) |
| 135 | + time.sleep(1) |
| 136 | + future.wait_until_done() |
| 137 | + result = {s: x for s, x in future.get()} |
| 138 | + for s, actual in zip(sequences, attn): |
| 139 | + mean_delta = np.abs(result[s] - actual).mean() |
| 140 | + random_mean_delta = np.abs( |
| 141 | + result[s] - actual[np.random.permutation(len(actual))] |
| 142 | + ).mean() |
| 143 | + print( |
| 144 | + "attn", |
| 145 | + mean_delta, |
| 146 | + random_mean_delta, |
| 147 | + random_mean_delta / mean_delta, |
| 148 | + ) |
| 149 | + assert np.abs(result[s] - actual).mean() < 1e-4 |
| 150 | + |
| 151 | + for reduction in [None, "MEAN", "SUM"]: |
| 152 | + future = model.embed(sequences, reduction=reduction) |
| 153 | + time.sleep(1) |
| 154 | + future.wait_until_done() |
| 155 | + result = {s: x for s, x in future.get()} |
| 156 | + for s, actual in zip(sequences, embeddings): |
| 157 | + if reduction == "MEAN": |
| 158 | + actual = actual.mean(axis=0) |
| 159 | + elif reduction == "SUM": |
| 160 | + # compare means to average out errors |
| 161 | + actual = actual.mean(axis=0) |
| 162 | + result[s] = result[s] / len(s) |
| 163 | + mean_delta = np.abs(result[s] - actual).mean() |
| 164 | + print("embed", reduction, mean_delta) |
| 165 | + assert np.abs(result[s] - actual).mean() < 1e-2 |
| 166 | + |
| 167 | + future = model.logits(sequences) |
| 168 | + time.sleep(1) |
| 169 | + future.wait_until_done() |
| 170 | + result = {s: x for s, x in future.get()} |
| 171 | + for s, actual in zip(sequences, logits): |
| 172 | + mean_delta = np.abs(result[s] - actual).mean() |
| 173 | + print("logits", mean_delta) |
| 174 | + assert np.abs(result[s] - actual).mean() < 1e-2 |
| 175 | + |
| 176 | + |
| 177 | +@pytest.mark.longrun |
| 178 | +@pytest.mark.parametrize( |
| 179 | + "model_id", ["esm1b_t33_650M_UR50S", "esm1v_t33_650M_UR90S_1", "esm2_t6_8M_UR50D"] |
| 180 | +) |
| 181 | +@torch.inference_mode() |
| 182 | +def test_esm(session: OpenProtein, model_id: str, sequences: list[bytes]): |
| 183 | + print("testing...", model_id) |
| 184 | + device = ( |
| 185 | + torch.device("cpu") # using cpu in case of low vram |
| 186 | + if model_id != "esm2_t6_8M_UR50D" else torch.device("cuda") |
| 187 | + ) |
| 188 | + local_model: Union[ESM2, ProteinBertModel] |
| 189 | + model_dir = "data/models" |
| 190 | + model_pt_path = os.path.join(model_dir, f"{model_id}.pt") |
| 191 | + local_model, alphabet = load_model_and_alphabet_local( |
| 192 | + model_pt_path, device |
| 193 | + ) |
| 194 | + batch_converter = alphabet.get_batch_converter() |
| 195 | + local_model = local_model.eval() # disables dropout for deterministic results |
| 196 | + if isinstance(local_model, ESM2): |
| 197 | + # half precision inference should be safe, per https://github.com/facebookresearch/esm/issues/283#issuecomment-1254283417 |
| 198 | + local_model = local_model.half() |
| 199 | + local_model = local_model.to(device) |
| 200 | + can_predict_contacts = _has_regression_weights(model_id) |
| 201 | + |
| 202 | + _, _, batch_tokens = batch_converter(list(zip( |
| 203 | + [f"{i}" for i in range(len(sequences))], |
| 204 | + [s.decode().replace("X", "<mask>") for s in sequences], |
| 205 | + ))) |
| 206 | + results = local_model( |
| 207 | + batch_tokens.to(device), |
| 208 | + repr_layers=[local_model.num_layers], |
| 209 | + need_head_weights=True, |
| 210 | + return_contacts=can_predict_contacts, |
| 211 | + ) |
| 212 | + |
| 213 | + embeddings = results["representations"][local_model.num_layers].float() |
| 214 | + attn = results["attentions"].float() |
| 215 | + logits = results["logits"].float() |
| 216 | + if can_predict_contacts: |
| 217 | + contacts = results["contacts"].float() |
| 218 | + else: |
| 219 | + contacts = None |
| 220 | + |
| 221 | + batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) |
| 222 | + embeddings = [ |
| 223 | + embeddings[i, :tokens_len] |
| 224 | + for i, tokens_len in enumerate(batch_lens) |
| 225 | + ] |
| 226 | + mean_embeddings = torch.vstack([e[1:-1].mean(dim=0) for e in embeddings]) |
| 227 | + sum_embeddings = torch.vstack([e[1:-1].sum(dim=0) for e in embeddings]) |
| 228 | + attn = [ |
| 229 | + attn[i, -1, :, :tokens_len, :tokens_len] |
| 230 | + for i, tokens_len in enumerate(batch_lens) |
| 231 | + ] |
| 232 | + logits = [ |
| 233 | + logits[i, :tokens_len] |
| 234 | + for i, tokens_len in enumerate(batch_lens) |
| 235 | + ] |
| 236 | + if contacts is not None: |
| 237 | + contacts = [ |
| 238 | + contacts[i, :tokens_len-2, :tokens_len-2] |
| 239 | + for i, tokens_len in enumerate(batch_lens) |
| 240 | + ] |
| 241 | + else: |
| 242 | + contacts = None |
| 243 | + |
| 244 | + embeddings = [x.float().cpu().numpy() for x in embeddings] |
| 245 | + mean_embeddings = [x.float().cpu().numpy() for x in mean_embeddings] |
| 246 | + sum_embeddings = [x.float().cpu().numpy() for x in sum_embeddings] |
| 247 | + attn = [x.float().cpu().numpy() for x in attn] |
| 248 | + logits = [x.float().cpu().numpy() for x in logits] |
| 249 | + contacts = ( |
| 250 | + [x.float().cpu().numpy() for x in contacts] |
| 251 | + if contacts is not None else None |
| 252 | + ) |
| 253 | + |
| 254 | + model = session.embedding.get_model(model_id=model_id) |
| 255 | + future = model.attn(sequences) |
| 256 | + time.sleep(1) |
| 257 | + future.wait_until_done() |
| 258 | + result = {s: x for s, x in future.get()} |
| 259 | + for s, actual in zip(sequences, attn): |
| 260 | + mean_delta = np.abs(result[s] - actual).mean() |
| 261 | + random_mean_delta = np.abs( |
| 262 | + result[s] - actual[np.random.permutation(len(actual))] |
| 263 | + ).mean() |
| 264 | + print( |
| 265 | + "attn", |
| 266 | + mean_delta, |
| 267 | + random_mean_delta, |
| 268 | + random_mean_delta / mean_delta, |
| 269 | + ) |
| 270 | + assert np.abs(result[s] - actual).mean() < 1e-4 |
| 271 | + |
| 272 | + for reduction in [None, "MEAN", "SUM"]: |
| 273 | + future = model.embed(sequences, reduction=reduction) |
| 274 | + time.sleep(1) |
| 275 | + future.wait_until_done() |
| 276 | + result = {s: x for s, x in future.get()} |
| 277 | + for i, s in enumerate(sequences): |
| 278 | + if reduction is None: |
| 279 | + actual = embeddings[i] |
| 280 | + elif reduction == "MEAN": |
| 281 | + actual = mean_embeddings[i] |
| 282 | + elif reduction == "SUM": |
| 283 | + # compare means to average out errors |
| 284 | + actual = sum_embeddings[i] / len(s) |
| 285 | + result[s] = result[s] / len(s) |
| 286 | + mean_delta = np.abs(result[s] - actual).mean() |
| 287 | + print("embed", reduction, mean_delta) |
| 288 | + assert np.abs(result[s] - actual).mean() < 1e-2 |
| 289 | + |
| 290 | + future = model.logits(sequences) |
| 291 | + time.sleep(1) |
| 292 | + future.wait_until_done() |
| 293 | + result = {s: x for s, x in future.get()} |
| 294 | + for s, actual in zip(sequences, logits): |
| 295 | + mean_delta = np.abs(result[s] - actual).mean() |
| 296 | + print("logits", mean_delta) |
| 297 | + assert np.abs(result[s] - actual).mean() < 1e-2 |
| 298 | + |
| 299 | + |
| 300 | +@pytest.mark.parametrize("reduction", [None, "MEAN", "SUM"]) |
| 301 | +@pytest.mark.parametrize("random_state,should_fail", [(47, False), (100, True)]) |
| 302 | +def test_svd( |
| 303 | + session: OpenProtein, |
| 304 | + same_length_sequences: list[bytes], |
| 305 | + reduction: Optional[str], |
| 306 | + random_state: int, |
| 307 | + should_fail: bool, |
| 308 | +): |
| 309 | + print("testing svd...", reduction, random_state, should_fail) |
| 310 | + sequences = same_length_sequences |
| 311 | + # this is an extremely strong test! |
| 312 | + # it depends on the svd random_state being the same |
| 313 | + model_id = "prot-seq" |
| 314 | + n_components = 1024 |
| 315 | + model = session.embedding.get_model(model_id=model_id) |
| 316 | + |
| 317 | + # get embeddings to svd |
| 318 | + future = model.embed(sequences, reduction=reduction) |
| 319 | + time.sleep(1) |
| 320 | + future.wait_until_done() |
| 321 | + result = {s: x for s, x in future.get()} |
| 322 | + embeddings = np.stack([result[s] for s in sequences]) |
| 323 | + if embeddings.ndim > 2: |
| 324 | + assert embeddings.ndim == 3 |
| 325 | + embeddings = embeddings.reshape(len(sequences), -1) |
| 326 | + assert embeddings.ndim == 2 |
| 327 | + |
| 328 | + # compute svd locally |
| 329 | + local_svd = TorchLowRankSVDTransform( |
| 330 | + n_components=n_components, random_state=random_state, device="cpu" |
| 331 | + ) |
| 332 | + reduced_embeddings = local_svd.fit_transform( |
| 333 | + torch.from_numpy(embeddings).float() |
| 334 | + ).cpu().numpy() |
| 335 | + |
| 336 | + # get svd from remote |
| 337 | + svd: SVDModel = model.fit_svd(sequences, n_components=n_components, reduction=reduction) |
| 338 | + time.sleep(1) |
| 339 | + svd.get_job().wait_until_done(session=session) |
| 340 | + future = svd.embed(sequences) |
| 341 | + time.sleep(1) |
| 342 | + future.wait_until_done() |
| 343 | + result = {s: x for s, x in future.get()} |
| 344 | + for s, actual in zip(sequences, reduced_embeddings): |
| 345 | + mean_delta = np.abs(result[s] - actual).mean() |
| 346 | + random_mean_delta = np.abs( |
| 347 | + result[s] - actual[np.random.permutation(len(actual))] |
| 348 | + ).mean() |
| 349 | + print( |
| 350 | + "svd embed", |
| 351 | + mean_delta, |
| 352 | + random_mean_delta, |
| 353 | + random_mean_delta / mean_delta, |
| 354 | + ) |
| 355 | + if not should_fail: |
| 356 | + assert random_mean_delta / mean_delta > 1e4 |
| 357 | + else: |
| 358 | + assert random_mean_delta / mean_delta < 1e2 |
0 commit comments