|
11 | 11 | ) |
12 | 12 | from openprotein.futures import FutureBase, FutureFactory |
13 | 13 |
|
14 | | -from pydantic import BaseModel, parse_obj_as |
| 14 | +from openprotein.pydantic import BaseModel, parse_obj_as |
15 | 15 | import numpy as np |
16 | 16 | from typing import Optional, List, Union, Any |
17 | 17 | import io |
@@ -247,7 +247,7 @@ def __init__( |
247 | 247 |
|
248 | 248 | def get(self, verbose=False) -> List: |
249 | 249 | return super().get(verbose=verbose) |
250 | | - |
| 250 | + |
251 | 251 | @property |
252 | 252 | def sequences(self): |
253 | 253 | if self._sequences is None: |
@@ -305,9 +305,7 @@ def embedding_model_post( |
305 | 305 | """ |
306 | 306 | endpoint = PATH_PREFIX + f"/models/{model_id}/embed" |
307 | 307 |
|
308 | | - sequences_unicode = [ |
309 | | - (s if isinstance(s, str) else s.decode()) for s in sequences |
310 | | - ] |
| 308 | + sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences] |
311 | 309 | body = { |
312 | 310 | "sequences": sequences_unicode, |
313 | 311 | } |
@@ -345,9 +343,7 @@ def embedding_model_logits_post( |
345 | 343 | """ |
346 | 344 | endpoint = PATH_PREFIX + f"/models/{model_id}/logits" |
347 | 345 |
|
348 | | - sequences_unicode = [ |
349 | | - (s if isinstance(s, str) else s.decode()) for s in sequences |
350 | | - ] |
| 346 | + sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences] |
351 | 347 | body = { |
352 | 348 | "sequences": sequences_unicode, |
353 | 349 | } |
@@ -385,9 +381,7 @@ def embedding_model_attn_post( |
385 | 381 | """ |
386 | 382 | endpoint = PATH_PREFIX + f"/models/{model_id}/attn" |
387 | 383 |
|
388 | | - sequences_unicode = [ |
389 | | - (s if isinstance(s, str) else s.decode()) for s in sequences |
390 | | - ] |
| 384 | + sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences] |
391 | 385 | body = { |
392 | 386 | "sequences": sequences_unicode, |
393 | 387 | } |
@@ -500,9 +494,7 @@ def svd_embed_post(session: APISession, svd_id: str, sequences: List[bytes]) -> |
500 | 494 | """ |
501 | 495 | endpoint = PATH_PREFIX + f"/svd/{svd_id}/embed" |
502 | 496 |
|
503 | | - sequences_unicode = [ |
504 | | - (s if isinstance(s, str) else s.decode()) for s in sequences |
505 | | - ] |
| 497 | + sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences] |
506 | 498 | body = { |
507 | 499 | "sequences": sequences_unicode, |
508 | 500 | } |
@@ -715,7 +707,7 @@ def get_job(self) -> Job: |
715 | 707 | """Get job associated with this SVD model""" |
716 | 708 | return job_get(self.session, self.id) |
717 | 709 |
|
718 | | - def get(self): |
| 710 | + def get(self, verbose: bool = False): |
719 | 711 | # overload for AsyncJobFuture |
720 | 712 | return self |
721 | 713 |
|
@@ -963,7 +955,7 @@ def fit_svd( |
963 | 955 | sequences: List[bytes], |
964 | 956 | n_components: int = 1024, |
965 | 957 | reduction: Optional[str] = None, |
966 | | - ) -> SVDModel: |
| 958 | + ) -> SVDModel: # type: ignore |
967 | 959 | """ |
968 | 960 | Fit an SVD on the embedding results of this model. |
969 | 961 |
|
|
0 commit comments