@@ -36,7 +36,7 @@ class LogRegEmbedding(EmbeddingModule):
3636 r"""
3737 Module for managing classification operations using logistic regression.
3838
39- LogRegEmbedding provides methods for indexing, training, and predicting based on embeddings
39+ LogRegEmbedding provides methods for indexing, and training based on embeddings
4040 for classification tasks.
4141
4242 :ivar classifier: The trained logistic regression model.
@@ -81,12 +81,12 @@ def __init__(
8181 self ,
8282 k : int ,
8383 embedder_name : str ,
84+ cv : int = 3 ,
8485 db_dir : str | None = None ,
8586 embedder_device : str = "cpu" ,
8687 batch_size : int = 32 ,
8788 max_length : int | None = None ,
8889 embedder_use_cache : bool = False ,
89- ** kwargs ,
9090 ) -> None :
9191 """
9292 Initialize the RetrievalEmbedding.
@@ -104,7 +104,7 @@ def __init__(
104104 self .batch_size = batch_size
105105 self .max_length = max_length
106106 self .embedder_use_cache = embedder_use_cache
107- self .classifier = LogisticRegressionCV (** kwargs )
107+ self .classifier = LogisticRegressionCV (cv = cv )
108108 self .label_encoder = LabelEncoder ()
109109
110110 super ().__init__ (k = k )
@@ -114,8 +114,8 @@ def from_context(
114114 cls ,
115115 context : Context ,
116116 k : int ,
117+ cv : int ,
117118 embedder_name : str ,
118- ** kwargs ,
119119 ) -> "LogRegEmbedding" :
120120 """
121121 Create a LogRegEmbedding instance using a Context object.
@@ -126,13 +126,13 @@ def from_context(
126126 """
127127 return cls (
128128 k = k ,
129+ cv = cv ,
129130 embedder_name = embedder_name ,
130131 db_dir = str (context .get_db_dir ()),
131132 embedder_device = context .get_device (),
132133 batch_size = context .get_batch_size (),
133134 max_length = context .get_max_length (),
134135 embedder_use_cache = context .get_use_cache (),
135- ** kwargs ,
136136 )
137137
138138 @property
0 commit comments