Skip to content

Commit ff7077e

Browse files
committed
Merge branch 'dev' into cnn
2 parents 8e34160 + a04587b commit ff7077e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+1398
-943
lines changed

.github/workflows/test-nodes.yaml

Lines changed: 0 additions & 13 deletions
This file was deleted.

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,6 @@ tests_logs
179179
tests/logs
180180
runs/
181181
vector_db*
182+
*.db
183+
*.sqlite
182184
/wandb

autointent/_callbacks/wandb.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import logging
12
import os
23
from pathlib import Path
34
from typing import Any
45

56
from autointent._callbacks.base import OptimizerCallback
67

8+
logger = logging.getLogger(__name__)
9+
710

811
class WandbCallback(OptimizerCallback):
912
"""Wandb callback for logging the optimization process to Weights & Biases (W&B).
@@ -94,13 +97,26 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
9497
Args:
9598
metrics: A dictionary of final performance metrics.
9699
"""
97-
self.wandb.init(
98-
project=self.project_name,
99-
group=self.group,
100-
name="final_metrics",
101-
config=metrics,
102-
settings=self.wandb.Settings(x_stats_sampling_interval=self.log_interval_time),
103-
)
100+
wandb_run_init_args = {
101+
"project": self.project_name,
102+
"group": self.group,
103+
"name": "final_metrics",
104+
"settings": self.wandb.Settings(x_stats_sampling_interval=self.log_interval_time),
105+
}
106+
107+
try:
108+
self.wandb.init(config=metrics, **wandb_run_init_args)
109+
except Exception as e:
110+
if "run config cannot exceed" not in str(e):
111+
# https://github.com/deeppavlov/AutoIntent/issues/202
112+
raise
113+
logger.warning("W&B run config is too large, skipping logging modules configs")
114+
logger.warning("'final_metrics' will be logged to W&B with pipeline_metrics only")
115+
logger.warning("If you want to access modules configs in future, address to the individual modules runs")
116+
self.wandb.init(
117+
config={},
118+
**wandb_run_init_args,
119+
)
104120

105121
self.wandb.log(metrics.get("pipeline_metrics", {}))
106122
self.wandb.finish()

