Skip to content

Commit fe2b79e

Browse files
add prompt logging (#220)
* add prompt logging * Update optimizer_config.schema.json * fix --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent d61ee10 commit fe2b79e

File tree

5 files changed

+16
-14
lines changed

5 files changed

+16
-14
lines changed

autointent/_embedder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,15 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
198198
return np.load(embeddings_path) # type: ignore[no-any-return]
199199

200200
self._load_model()
201+
prompt = self.config.get_prompt(task_type)
201202

202203
logger.debug(
203-
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
204+
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s",
204205
self.config.model_name,
205206
self.config.batch_size,
206207
str(self.config.tokenizer_config.max_length),
207208
self.config.device,
209+
prompt,
208210
)
209211

210212
if self.config.tokenizer_config.max_length is not None:
@@ -215,7 +217,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
215217
convert_to_numpy=True,
216218
batch_size=self.config.batch_size,
217219
normalize_embeddings=True,
218-
prompt=self.config.get_prompt_type(task_type),
220+
prompt=prompt,
219221
)
220222

221223
if self.config.use_cache:

autointent/configs/_transformers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class EmbedderConfig(HFModelConfig):
5959
default_prompt: str | None = Field(
6060
None, description="Default prompt for the model. This is used when no task specific prompt is not provided."
6161
)
62-
classifier_prompt: str | None = Field(None, description="Prompt for classifier.")
62+
classification_prompt: str | None = Field(None, description="Prompt for classifier.")
6363
cluster_prompt: str | None = Field(None, description="Prompt for clustering.")
6464
sts_prompt: str | None = Field(None, description="Prompt for finding most similar sentences.")
6565
query_prompt: str | None = Field(None, description="Prompt for query.")
@@ -79,8 +79,8 @@ def get_prompt_config(self) -> dict[str, str] | None:
7979
prompts = {}
8080
if self.default_prompt:
8181
prompts[TaskTypeEnum.default.value] = self.default_prompt
82-
if self.classifier_prompt:
83-
prompts[TaskTypeEnum.classification.value] = self.classifier_prompt
82+
if self.classification_prompt:
83+
prompts[TaskTypeEnum.classification.value] = self.classification_prompt
8484
if self.cluster_prompt:
8585
prompts[TaskTypeEnum.cluster.value] = self.cluster_prompt
8686
if self.query_prompt:
@@ -91,7 +91,7 @@ def get_prompt_config(self) -> dict[str, str] | None:
9191
prompts[TaskTypeEnum.sts.value] = self.sts_prompt
9292
return prompts if len(prompts) > 0 else None
9393

94-
def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: PLR0911
94+
def get_prompt(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: PLR0911
9595
"""Get the prompt type for the given task type.
9696
9797
Args:
@@ -103,7 +103,7 @@ def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # no
103103
if prompt_type is None:
104104
return self.default_prompt
105105
if prompt_type == TaskTypeEnum.classification:
106-
return self.classifier_prompt
106+
return self.classification_prompt
107107
if prompt_type == TaskTypeEnum.cluster:
108108
return self.cluster_prompt
109109
if prompt_type == TaskTypeEnum.query:

docs/optimizer_config.schema.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@
160160
"description": "Default prompt for the model. This is used when no task specific prompt is not provided.",
161161
"title": "Default Prompt"
162162
},
163-
"classifier_prompt": {
163+
"classification_prompt": {
164164
"anyOf": [
165165
{
166166
"type": "string"
@@ -171,7 +171,7 @@
171171
],
172172
"default": null,
173173
"description": "Prompt for classifier.",
174-
"title": "Classifier Prompt"
174+
"title": "Classification Prompt"
175175
},
176176
"cluster_prompt": {
177177
"anyOf": [
@@ -459,7 +459,7 @@
459459
},
460460
"trust_remote_code": false,
461461
"default_prompt": null,
462-
"classifier_prompt": null,
462+
"classification_prompt": null,
463463
"cluster_prompt": null,
464464
"sts_prompt": null,
465465
"query_prompt": null,

docs/optimizer_search_space_config.schema.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@
404404
"description": "Default prompt for the model. This is used when no task specific prompt is not provided.",
405405
"title": "Default Prompt"
406406
},
407-
"classifier_prompt": {
407+
"classification_prompt": {
408408
"anyOf": [
409409
{
410410
"type": "string"

tests/callback/test_callback.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_pipeline_callbacks(dataset):
136136
"module_kwargs": {
137137
"embedder_config": {
138138
"batch_size": 32,
139-
"classifier_prompt": None,
139+
"classification_prompt": None,
140140
"cluster_prompt": None,
141141
"default_prompt": None,
142142
"device": None,
@@ -173,7 +173,7 @@ def test_pipeline_callbacks(dataset):
173173
"module_kwargs": {
174174
"embedder_config": {
175175
"batch_size": 32,
176-
"classifier_prompt": None,
176+
"classification_prompt": None,
177177
"cluster_prompt": None,
178178
"default_prompt": None,
179179
"device": None,
@@ -210,7 +210,7 @@ def test_pipeline_callbacks(dataset):
210210
"module_kwargs": {
211211
"embedder_config": {
212212
"batch_size": 32,
213-
"classifier_prompt": None,
213+
"classification_prompt": None,
214214
"cluster_prompt": None,
215215
"default_prompt": None,
216216
"device": None,

0 commit comments

Comments
 (0)