Skip to content

Commit 03dcbdb

Browse files
committed
allow to use default prompt with override
1 parent 6dab63b commit 03dcbdb

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

autointent/configs/_transformers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Literal
33

44
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
5-
from typing_extensions import Self, assert_never
5+
from typing_extensions import Self
66

77
from autointent.custom_types import FloatFromZeroToOne
88
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
@@ -91,7 +91,7 @@ def get_prompt_config(self) -> dict[str, str] | None:
9191
prompts[TaskTypeEnum.sts.value] = self.sts_prompt
9292
return prompts if len(prompts) > 0 else None
9393

94-
def get_prompt(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: PLR0911
94+
def get_prompt(self, prompt_type: TaskTypeEnum | None) -> str | None:
9595
"""Get the prompt type for the given task type.
9696
9797
Args:
@@ -100,15 +100,15 @@ def get_prompt(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: P
100100
Returns:
101101
The prompt for the given task type.
102102
"""
103-
if prompt_type == TaskTypeEnum.classification:
103+
if prompt_type == TaskTypeEnum.classification and self.classification_prompt is not None:
104104
return self.classification_prompt
105-
if prompt_type == TaskTypeEnum.cluster:
105+
if prompt_type == TaskTypeEnum.cluster and self.classification_prompt is not None:
106106
return self.cluster_prompt
107-
if prompt_type == TaskTypeEnum.query:
107+
if prompt_type == TaskTypeEnum.query and self.query_prompt is not None:
108108
return self.query_prompt
109-
if prompt_type == TaskTypeEnum.passage:
109+
if prompt_type == TaskTypeEnum.passage and self.passage_prompt is not None:
110110
return self.passage_prompt
111-
if prompt_type == TaskTypeEnum.sts:
111+
if prompt_type == TaskTypeEnum.sts and self.sts_prompt is not None:
112112
return self.sts_prompt
113113
return self.default_prompt
114114

0 commit comments

Comments
 (0)