diff --git a/autointent/modules/embedding/_logreg.py b/autointent/modules/embedding/_logreg.py index d729b8741..02d3fb0da 100644 --- a/autointent/modules/embedding/_logreg.py +++ b/autointent/modules/embedding/_logreg.py @@ -77,8 +77,8 @@ def __init__( def from_context( cls, context: Context, - cv: int, embedder_name: str, + cv: int = 3, ) -> "LogregAimedEmbedding": """ Create a LogregAimedEmbedding instance using a Context object. @@ -89,8 +89,8 @@ def from_context( :return: Initialized LogregAimedEmbedding instance. """ return cls( - cv=cv, embedder_name=embedder_name, + cv=cv, embedder_device=context.get_device(), embedder_batch_size=context.get_batch_size(), embedder_max_length=context.get_max_length(),