Skip to content

Commit 95be22a

Browse files
authored
Fix default prompt (#226)
* fix default prompt * allow to use default prompt with override
1 parent 57eafbb commit 95be22a

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

autointent/configs/_transformers.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Literal
33

44
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
5-
from typing_extensions import Self, assert_never
5+
from typing_extensions import Self
66

77
from autointent.custom_types import FloatFromZeroToOne
88
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
@@ -91,7 +91,7 @@ def get_prompt_config(self) -> dict[str, str] | None:
9191
prompts[TaskTypeEnum.sts.value] = self.sts_prompt
9292
return prompts if len(prompts) > 0 else None
9393

94-
def get_prompt(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: PLR0911
94+
def get_prompt(self, prompt_type: TaskTypeEnum | None) -> str | None:
9595
"""Get the prompt type for the given task type.
9696
9797
Args:
@@ -100,21 +100,17 @@ def get_prompt(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: P
100100
Returns:
101101
The prompt for the given task type.
102102
"""
103-
if prompt_type is None:
104-
return self.default_prompt
105-
if prompt_type == TaskTypeEnum.classification:
103+
if prompt_type == TaskTypeEnum.classification and self.classification_prompt is not None:
106104
return self.classification_prompt
107-
if prompt_type == TaskTypeEnum.cluster:
105+
if prompt_type == TaskTypeEnum.cluster and self.classification_prompt is not None:
108106
return self.cluster_prompt
109-
if prompt_type == TaskTypeEnum.query:
107+
if prompt_type == TaskTypeEnum.query and self.query_prompt is not None:
110108
return self.query_prompt
111-
if prompt_type == TaskTypeEnum.passage:
109+
if prompt_type == TaskTypeEnum.passage and self.passage_prompt is not None:
112110
return self.passage_prompt
113-
if prompt_type == TaskTypeEnum.sts:
111+
if prompt_type == TaskTypeEnum.sts and self.sts_prompt is not None:
114112
return self.sts_prompt
115-
if prompt_type == TaskTypeEnum.default:
116-
return self.default_prompt
117-
assert_never(prompt_type)
113+
return self.default_prompt
118114

119115

120116
class CrossEncoderConfig(HFModelConfig):

0 commit comments

Comments
 (0)