Skip to content

Commit 001e761

Browse files
committed
fix: fixed docstring
1 parent ee343e8 commit 001e761

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

autointent/modules/embedding/_retrieval.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class LogRegEmbedding(EmbeddingModule):
3636
r"""
3737
Module for managing classification operations using logistic regression.
3838
39-
LogRegEmbedding provides methods for indexing, training, and predicting based on embeddings
39+
LogRegEmbedding provides methods for indexing, and training based on embeddings
4040
for classification tasks.
4141
4242
:ivar classifier: The trained logistic regression model.
@@ -81,12 +81,12 @@ def __init__(
8181
self,
8282
k: int,
8383
embedder_name: str,
84+
cv: int = 3,
8485
db_dir: str | None = None,
8586
embedder_device: str = "cpu",
8687
batch_size: int = 32,
8788
max_length: int | None = None,
8889
embedder_use_cache: bool = False,
89-
**kwargs,
9090
) -> None:
9191
"""
9292
Initialize the RetrievalEmbedding.
@@ -104,7 +104,7 @@ def __init__(
104104
self.batch_size = batch_size
105105
self.max_length = max_length
106106
self.embedder_use_cache = embedder_use_cache
107-
self.classifier = LogisticRegressionCV(**kwargs)
107+
self.classifier = LogisticRegressionCV(cv=cv)
108108
self.label_encoder = LabelEncoder()
109109

110110
super().__init__(k=k)
@@ -114,8 +114,8 @@ def from_context(
114114
cls,
115115
context: Context,
116116
k: int,
117+
cv: int,
117118
embedder_name: str,
118-
**kwargs,
119119
) -> "LogRegEmbedding":
120120
"""
121121
Create a LogRegEmbedding instance using a Context object.
@@ -126,13 +126,13 @@ def from_context(
126126
"""
127127
return cls(
128128
k=k,
129+
cv=cv,
129130
embedder_name=embedder_name,
130131
db_dir=str(context.get_db_dir()),
131132
embedder_device=context.get_device(),
132133
batch_size=context.get_batch_size(),
133134
max_length=context.get_max_length(),
134135
embedder_use_cache=context.get_use_cache(),
135-
**kwargs,
136136
)
137137

138138
@property

0 commit comments

Comments
 (0)