Skip to content

Commit 9d4e2da

Browse files
committed
try to fix zero shot scorers
1 parent 14b205f commit 9d4e2da

File tree

4 files changed

+24
-21
lines changed

4 files changed

+24
-21
lines changed

autointent/modules/scoring/_description/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ class BaseDescriptionScorer(BaseScorer, ABC):
2222
2323
Args:
2424
temperature: Temperature parameter for scaling logits, defaults to 1.0
25+
multilabe: Flag indicating classification task type
2526
"""
2627

2728
supports_multiclass = True
2829
supports_multilabel = True
2930

30-
def __init__(self, temperature: PositiveFloat = 1.0) -> None:
31+
def __init__(self, temperature: PositiveFloat = 1.0, multilabel: bool = False) -> None:
3132
self.temperature = temperature
33+
self._multilabel = multilabel
3234
self._validate_temperature()
3335

3436
def _validate_temperature(self) -> None:
@@ -82,16 +84,14 @@ def fit(
8284
Raises:
8385
ValueError: If descriptions contain None values
8486
"""
85-
self._validate_task(labels)
8687
self._validate_descriptions(descriptions)
87-
self._fit_implementation(utterances, descriptions)
88+
self._fit_implementation(descriptions)
8889

8990
@abstractmethod
90-
def _fit_implementation(self, utterances: list[str], descriptions: list[str]) -> None:
91+
def _fit_implementation(self, descriptions: list[str]) -> None:
9192
"""Implementation-specific fitting logic.
9293
9394
Args:
94-
utterances: List of utterances to process
9595
descriptions: List of intent descriptions
9696
"""
9797

autointent/modules/scoring/_description/bi_encoder.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class BiEncoderDescriptionScorer(BaseDescriptionScorer):
2626
Args:
2727
embedder_config: Configuration for the embedder model (HuggingFace model name or config)
2828
temperature: Temperature parameter for scaling logits before softmax/sigmoid (default: 1.0)
29+
multilabel: Flag indicating classification task type
2930
3031
Example:
3132
--------
@@ -60,8 +61,9 @@ def __init__(
6061
self,
6162
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
6263
temperature: PositiveFloat = 1.0,
64+
multilabel: bool = False,
6365
) -> None:
64-
super().__init__(temperature)
66+
super().__init__(temperature=temperature, multilabel=multilabel)
6567
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
6668
self._embedder: Embedder | None = None
6769
self._description_vectors: NDArray[Any] | None = None
@@ -86,16 +88,13 @@ def from_context(
8688
if embedder_config is None:
8789
embedder_config = context.resolve_embedder()
8890

89-
return cls(
90-
temperature=temperature,
91-
embedder_config=embedder_config,
92-
)
91+
return cls(temperature=temperature, embedder_config=embedder_config, multilabel=context.is_multilabel())
9392

9493
def get_implicit_initialization_params(self) -> dict[str, Any]:
9594
"""Get implicit initialization parameters for this scorer."""
9695
return {"embedder_config": self.embedder_config.model_dump()}
9796

98-
def _fit_implementation(self, utterances: list[str], descriptions: list[str]) -> None:
97+
def _fit_implementation(self, descriptions: list[str]) -> None:
9998
"""Fit the bi-encoder by embedding descriptions.
10099
101100
Args:

autointent/modules/scoring/_description/cross_encoder.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class CrossEncoderDescriptionScorer(BaseDescriptionScorer):
2828
Args:
2929
cross_encoder_config: Configuration for the cross-encoder model (HuggingFace model name or config)
3030
temperature: Temperature parameter for scaling logits before softmax/sigmoid (default: 1.0)
31+
multilabel: Flag indicating classification task type
3132
3233
Example:
3334
--------
@@ -48,8 +49,8 @@ class CrossEncoderDescriptionScorer(BaseDescriptionScorer):
4849
"User asks about weather conditions or forecasts"
4950
]
5051
51-
# Fit using descriptions only (zero-shot approach)
52-
scorer.fit([], [], descriptions)
52+
# Fit using descriptions only (zero-shot approach)
53+
scorer.fit([], [], descriptions)
5354
5455
# Make predictions on new utterances
5556
test_utterances = ["Reserve a hotel room", "Delete my booking"]
@@ -62,8 +63,9 @@ def __init__(
6263
self,
6364
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None,
6465
temperature: PositiveFloat = 1.0,
66+
multilabel: bool = False,
6567
) -> None:
66-
super().__init__(temperature)
68+
super().__init__(temperature=temperature, multilabel=multilabel)
6769
self.cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config)
6870
self._cross_encoder: Ranker | None = None
6971
self._description_texts: list[str] | None = None
@@ -89,15 +91,14 @@ def from_context(
8991
cross_encoder_config = context.resolve_ranker()
9092

9193
return cls(
92-
temperature=temperature,
93-
cross_encoder_config=cross_encoder_config,
94+
temperature=temperature, cross_encoder_config=cross_encoder_config, multilabel=context.is_multilabel()
9495
)
9596

9697
def get_implicit_initialization_params(self) -> dict[str, Any]:
9798
"""Get implicit initialization parameters for this scorer."""
9899
return {"cross_encoder_config": self.cross_encoder_config.model_dump()}
99100

100-
def _fit_implementation(self, utterances: list[str], descriptions: list[str]) -> None:
101+
def _fit_implementation(self, descriptions: list[str]) -> None:
101102
"""Fit the cross-encoder by storing descriptions.
102103
103104
Args:

autointent/modules/scoring/_description/llm_encoder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class LLMDescriptionScorer(BaseDescriptionScorer):
6262
max_concurrent: Maximum number of concurrent async calls to LLM (default: 15)
6363
max_per_second: Maximum number of API calls per second for rate limiting (default: 10)
6464
max_retries: Maximum number of retry attempts for failed API calls (default: 3)
65+
multilabel: Flag indicating classification task type
6566
6667
Example:
6768
--------
@@ -84,8 +85,8 @@ class LLMDescriptionScorer(BaseDescriptionScorer):
8485
"User asks about weather conditions or forecasts"
8586
]
8687
87-
# Fit using descriptions only (zero-shot approach)
88-
scorer.fit([], [], descriptions)
88+
# Fit using descriptions only (zero-shot approach)
89+
scorer.fit([], [], descriptions)
8990
9091
# Make predictions on new utterances
9192
test_utterances = ["Reserve a hotel room", "Delete my booking"]
@@ -101,8 +102,9 @@ def __init__(
101102
max_concurrent: PositiveInt | None = 15,
102103
max_per_second: PositiveInt = 10,
103104
max_retries: PositiveInt = 3,
105+
multilabel: bool = False,
104106
) -> None:
105-
super().__init__(temperature=temperature)
107+
super().__init__(temperature=temperature, multilabel=multilabel)
106108

107109
self.generator_config = generator_config or {}
108110
self.max_concurrent = max_concurrent
@@ -125,12 +127,13 @@ def from_context(
125127
max_concurrent=max_concurrent,
126128
max_per_second=max_per_second,
127129
max_retries=max_retries,
130+
multilabel=context.is_multilabel(),
128131
)
129132

130133
def get_implicit_initialization_params(self) -> dict[str, Any]:
131134
return {}
132135

133-
def _fit_implementation(self, utterances: list[str], descriptions: list[str]) -> None:
136+
def _fit_implementation(self, descriptions: list[str]) -> None:
134137
"""Fit the LLM scorer by initializing the generator and storing descriptions.
135138
136139
Args:

0 commit comments

Comments
 (0)