88from typing_extensions import assert_never
99
1010from autointent import Context , VectorIndex
11- from autointent .configs import EmbedderConfig
11+ from autointent .configs import EmbedderConfig , VectorIndexConfig , get_default_vector_index_config
1212from autointent .custom_types import ListOfLabels
1313from autointent .modules .base import BaseScorer
1414
@@ -67,11 +67,13 @@ def __init__(
6767 embedder_config : EmbedderConfig | str | dict [str , Any ] | None = None ,
6868 s : float = 1.0 ,
6969 ignore_first_neighbours : int = 0 ,
70+ vector_index_config : VectorIndexConfig | None = None ,
7071 ) -> None :
7172 self .k = k
7273 self .embedder_config = EmbedderConfig .from_search_config (embedder_config )
7374 self .s = s
7475 self .ignore_first_neighbours = ignore_first_neighbours
76+ self .vector_index_config = vector_index_config or get_default_vector_index_config ()
7577
7678 if self .k < 0 or not isinstance (self .k , int ):
7779 msg = "`k` argument of `MLKnnScorer` must be a positive int"
@@ -109,6 +111,7 @@ def from_context(
109111 embedder_config = embedder_config ,
110112 s = s ,
111113 ignore_first_neighbours = ignore_first_neighbours ,
114+ vector_index_config = context .vector_index_config ,
112115 )
113116
114117 def get_implicit_initialization_params (self ) -> dict [str , Any ]:
@@ -127,15 +130,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
127130 """
128131 self ._validate_task (labels )
129132
130- self ._vector_index = VectorIndex (
131- EmbedderConfig (
132- model_name = self .embedder_config .model_name ,
133- device = self .embedder_config .device ,
134- batch_size = self .embedder_config .batch_size ,
135- tokenizer_config = self .embedder_config .tokenizer_config ,
136- use_cache = self .embedder_config .use_cache ,
137- ),
138- )
133+ self ._vector_index = VectorIndex (embedder_config = self .embedder_config , config = self .vector_index_config )
139134 self ._vector_index .add (utterances , labels )
140135
141136 self ._features = self ._vector_index .get_all_embeddings ()
0 commit comments