Skip to content

Commit 138f827

Browse files
committed
Release v0.8.11
1 parent 29d1538 commit 138f827

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

openprotein/embeddings/poet2.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ def fit_svd(
374374
prompt: str | Prompt | None = None,
375375
query: str | bytes | Protein | Query | None = None,
376376
use_query_structure_in_decoder: bool = True,
377+
decoder_type: Literal["mlm", "clm"] | None = None,
378+
**kwargs,
377379
) -> "SVDModel":
378380
"""
379381
Fit an SVD on the embedding results of PoET.
@@ -397,6 +399,10 @@ def fit_svd(
397399
Query to use with prompt.
398400
use_query_structure_in_decoder : bool, optional
399401
Whether to use query structure in decoder. Default is True.
402+
decoder_type : {'mlm', 'clm'} or None, optional
403+
Decoder type. Default is None.
404+
**kwargs
405+
Additional keyword arguments for the model.
400406
401407
Returns
402408
-------
@@ -412,17 +418,21 @@ def fit_svd(
412418
prompt=prompt,
413419
query_id=query_id,
414420
use_query_structure_in_decoder=use_query_structure_in_decoder,
421+
decoder_type=decoder_type,
422+
**kwargs,
415423
)
416424

417425
def fit_umap(
418426
self,
419427
sequences: list[bytes] | list[str] | None = None,
420428
assay: AssayDataset | None = None,
421429
n_components: int = 2,
422-
reduction: ReductionType | None = ReductionType.MEAN,
430+
reduction: ReductionType = ReductionType.MEAN,
423431
prompt: str | Prompt | None = None,
424432
query: str | bytes | Protein | Query | None = None,
425433
use_query_structure_in_decoder: bool = True,
434+
decoder_type: Literal["mlm", "clm"] | None = None,
435+
**kwargs,
426436
) -> "UMAPModel":
427437
"""
428438
Fit a UMAP on assay using PoET and hyperparameters.
@@ -446,6 +456,10 @@ def fit_umap(
446456
Query to use with prompt.
447457
use_query_structure_in_decoder : bool, optional
448458
Whether to use query structure in decoder. Default is True.
459+
decoder_type : {'mlm', 'clm'} or None, optional
460+
Decoder type. Default is None.
461+
**kwargs
462+
Additional keyword arguments for the model.
449463
450464
Returns
451465
-------
@@ -461,6 +475,8 @@ def fit_umap(
461475
prompt=prompt,
462476
query_id=query_id,
463477
use_query_structure_in_decoder=use_query_structure_in_decoder,
478+
decoder_type=decoder_type,
479+
**kwargs,
464480
)
465481

466482
def fit_gp(
@@ -470,6 +486,7 @@ def fit_gp(
470486
prompt: str | Prompt | None = None,
471487
query: str | bytes | Protein | Query | None = None,
472488
use_query_structure_in_decoder: bool = True,
489+
decoder_type: Literal["mlm", "clm"] | None = None,
473490
**kwargs,
474491
) -> "PredictorModel":
475492
"""
@@ -487,8 +504,10 @@ def fit_gp(
487504
Query to use with prompt.
488505
use_query_structure_in_decoder : bool, optional
489506
Whether to use query structure in decoder. Default is True.
507+
decoder_type : {'mlm', 'clm'} or None, optional
508+
Decoder type. Default is None.
490509
**kwargs
491-
Additional keyword arguments.
510+
Additional keyword arguments for the model.
492511
493512
Returns
494513
-------
@@ -502,5 +521,6 @@ def fit_gp(
502521
prompt=prompt,
503522
query_id=query_id,
504523
use_query_structure_in_decoder=use_query_structure_in_decoder,
524+
decoder_type=decoder_type,
505525
**kwargs,
506526
)

0 commit comments

Comments
 (0)