11"""DescriptionScorer class for scoring utterances based on intent descriptions."""
22
3- from typing import Any
3+ from typing import Any , Literal
44
55import numpy as np
66import scipy
77from numpy .typing import NDArray
88from pydantic import PositiveFloat
9- from sklearn .metrics .pairwise import cosine_similarity
109
11- from autointent import Context , Embedder
12- from autointent .configs import EmbedderConfig , TaskTypeEnum
10+ from autointent import Context , Embedder , Ranker
11+ from autointent .configs import CrossEncoderConfig , EmbedderConfig , TaskTypeEnum
1312from autointent .context .optimization_info import ScorerArtifact
1413from autointent .custom_types import ListOfLabels
1514from autointent .metrics import SCORING_METRICS_MULTICLASS , SCORING_METRICS_MULTILABEL
1918class DescriptionScorer (BaseScorer ):
2019 """Scoring module that scores utterances based on similarity to intent descriptions.
2120
22- DescriptionScorer embeds both the utterances and the intent descriptions, then computes
23- a similarity score between the two, using either cosine similarity and softmax.
21+ DescriptionScorer can use either a bi-encoder or cross-encoder architecture:
22+ - Bi-encoder: Embeds both utterances and descriptions separately, then computes cosine similarity
23+ - Cross-encoder: Directly computes similarity between each utterance-description pair
2424
2525 Args:
26- embedder_config: Config of the embedder model
26+ embedder_config: Config of the embedder model (for bi-encoder mode)
27+ cross_encoder_config: Config of the cross-encoder model (for cross-encoder mode)
28+ encoder_type: Type of encoder to use, either "bi" or "cross"
2729 temperature: Temperature parameter for scaling logits, defaults to 1.0
2830 """
2931
30- _embedder : Embedder
32+ _embedder : Embedder | None = None
33+ _cross_encoder : Ranker | None = None
3134 name = "description"
3235 _n_classes : int
3336 _multilabel : bool
34- _description_vectors : NDArray [Any ]
37+ _description_vectors : NDArray [Any ] | None = None
38+ _description_texts : list [str ] | None = None
3539 supports_multiclass = True
3640 supports_multilabel = True
3741
3842 def __init__ (
3943 self ,
4044 embedder_config : EmbedderConfig | str | dict [str , Any ] | None = None ,
45+ cross_encoder_config : CrossEncoderConfig | str | dict [str , Any ] | None = None ,
46+ encoder_type : Literal ["bi" , "cross" ] = "bi" ,
4147 temperature : PositiveFloat = 1.0 ,
4248 ) -> None :
4349 self .temperature = temperature
4450 self .embedder_config = EmbedderConfig .from_search_config (embedder_config )
51+ self .cross_encoder_config = CrossEncoderConfig .from_search_config (cross_encoder_config )
52+ self ._encoder_type = encoder_type
4553
4654 if self .temperature < 0 or not isinstance (self .temperature , float | int ):
4755 msg = "`temperature` argument of `DescriptionScorer` must be a positive float"
@@ -51,35 +59,51 @@ def __init__(
5159 def from_context (
5260 cls ,
5361 context : Context ,
54- temperature : PositiveFloat ,
62+ temperature : PositiveFloat = 1.0 ,
5563 embedder_config : EmbedderConfig | str | None = None ,
64+ cross_encoder_config : CrossEncoderConfig | str | None = None ,
65+ encoder_type : Literal ["bi" , "cross" ] = "bi" ,
5666 ) -> "DescriptionScorer" :
5767 """Create a DescriptionScorer instance using a Context object.
5868
5969 Args:
6070 context: Context containing configurations and utilities
6171 temperature: Temperature parameter for scaling logits
6272 embedder_config: Config of the embedder model. If None, the best embedder is used
73+ cross_encoder_config: Config of the cross-encoder model. If None, the default config is used
74+ encoder_type: Type of encoder to use, either "bi" or "cross"
6375
6476 Returns:
6577 Initialized DescriptionScorer instance
6678 """
6779 if embedder_config is None :
6880 embedder_config = context .resolve_embedder ()
81+ if cross_encoder_config is None :
82+ cross_encoder_config = context .resolve_ranker ()
6983
7084 return cls (
7185 temperature = temperature ,
7286 embedder_config = embedder_config ,
87+ cross_encoder_config = cross_encoder_config ,
88+ encoder_type = encoder_type ,
7389 )
7490
7591 def get_embedder_config (self ) -> dict [str , Any ]:
76- """Get the name of the embedder.
92+ """Get the configuration of the embedder.
7793
7894 Returns:
79- Embedder name
95+ Embedder configuration
8096 """
8197 return self .embedder_config .model_dump ()
8298
99+ def get_cross_encoder_config (self ) -> dict [str , Any ]:
100+ """Get the configuration of the cross-encoder.
101+
102+ Returns:
103+ Cross-encoder configuration
104+ """
105+ return self .cross_encoder_config .model_dump ()
106+
83107 def fit (
84108 self ,
85109 utterances : list [str ],
@@ -96,8 +120,10 @@ def fit(
96120 Raises:
97121 ValueError: If descriptions contain None values or embeddings mismatch utterances
98122 """
99- if hasattr (self , "_embedder" ):
123+ if hasattr (self , "_embedder" ) and self . _embedder is not None :
100124 self ._embedder .clear_ram ()
125+ if hasattr (self , "_cross_encoder" ) and self ._cross_encoder is not None :
126+ self ._cross_encoder .clear_ram ()
101127
102128 self ._validate_task (labels )
103129
@@ -108,10 +134,17 @@ def fit(
108134 )
109135 raise ValueError (error_text )
110136
111- embedder = Embedder (self .embedder_config )
112-
113- self ._description_vectors = embedder .embed (descriptions , TaskTypeEnum .sts )
114- self ._embedder = embedder
137+ if self ._encoder_type == "bi" :
138+ embedder = Embedder (self .embedder_config )
139+ self ._description_vectors = embedder .embed (descriptions , TaskTypeEnum .sts )
140+ self ._embedder = embedder
141+ self ._cross_encoder = None
142+ self ._description_texts = None
143+ else :
144+ self ._cross_encoder = Ranker (self .cross_encoder_config )
145+ self ._description_texts = descriptions
146+ self ._embedder = None
147+ self ._description_vectors = None
115148
116149 def predict (self , utterances : list [str ]) -> NDArray [np .float64 ]:
117150 """Predict scores for utterances based on similarity to intent descriptions.
@@ -122,8 +155,32 @@ def predict(self, utterances: list[str]) -> NDArray[np.float64]:
122155 Returns:
123156 Array of probabilities for each utterance
124157 """
125- utterance_vectors = self ._embedder .embed (utterances , TaskTypeEnum .sts )
126- similarities : NDArray [np .float64 ] = cosine_similarity (utterance_vectors , self ._description_vectors )
158+ if self ._encoder_type == "bi" :
159+ if self ._description_vectors is None :
160+ error_text = "Description vectors are not initialized. Call fit() before predict()."
161+ raise RuntimeError (error_text )
162+
163+ if self ._embedder is None :
164+ error_text = "Embedder is not initialized. Call fit() before predict()."
165+ raise RuntimeError (error_text )
166+
167+ utterance_vectors = self ._embedder .embed (utterances , TaskTypeEnum .sts )
168+ similarities : NDArray [np .float64 ] = np .array (
169+ self ._embedder .similarity (utterance_vectors , self ._description_vectors ), dtype = np .float64
170+ )
171+ else :
172+ if self ._cross_encoder is None :
173+ error_text = "Cross encoder is not initialized. Call fit() before predict()."
174+ raise RuntimeError (error_text )
175+
176+ if self ._description_texts is None :
177+ error_text = "Description texts are not initialized. Call fit() before predict()."
178+ raise RuntimeError (error_text )
179+
180+ pairs = [(utterance , description ) for utterance in utterances for description in self ._description_texts ]
181+
182+ scores = self ._cross_encoder .predict (pairs )
183+ similarities = np .array (scores , dtype = np .float64 ).reshape (len (utterances ), len (self ._description_texts ))
127184
128185 if self ._multilabel :
129186 probabilities = scipy .special .expit (similarities / self .temperature )
@@ -132,8 +189,11 @@ def predict(self, utterances: list[str]) -> NDArray[np.float64]:
132189 return probabilities # type: ignore[no-any-return]
133190
134191 def clear_cache (self ) -> None :
135- """Clear cached data in memory used by the embedder."""
136- self ._embedder .clear_ram ()
192+ """Clear cached data in memory used by the embedder or cross-encoder."""
193+ if self ._embedder is not None :
194+ self ._embedder .clear_ram ()
195+ if self ._cross_encoder is not None :
196+ self ._cross_encoder .clear_ram ()
137197
138198 def get_train_data (self , context : Context ) -> tuple [list [str ], ListOfLabels , list [str ]]:
139199 """Get training data from context.
0 commit comments