@@ -90,7 +90,7 @@ def __init__(
9090 k : int ,
9191 weights : WEIGHT_TYPES = "distance" ,
9292 db_dir : str | None = None ,
93- device : str = "cpu" ,
93+ embedder_device : str = "cpu" ,
9494 batch_size : int = 32 ,
9595 max_length : int | None = None ,
9696 embedder_use_cache : bool = False ,
@@ -105,7 +105,7 @@ def __init__(
105105 - "distance" (or True): Weight inversely proportional to distance.
106106 - "closest": Only the closest neighbor of each class is weighted.
107107 :param db_dir: Path to the database directory, or None to use default.
108- :param device : Device to run operations on, e.g., "cpu" or "cuda".
108+ :param embedder_device : Device to run operations on, e.g., "cpu" or "cuda".
109109 :param batch_size: Batch size for embedding generation, defaults to 32.
110110 :param max_length: Maximum sequence length for embedding, or None for default.
111111 :param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
@@ -114,7 +114,7 @@ def __init__(
114114 self .k = k
115115 self .weights = weights
116116 self ._db_dir = db_dir
117- self .device = device
117+ self .embedder_device = embedder_device
118118 self .batch_size = batch_size
119119 self .max_length = max_length
120120 self .embedder_use_cache = embedder_use_cache
@@ -158,7 +158,7 @@ def from_context(
158158 k = k ,
159159 weights = weights ,
160160 db_dir = str (context .get_db_dir ()),
161- device = context .get_device (),
161+ embedder_device = context .get_device (),
162162 batch_size = context .get_batch_size (),
163163 max_length = context .get_max_length (),
164164 embedder_use_cache = context .get_use_cache (),
@@ -188,7 +188,9 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
188188 else :
189189 self .n_classes = len (set (labels ))
190190 self .multilabel = False
191- vector_index_client = VectorIndexClient (self .device , self .db_dir , embedder_use_cache = self .embedder_use_cache )
191+ vector_index_client = VectorIndexClient (
192+ self .embedder_device , self .db_dir , embedder_use_cache = self .embedder_use_cache
193+ )
192194
193195 if self .prebuilt_index :
194196 # this happens only after RetrievalNode optimization
@@ -265,7 +267,7 @@ def _restore_state_from_metadata(self, metadata: KNNScorerDumpMetadata) -> None:
265267 self .multilabel = metadata ["multilabel" ]
266268
267269 vector_index_client = VectorIndexClient (
268- device = self .device ,
270+ embedder_device = self .embedder_device ,
269271 db_dir = metadata ["db_dir" ],
270272 embedder_batch_size = metadata ["batch_size" ],
271273 embedder_max_length = metadata ["max_length" ],
0 commit comments