autointent/_dump_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from peft import PeftModel
1212
from pydantic import BaseModel
1313
from sklearn.base import BaseEstimator
14-
from torch import nn
14+
1515
from transformers import ( # type: ignore[attr-defined]
1616
AutoModelForSequenceClassification,
1717
AutoTokenizer,

autointent/_embedder.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,25 @@
77
import json
88
import logging
99
import shutil
10+
from functools import lru_cache
1011
from pathlib import Path
1112
from typing import TypedDict
1213

14+
import huggingface_hub
1315
import numpy as np
1416
import numpy.typing as npt
1517
import torch
1618
from appdirs import user_cache_dir
1719
from sentence_transformers import SentenceTransformer
20+
from sentence_transformers.similarity_functions import SimilarityFunction
1821

1922
from ._hash import Hasher
2023
from .configs import EmbedderConfig, TaskTypeEnum
2124

25+
logger = logging.getLogger(__name__)
2226

23-
def get_embeddings_path(filename: str) -> Path:
27+
28+
def _get_embeddings_path(filename: str) -> Path:
2429
"""Get the path to the embeddings file.
2530
2631
This function constructs the full path to an embeddings file stored
@@ -37,6 +42,23 @@ def get_embeddings_path(filename: str) -> Path:
3742
return Path(user_cache_dir("autointent")) / "embeddings" / f"{filename}.npy"
3843

3944

45+
@lru_cache(maxsize=128)
46+
def _get_latest_commit_hash(model_name: str) -> str:
47+
"""Get the latest commit hash for a given Hugging Face model.
48+
49+
Args:
50+
model_name: The name of the model to get the latest commit hash for.
51+
52+
Returns:
53+
The latest commit hash for the given model name or the model name if the commit hash is not found.
54+
"""
55+
commit_hash = huggingface_hub.model_info(model_name, revision="main").sha
56+
if commit_hash is None:
57+
logger.warning("No commit hash found for model %s", model_name)
58+
return model_name
59+
return commit_hash
60+
61+
4062
class EmbedderDumpMetadata(TypedDict):
4163
"""Metadata for saving and loading an Embedder instance."""
4264

@@ -63,7 +85,6 @@ class Embedder:
6385

6486
_metadata_dict_name: str = "metadata.json"
6587
_dump_dir: Path | None = None
66-
config: EmbedderConfig
6788
embedding_model: SentenceTransformer
6889

6990
def __init__(self, embedder_config: EmbedderConfig) -> None:
@@ -74,34 +95,41 @@ def __init__(self, embedder_config: EmbedderConfig) -> None:
7495
"""
7596
self.config = embedder_config
7697

77-
self.embedding_model = SentenceTransformer(
78-
self.config.model_name,
79-
device=self.config.device,
80-
prompts=embedder_config.get_prompt_config(),
81-
similarity_fn_name=self.config.similarity_fn_name,
82-
trust_remote_code=self.config.trust_remote_code,
83-
)
84-
85-
self._logger = logging.getLogger(__name__)
86-
8798
def __hash__(self) -> int:
8899
"""Compute a hash value for the Embedder.
89100
90101
Returns:
91102
The hash value of the Embedder.
92103
"""
93104
hasher = Hasher()
94-
for parameter in self.embedding_model.parameters():
95-
hasher.update(parameter.detach().cpu().numpy())
105+
if self.config.freeze:
106+
commit_hash = _get_latest_commit_hash(self.config.model_name)
107+
hasher.update(commit_hash)
108+
else:
109+
self._load_model()
110+
for parameter in self.embedding_model.parameters():
111+
hasher.update(parameter.detach().cpu().numpy())
96112
hasher.update(self.config.tokenizer_config.max_length)
97113
return hasher.intdigest()
98114

115+
def _load_model(self) -> None:
116+
"""Load sentence transformers model to device."""
117+
if not hasattr(self, "embedding_model"):
118+
self.embedding_model = SentenceTransformer(
119+
self.config.model_name,
120+
device=self.config.device,
121+
prompts=self.config.get_prompt_config(),
122+
similarity_fn_name=self.config.similarity_fn_name,
123+
trust_remote_code=self.config.trust_remote_code,
124+
)
125+
99126
def clear_ram(self) -> None:
100127
"""Move the embedding model to CPU and delete it from memory."""
101-
self._logger.debug("Clearing embedder %s from memory", self.config.model_name)
102-
self.embedding_model.cpu()
103-
del self.embedding_model
104-
torch.cuda.empty_cache()
128+
if hasattr(self, "embedding_model"):
129+
logger.debug("Clearing embedder %s from memory", self.config.model_name)
130+
self.embedding_model.cpu()
131+
del self.embedding_model
132+
torch.cuda.empty_cache()
105133

106134
def delete(self) -> None:
107135
"""Delete the embedding model and its associated directory."""
@@ -165,11 +193,13 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
165193
hasher.update(self)
166194
hasher.update(utterances)
167195

168-
embeddings_path = get_embeddings_path(hasher.hexdigest())
196+
embeddings_path = _get_embeddings_path(hasher.hexdigest())
169197
if embeddings_path.exists():
170198
return np.load(embeddings_path) # type: ignore[no-any-return]
171199

172-
self._logger.debug(
200+
self._load_model()
201+
202+
logger.debug(
173203
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
174204
self.config.model_name,
175205
self.config.batch_size,
@@ -200,11 +230,11 @@ def similarity(
200230
"""Calculate similarity between two sets of embeddings.
201231
202232
Args:
203-
embeddings1: First set of embeddings.
204-
embeddings2: Second set of embeddings.
233+
embeddings1: First set of embeddings (size n).
234+
embeddings2: Second set of embeddings (size m).
205235
206236
Returns:
207-
A numpy array of similarities.
237+
A numpy array of similarities (size n x m).
208238
"""
209-
result = self.embedding_model.similarity(embeddings1, embeddings2)
210-
return result.detach().cpu().numpy().astype(np.float32)
239+
similarity_fn = SimilarityFunction.to_similarity_fn(self.config.similarity_fn_name)
240+
return similarity_fn(embeddings1, embeddings2).detach().cpu().numpy().astype(np.float32)

autointent/_optimization_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pydantic import BaseModel, PositiveInt
44

5-
from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, LoggingConfig
5+
from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, LoggingConfig
66
from .custom_types import SamplerType
77

88

@@ -25,6 +25,8 @@ class OptimizationConfig(BaseModel):
2525

2626
cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig()
2727

28+
transformer_config: HFModelConfig = HFModelConfig()
29+
2830
sampler: SamplerType = "brute"
2931
"""See tutorial on optuna and presets."""
3032

autointent/_pipeline/_pipeline.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
CrossEncoderConfig,
1515
DataConfig,
1616
EmbedderConfig,
17+
HFModelConfig,
1718
InferenceNodeConfig,
1819
LoggingConfig,
1920
)
@@ -67,10 +68,13 @@ def __init__(
6768
self.embedder_config = EmbedderConfig()
6869
self.cross_encoder_config = CrossEncoderConfig()
6970
self.data_config = DataConfig()
71+
self.transformer_config = HFModelConfig()
7072
elif not isinstance(nodes[0], InferenceNode):
7173
assert_never(nodes)
7274

73-
def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig) -> None:
75+
def set_config(
76+
self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig | HFModelConfig
77+
) -> None:
7478
"""Set the configuration for the pipeline.
7579
7680
Args:
@@ -84,6 +88,8 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig
8488
self.cross_encoder_config = config
8589
elif isinstance(config, DataConfig):
8690
self.data_config = config
91+
elif isinstance(config, HFModelConfig):
92+
self.transformer_config = config
8793
else:
8894
assert_never(config)
8995

@@ -133,6 +139,7 @@ def from_optimization_config(cls, config: dict[str, Any] | Path | str | Optimiza
133139
pipeline.set_config(optimization_config.data_config)
134140
pipeline.set_config(optimization_config.embedder_config)
135141
pipeline.set_config(optimization_config.cross_encoder_config)
142+
pipeline.set_config(optimization_config.transformer_config)
136143
return pipeline
137144

138145
def _fit(self, context: Context, sampler: SamplerType) -> None:
@@ -144,6 +151,14 @@ def _fit(self, context: Context, sampler: SamplerType) -> None:
144151
"""
145152
self.context = context
146153
self._logger.info("starting pipeline optimization...")
154+
155+
if not context.logging_config.dump_modules:
156+
self._logger.warning(
157+
"Memory storage is not compatible with resuming optimization. "
158+
"Modules from previous runs won't be available. "
159+
"Set dump_modules=True in LoggingConfig to enable proper resuming."
160+
)
161+
147162
self.context.callback_handler.start_run(
148163
run_name=self.context.logging_config.get_run_name(),
149164
dirpath=self.context.logging_config.dirpath,
@@ -190,6 +205,7 @@ def fit(
190205
context.configure_logging(self.logging_config)
191206
context.configure_transformer(self.embedder_config)
192207
context.configure_transformer(self.cross_encoder_config)
208+
context.configure_transformer(self.transformer_config)
193209

194210
self.validate_modules(dataset, mode=incompatible_search_space)
195211

autointent/_ranker.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111
from pathlib import Path
1212
from random import shuffle
13-
from typing import Any, TypedDict
13+
from typing import Any, Literal, TypedDict
1414

1515
import joblib
1616
import numpy as np
@@ -101,12 +101,14 @@ def __init__(
101101
self,
102102
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any],
103103
classifier_head: LogisticRegressionCV | None = None,
104+
output_range: Literal["sigmoid", "tanh"] = "sigmoid",
104105
) -> None:
105106
"""Initialize the Ranker.
106107
107108
Args:
108109
cross_encoder_config: Configuration for the cross-encoder model
109110
classifier_head: Optional pre-trained classifier head
111+
output_range: Range of the output probabilities ([0, 1] for sigmoid, [-1, 1] for tanh)
110112
"""
111113
self.config = CrossEncoderConfig.from_search_config(cross_encoder_config)
112114
self.cross_encoder = st.CrossEncoder(
@@ -117,6 +119,7 @@ def __init__(
117119
)
118120
self._train_head = False
119121
self._clf = classifier_head
122+
self.output_range = output_range
120123

121124
if classifier_head is not None or self.config.train_head:
122125
self._train_head = True
@@ -148,7 +151,7 @@ def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDAr
148151
self.cross_encoder.predict(
149152
pairs,
150153
batch_size=self.config.batch_size,
151-
activation_fct=nn.Sigmoid(),
154+
activation_fct=nn.Sigmoid() if self.output_range == "sigmoid" else nn.Tanh(),
152155
)
153156
)
154157

@@ -210,7 +213,10 @@ def predict(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]:
210213
features = self._get_features_or_predictions(pairs)
211214

212215
if self._clf is not None:
213-
return np.array(self._clf.predict_proba(features)[:, 1])
216+
probs = np.array(self._clf.predict_proba(features)[:, 1])
217+
if self.output_range == "tanh":
218+
probs = probs * 2 - 1
219+
return probs
214220
return features
215221

216222
def rank(

autointent/configs/_transformers.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ class EmbedderConfig(HFModelConfig):
6161
sts_prompt: str | None = Field(None, description="Prompt for finding most similar sentences.")
6262
query_prompt: str | None = Field(None, description="Prompt for query.")
6363
passage_prompt: str | None = Field(None, description="Prompt for passage.")
64-
similarity_fn_name: str | None = Field(
65-
"cosine", description="Name of the similarity function to use (cosine, dot, euclidean, manhattan)."
64+
similarity_fn_name: Literal["cosine", "dot", "euclidean", "manhattan"] = Field(
65+
"cosine", description="Name of the similarity function to use."
6666
)
67+
use_cache: bool = Field(True, description="Whether to use embeddings caching.")
68+
freeze: bool = Field(True, description="Whether to freeze the model parameters.")
6769

6870
def get_prompt_config(self) -> dict[str, str] | None:
6971
"""Get the prompt config for the given prompt type.
@@ -111,11 +113,12 @@ def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # no
111113
return self.default_prompt
112114
assert_never(prompt_type)
113115

114-
use_cache: bool = Field(False, description="Whether to use embeddings caching.")
115-
116116

117117
class CrossEncoderConfig(HFModelConfig):
118-
model_name: str = Field("cross-encoder/ms-marco-MiniLM-L-6-v2", description="Name of the hugging face model.")
118+
model_name: str = Field("cross-encoder/ms-marco-MiniLM-L6-v2", description="Name of the hugging face model.")
119119
train_head: bool = Field(
120120
False, description="Whether to train the head of the model. If False, LogReg will be trained."
121121
)
122+
tokenizer_config: TokenizerConfig = Field(
123+
default_factory=lambda: TokenizerConfig(max_length=512)
124+
) # this is because sentence-transformers doesn't allow you to customize tokenizer settings properly

0 commit comments

Comments
 (0)