Skip to content

Commit 79ed3e8

Browse files
authored
Refactor/another revision (#73)
* change default `clear_ram` value * fix some references * convert search space guide to ipynb * fix some references * rename tutorials, fix refs, begin with python api tutorial * stage progress on python api tutorial * move device option to embedder config * stage progress on tutorials
1 parent ac5ae75 commit 79ed3e8

File tree

33 files changed

+431
-344
lines changed

33 files changed

+431
-344
lines changed

autointent/_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
158158
return np.load(embeddings_path) # type: ignore[no-any-return]
159159

160160
self.logger.debug(
161-
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, device=%s",
161+
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
162162
self.model_name,
163163
self.batch_size,
164164
str(self.max_length),

autointent/configs/_optimization_cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class LoggingConfig:
4343
"""Path to the directory where the modules will be dumped. If None, the modules will not be dumped"""
4444
dump_modules: bool = False
4545
"""Whether to dump the modules or not"""
46-
clear_ram: bool = True
46+
clear_ram: bool = False
4747
"""Whether to clear the RAM after dumping the modules"""
4848

4949
def __post_init__(self) -> None:
@@ -77,8 +77,6 @@ class VectorIndexConfig:
7777

7878
db_dir: Path | None = None
7979
"""Path to the directory where the vector index database will be saved. If None, the database will not be saved"""
80-
device: str = "cpu"
81-
"""Device to use for the vector index. Can be 'cpu', 'cuda', 'cuda:0', 'mps', etc."""
8280
save_db: bool = False
8381
"""Whether to save the vector index database or not"""
8482

@@ -109,6 +107,8 @@ class EmbedderConfig:
109107
"""Max length for the embedder. If None, the max length will be taken from model config"""
110108
use_cache: bool = False
111109
"""Flag indicating whether to cache embeddings for reuse, improving performance in repeated operations."""
110+
device: str = "cpu"
111+
"""Device to use for the vector index. Can be 'cpu', 'cuda', 'cuda:0', 'mps', etc."""
112112

113113

114114
@dataclass

autointent/context/_context.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def configure_vector_index(self, config: VectorIndexConfig, embedder_config: Emb
6565
self.embedder_config = embedder_config
6666

6767
self.vector_index_client = VectorIndexClient(
68-
self.vector_index_config.device,
68+
self.embedder_config.device,
6969
self.vector_index_config.db_dir,
7070
self.embedder_config.batch_size,
7171
self.embedder_config.max_length,
@@ -115,7 +115,7 @@ def get_inference_config(self) -> dict[str, Any]:
115115
nodes_configs = [asdict(cfg) for cfg in self.optimization_info.get_inference_nodes_config()]
116116
return {
117117
"metadata": {
118-
"device": self.get_device(),
118+
"embedder_device": self.get_device(),
119119
"multilabel": self.is_multilabel(),
120120
"n_classes": self.get_n_classes(),
121121
"seed": self.seed,
@@ -168,11 +168,11 @@ def get_db_dir(self) -> Path:
168168

169169
def get_device(self) -> str:
170170
"""
171-
Get the device used by the vector index client.
171+
Get the embedder device used by the vector index client.
172172
173173
:return: Device name.
174174
"""
175-
return self.vector_index_client.device
175+
return self.vector_index_client.embedder_device
176176

177177
def get_batch_size(self) -> int:
178178
"""

autointent/context/vector_index_client/_vector_index.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class VectorIndex:
2828
def __init__(
2929
self,
3030
model_name: str,
31-
device: str,
31+
embedder_device: str,
3232
embedder_batch_size: int = 32,
3333
embedder_max_length: int | None = None,
3434
embedder_use_cache: bool = False,
@@ -37,7 +37,7 @@ def __init__(
3737
Initialize the vector index.
3838
3939
:param model_name: Name of the embedding model to use.
40-
:param device: Device for running the embedding model (e.g., "cpu", "cuda").
40+
:param embedder_device: Device for running the embedding model (e.g., "cpu", "cuda").
4141
:param embedder_batch_size: Batch size for the embedder.
4242
:param embedder_max_length: Maximum sequence length for the embedder.
4343
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
@@ -46,11 +46,11 @@ def __init__(
4646
self.embedder = Embedder(
4747
model_name=model_name,
4848
batch_size=embedder_batch_size,
49-
device=device,
49+
device=embedder_device,
5050
max_length=embedder_max_length,
5151
use_cache=embedder_use_cache,
5252
)
53-
self.device = device
53+
self.embedder_device = embedder_device
5454

5555
self.labels: list[LabelType] = [] # (n_samples,) or (n_samples, n_classes)
5656
self.texts: list[str] = []
@@ -200,7 +200,7 @@ def load(self, dir_path: Path) -> None:
200200
"""
201201
self.dump_dir = Path(dir_path)
202202
self.index = faiss.read_index(str(dir_path / "index.faiss"))
203-
self.embedder = Embedder(model_name=dir_path / "embedding_model", device=self.device)
203+
self.embedder = Embedder(model_name=dir_path / "embedding_model", device=self.embedder_device)
204204
with (dir_path / "texts.json").open() as file:
205205
self.texts = json.load(file)
206206
with (dir_path / "labels.json").open() as file:

autointent/context/vector_index_client/_vector_index_client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class VectorIndexClient:
2828

2929
def __init__(
3030
self,
31-
device: str,
31+
embedder_device: str,
3232
db_dir: str | Path | None,
3333
embedder_batch_size: int = 32,
3434
embedder_max_length: int | None = None,
@@ -37,14 +37,14 @@ def __init__(
3737
"""
3838
Initialize the VectorIndexClient.
3939
40-
:param device: Device to run the embedding model on.
40+
:param embedder_device: Device to run the embedding model on.
4141
:param db_dir: Directory for storing vector indexes. Defaults to a cache directory.
4242
:param embedder_batch_size: Batch size for the embedding model.
4343
:param embedder_max_length: Maximum sequence length for the embedding model.
4444
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
4545
"""
4646
self._logger = logging.getLogger(__name__)
47-
self.device = device
47+
self.embedder_device = embedder_device
4848
self.db_dir = get_db_dir(db_dir)
4949
self.embedder_batch_size = embedder_batch_size
5050
self.embedder_max_length = embedder_max_length
@@ -69,7 +69,7 @@ def create_index(
6969

7070
index = VectorIndex(
7171
model_name,
72-
self.device,
72+
self.embedder_device,
7373
self.embedder_batch_size,
7474
self.embedder_max_length,
7575
self.embedder_use_cache,
@@ -176,7 +176,7 @@ def get_index(self, model_name: str) -> VectorIndex:
176176
if dirpath is not None:
177177
index = VectorIndex(
178178
model_name,
179-
self.device,
179+
self.embedder_device,
180180
self.embedder_batch_size,
181181
self.embedder_max_length,
182182
self.embedder_use_cache,

autointent/modules/retrieval/_vectordb.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
k: int,
7070
embedder_name: str,
7171
db_dir: str | None = None,
72-
device: str = "cpu",
72+
embedder_device: str = "cpu",
7373
batch_size: int = 32,
7474
max_length: int | None = None,
7575
embedder_use_cache: bool = False,
@@ -80,13 +80,13 @@ def __init__(
8080
:param k: Number of nearest neighbors to retrieve.
8181
:param embedder_name: Name of the embedder used for creating embeddings.
8282
:param db_dir: Path to the database directory. If None, defaults will be used.
83-
:param device: Device to run operations on, e.g., "cpu" or "cuda".
83+
:param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
8484
:param batch_size: Batch size for embedding generation.
8585
:param max_length: Maximum sequence length for embeddings. None if not set.
8686
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
8787
"""
8888
self.embedder_name = embedder_name
89-
self.device = device
89+
self.embedder_device = embedder_device
9090
self._db_dir = db_dir
9191
self.batch_size = batch_size
9292
self.max_length = max_length
@@ -113,7 +113,7 @@ def from_context(
113113
k=k,
114114
embedder_name=embedder_name,
115115
db_dir=str(context.get_db_dir()),
116-
device=context.get_device(),
116+
embedder_device=context.get_device(),
117117
batch_size=context.get_batch_size(),
118118
max_length=context.get_max_length(),
119119
embedder_use_cache=context.get_use_cache(),
@@ -138,7 +138,7 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
138138
:param labels: List of corresponding labels for the utterances.
139139
"""
140140
vector_index_client = VectorIndexClient(
141-
self.device,
141+
self.embedder_device,
142142
self.db_dir,
143143
embedder_batch_size=self.batch_size,
144144
embedder_max_length=self.max_length,
@@ -212,7 +212,7 @@ def load(self, path: str) -> None:
212212
self.metadata: VectorDBMetadata = json.load(file)
213213

214214
vector_index_client = VectorIndexClient(
215-
device=self.device,
215+
embedder_device=self.embedder_device,
216216
db_dir=self.metadata["db_dir"],
217217
embedder_batch_size=self.metadata["batch_size"],
218218
embedder_max_length=self.metadata["max_length"],

autointent/modules/scoring/_description/description.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
self,
5555
embedder_name: str,
5656
temperature: float = 1.0,
57-
device: str = "cpu",
57+
embedder_device: str = "cpu",
5858
batch_size: int = 32,
5959
max_length: int | None = None,
6060
embedder_use_cache: bool = False,
@@ -64,13 +64,13 @@ def __init__(
6464
6565
:param embedder_name: Name of the embedder model.
6666
:param temperature: Temperature parameter for scaling logits, defaults to 1.0.
67-
:param device: Device to run the embedder on, e.g., "cpu" or "cuda".
67+
:param embedder_device: Device to run the embedder on, e.g., "cpu" or "cuda".
6868
:param batch_size: Batch size for embedding generation, defaults to 32.
6969
:param max_length: Maximum sequence length for embedding, defaults to None.
7070
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
7171
"""
7272
self.temperature = temperature
73-
self.device = device
73+
self.embedder_device = embedder_device
7474
self.embedder_name = embedder_name
7575
self.batch_size = batch_size
7676
self.max_length = max_length
@@ -99,7 +99,7 @@ def from_context(
9999

100100
instance = cls(
101101
temperature=temperature,
102-
device=context.get_device(),
102+
embedder_device=context.get_device(),
103103
embedder_name=embedder_name,
104104
embedder_use_cache=context.get_use_cache(),
105105
)
@@ -139,7 +139,7 @@ def fit(
139139
if self.precomputed_embeddings:
140140
# this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization
141141
vector_index_client = VectorIndexClient(
142-
self.device,
142+
self.embedder_device,
143143
self.db_dir,
144144
self.batch_size,
145145
self.max_length,
@@ -153,7 +153,7 @@ def fit(
153153
embedder = vector_index.embedder
154154
else:
155155
embedder = Embedder(
156-
device=self.device,
156+
device=self.embedder_device,
157157
model_name=self.embedder_name,
158158
batch_size=self.batch_size,
159159
max_length=self.max_length,
@@ -230,7 +230,7 @@ def load(self, path: str) -> None:
230230

231231
embedder_dir = dump_dir / self.embedding_model_subdir
232232
self.embedder = Embedder(
233-
device=self.device,
233+
device=self.embedder_device,
234234
model_name=embedder_dir,
235235
batch_size=self.metadata["batch_size"],
236236
max_length=self.metadata["max_length"],

autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def load(self, path: str) -> None:
313313
self.n_classes = self.metadata["n_classes"]
314314

315315
vector_index_client = VectorIndexClient(
316-
device=self.device,
316+
embedder_device=self.device,
317317
db_dir=self.metadata["db_dir"],
318318
embedder_batch_size=self.metadata["batch_size"],
319319
embedder_max_length=self.metadata["max_length"],

autointent/modules/scoring/_knn/knn.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
k: int,
9191
weights: WEIGHT_TYPES = "distance",
9292
db_dir: str | None = None,
93-
device: str = "cpu",
93+
embedder_device: str = "cpu",
9494
batch_size: int = 32,
9595
max_length: int | None = None,
9696
embedder_use_cache: bool = False,
@@ -105,7 +105,7 @@ def __init__(
105105
- "distance" (or True): Weight inversely proportional to distance.
106106
- "closest": Only the closest neighbor of each class is weighted.
107107
:param db_dir: Path to the database directory, or None to use default.
108-
:param device: Device to run operations on, e.g., "cpu" or "cuda".
108+
:param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
109109
:param batch_size: Batch size for embedding generation, defaults to 32.
110110
:param max_length: Maximum sequence length for embedding, or None for default.
111111
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
@@ -114,7 +114,7 @@ def __init__(
114114
self.k = k
115115
self.weights = weights
116116
self._db_dir = db_dir
117-
self.device = device
117+
self.embedder_device = embedder_device
118118
self.batch_size = batch_size
119119
self.max_length = max_length
120120
self.embedder_use_cache = embedder_use_cache
@@ -158,7 +158,7 @@ def from_context(
158158
k=k,
159159
weights=weights,
160160
db_dir=str(context.get_db_dir()),
161-
device=context.get_device(),
161+
embedder_device=context.get_device(),
162162
batch_size=context.get_batch_size(),
163163
max_length=context.get_max_length(),
164164
embedder_use_cache=context.get_use_cache(),
@@ -188,7 +188,9 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
188188
else:
189189
self.n_classes = len(set(labels))
190190
self.multilabel = False
191-
vector_index_client = VectorIndexClient(self.device, self.db_dir, embedder_use_cache=self.embedder_use_cache)
191+
vector_index_client = VectorIndexClient(
192+
self.embedder_device, self.db_dir, embedder_use_cache=self.embedder_use_cache
193+
)
192194

193195
if self.prebuilt_index:
194196
# this happens only after RetrievalNode optimization
@@ -265,7 +267,7 @@ def _restore_state_from_metadata(self, metadata: KNNScorerDumpMetadata) -> None:
265267
self.multilabel = metadata["multilabel"]
266268

267269
vector_index_client = VectorIndexClient(
268-
device=self.device,
270+
embedder_device=self.embedder_device,
269271
db_dir=metadata["db_dir"],
270272
embedder_batch_size=metadata["batch_size"],
271273
embedder_max_length=metadata["max_length"],

autointent/modules/scoring/_knn/rerank_scorer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
m: int | None = None,
5252
rank_threshold_cutoff: int | None = None,
5353
db_dir: str | None = None,
54-
device: str = "cpu",
54+
embedder_device: str = "cpu",
5555
batch_size: int = 32,
5656
max_length: int | None = None,
5757
) -> None:
@@ -68,7 +68,7 @@ def __init__(
6868
:param m: Number of top-ranked neighbors to consider, or None to use k.
6969
:param rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None.
7070
:param db_dir: Path to the database directory, or None to use default.
71-
:param device: Device to run operations on, e.g., "cpu" or "cuda".
71+
:param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
7272
:param batch_size: Batch size for embedding generation, defaults to 32.
7373
:param max_length: Maximum sequence length for embedding, or None for default.
7474
"""
@@ -77,7 +77,7 @@ def __init__(
7777
k=k,
7878
weights=weights,
7979
db_dir=db_dir,
80-
device=device,
80+
embedder_device=embedder_device,
8181
batch_size=batch_size,
8282
max_length=max_length,
8383
)
@@ -123,7 +123,7 @@ def from_context(
123123
m=m,
124124
rank_threshold_cutoff=rank_threshold_cutoff,
125125
db_dir=str(context.get_db_dir()),
126-
device=context.get_device(),
126+
embedder_device=context.get_device(),
127127
batch_size=context.get_batch_size(),
128128
max_length=context.get_max_length(),
129129
)
@@ -138,7 +138,7 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
138138
:param utterances: List of utterances to fit the scorer.
139139
:param labels: List of labels corresponding to the utterances.
140140
"""
141-
self._scorer = CrossEncoder(self.cross_encoder_name, device=self.device, max_length=self.max_length) # type: ignore[arg-type]
141+
self._scorer = CrossEncoder(self.cross_encoder_name, device=self.embedder_device, max_length=self.max_length) # type: ignore[arg-type]
142142

143143
super().fit(utterances, labels)
144144

@@ -179,7 +179,7 @@ def _restore_state_from_metadata(self, metadata: RerankScorerDumpMetadata) -> No
179179
self.m = metadata["m"] if metadata["m"] else self.k
180180
self.cross_encoder_name = metadata["cross_encoder_name"]
181181
self.rank_threshold_cutoff = metadata["rank_threshold_cutoff"]
182-
self._scorer = CrossEncoder(self.cross_encoder_name, device=self.device, max_length=self.max_length) # type: ignore[arg-type]
182+
self._scorer = CrossEncoder(self.cross_encoder_name, device=self.embedder_device, max_length=self.max_length) # type: ignore[arg-type]
183183

184184
def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]:
185185
"""

0 commit comments

Comments
 (0)