@@ -374,6 +374,8 @@ def fit_svd(
374374 prompt : str | Prompt | None = None ,
375375 query : str | bytes | Protein | Query | None = None ,
376376 use_query_structure_in_decoder : bool = True ,
377+ decoder_type : Literal ["mlm" , "clm" ] | None = None ,
378+ ** kwargs ,
377379 ) -> "SVDModel" :
378380 """
379381 Fit an SVD on the embedding results of PoET.
@@ -397,6 +399,10 @@ def fit_svd(
397399 Query to use with prompt.
398400 use_query_structure_in_decoder : bool, optional
399401 Whether to use query structure in decoder. Default is True.
402+ decoder_type : {'mlm', 'clm'} or None, optional
403+ Decoder type. Default is None.
404+ **kwargs
405+ Additional keyword arguments for the model.
400406
401407 Returns
402408 -------
@@ -412,17 +418,21 @@ def fit_svd(
412418 prompt = prompt ,
413419 query_id = query_id ,
414420 use_query_structure_in_decoder = use_query_structure_in_decoder ,
421+ decoder_type = decoder_type ,
422+ ** kwargs ,
415423 )
416424
417425 def fit_umap (
418426 self ,
419427 sequences : list [bytes ] | list [str ] | None = None ,
420428 assay : AssayDataset | None = None ,
421429 n_components : int = 2 ,
422- reduction : ReductionType | None = ReductionType .MEAN ,
430+ reduction : ReductionType = ReductionType .MEAN ,
423431 prompt : str | Prompt | None = None ,
424432 query : str | bytes | Protein | Query | None = None ,
425433 use_query_structure_in_decoder : bool = True ,
434+ decoder_type : Literal ["mlm" , "clm" ] | None = None ,
435+ ** kwargs ,
426436 ) -> "UMAPModel" :
427437 """
428438 Fit a UMAP on assay using PoET and hyperparameters.
@@ -446,6 +456,10 @@ def fit_umap(
446456 Query to use with prompt.
447457 use_query_structure_in_decoder : bool, optional
448458 Whether to use query structure in decoder. Default is True.
459+ decoder_type : {'mlm', 'clm'} or None, optional
460+ Decoder type. Default is None.
461+ **kwargs
462+ Additional keyword arguments for the model.
449463
450464 Returns
451465 -------
@@ -461,6 +475,8 @@ def fit_umap(
461475 prompt = prompt ,
462476 query_id = query_id ,
463477 use_query_structure_in_decoder = use_query_structure_in_decoder ,
478+ decoder_type = decoder_type ,
479+ ** kwargs ,
464480 )
465481
466482 def fit_gp (
@@ -470,6 +486,7 @@ def fit_gp(
470486 prompt : str | Prompt | None = None ,
471487 query : str | bytes | Protein | Query | None = None ,
472488 use_query_structure_in_decoder : bool = True ,
489+ decoder_type : Literal ["mlm" , "clm" ] | None = None ,
473490 ** kwargs ,
474491 ) -> "PredictorModel" :
475492 """
@@ -487,8 +504,10 @@ def fit_gp(
487504 Query to use with prompt.
488505 use_query_structure_in_decoder : bool, optional
489506 Whether to use query structure in decoder. Default is True.
507+ decoder_type : {'mlm', 'clm'} or None, optional
508+ Decoder type. Default is None.
490509 **kwargs
491- Additional keyword arguments.
510+ Additional keyword arguments for the model .
492511
493512 Returns
494513 -------
@@ -502,5 +521,6 @@ def fit_gp(
502521 prompt = prompt ,
503522 query_id = query_id ,
504523 use_query_structure_in_decoder = use_query_structure_in_decoder ,
524+ decoder_type = decoder_type ,
505525 ** kwargs ,
506526 )
0 commit comments